Post

Replies

Boosts

Views

Activity

JAX Metal: Random Number Generation Performance Issue on M1 Max
JAX Metal shows 55x slower random number generation compared to NVIDIA CUDA on equivalent workloads. This makes Monte Carlo simulations and scientific computing impractical on Apple Silicon. Performance Comparison NVIDIA GPU: 0.475s for 12.6M random elements M1 Max Metal: 26.3s for same workload Performance gap: 55x slower Environment Apple M1 Max, 64GB RAM, macOS Sequoia Version 15.6.1 JAX 0.4.34, jax-metal latest Backend: Metal Reproduction Code import time import jax import jax.numpy as jnp from jax import random key = random.PRNGKey(42) start_time = time.time() random_array = random.normal(key, (50000, 252)) duration = time.time() - start_time print(f"Duration: {duration:.3f}s")
0
0
250
1w