Post

Replies

Boosts

Views

Activity

recent JAX versions fail on Metal
Hi, I'm not sure whether this is the appropriate forum for this topic. I just followed a link from the JAX Metal plugin page https://developer.apple.com/metal/jax/ I'm writing a Python app with JAX, and recent JAX versions fail on Metal. E.g. v0.8.2 I have to downgrade JAX pretty hard to make it work: pip install jax==0.4.35 jaxlib==0.4.35 jax-metal==0.1.1 Can we get an updated release of jax-metal that would fix this issue? Here is the error I get with JAX v0.8.2: WARNING:2025-12-26 09:55:28,117:jax._src.xla_bridge:881: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1766771728.118004 207582 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M3 Max systemMemory: 36.00 GB maxCacheSize: 13.50 GB I0000 00:00:1766771728.129886 207582 service.cc:145] XLA service 0x600001fad300 initialized for platform METAL (this does not guarantee that XLA will be used). Devices: I0000 00:00:1766771728.129893 207582 service.cc:153] StreamExecutor device (0): Metal, <undefined> I0000 00:00:1766771728.130856 207582 mps_client.cc:406] Using Simple allocator. I0000 00:00:1766771728.130864 207582 mps_client.cc:384] XLA backend will use up to 28990554112 bytes on device 0 for SimpleAllocator. Traceback (most recent call last): File "<string>", line 1, in <module> import jax; print(jax.numpy.arange(10)) ~~~~~~~~~~~~~~~~^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py", line 5951, in arange return _arange(start, stop=stop, step=step, dtype=dtype, out_sharding=sharding) File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py", line 6012, in _arange return lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding) ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 3415, in broadcasted_iota return iota_p.bind(dtype=dtype, shape=shape, ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^ dimension=dimension, sharding=out_sharding) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 633, in bind return self._true_bind(*args, **params) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 649, in _true_bind return self.bind_with_trace(prev_trace, args, params) ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 661, in bind_with_trace return trace.process_primitive(self, args, params) ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 1210, in process_primitive return primitive.impl(*args, **params) ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/dispatch.py", line 91, in apply_primitive outs = fun(*args) jax.errors.JaxRuntimeError: UNKNOWN: -:0:0: error: unknown attribute code: 22 -:0:0: note: in bytecode version 6 produced by: StableHLO_v1.13.0 -------------------- For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. I0000 00:00:1766771728.149951 207582 mps_client.h:209] MetalClient destroyed.
0
0
460
Dec ’25