JAX Metal shows 55x slower random number generation compared to NVIDIA CUDA on equivalent workloads. This makes Monte Carlo simulations and scientific computing impractical on Apple Silicon.
Performance Comparison
- NVIDIA GPU: 0.475s for 12.6M random elements
- M1 Max Metal: 26.3s for same workload
- Performance gap: 55x slower
Environment
- Apple M1 Max, 64GB RAM, macOS Sequoia Version 15.6.1
- JAX 0.4.34, jax-metal latest
- Backend: Metal
Reproduction Code
import time
import jax
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(42)
start_time = time.time()
random_array = random.normal(key, (50000, 252))
duration = time.time() - start_time
print(f"Duration: {duration:.3f}s")