JAX Metal shows 55x slower random number generation compared to NVIDIA CUDA on equivalent workloads. This makes Monte Carlo simulations and scientific computing impractical on Apple Silicon.
Performance Comparison
NVIDIA GPU: 0.475s for 12.6M random elements
M1 Max Metal: 26.3s for same workload
Performance gap: 55x slower
Environment
Apple M1 Max, 64GB RAM, macOS Sequoia Version 15.6.1
JAX 0.4.34, jax-metal latest
Backend: Metal
Reproduction Code
import time
import jax
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(42)
start_time = time.time()
random_array = random.normal(key, (50000, 252))
duration = time.time() - start_time
print(f"Duration: {duration:.3f}s")
Selecting any option will automatically load the page