JAX Metal: Random Number Generation Performance Issue on M1 Max

JAX Metal shows 55x slower random number generation compared to NVIDIA CUDA on equivalent workloads. This makes Monte Carlo simulations and scientific computing impractical on Apple Silicon.

Performance Comparison

  • NVIDIA GPU: 0.475s for 12.6M random elements
  • M1 Max Metal: 26.3s for same workload
  • Performance gap: 55x slower

Environment

  • Apple M1 Max, 64GB RAM, macOS Sequoia Version 15.6.1
  • JAX 0.4.34, jax-metal latest
  • Backend: Metal

Reproduction Code

import time
import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(42)
start_time = time.time()
random_array = random.normal(key, (50000, 252))
duration = time.time() - start_time
print(f"Duration: {duration:.3f}s")
JAX Metal: Random Number Generation Performance Issue on M1 Max
 
 
Q