jax-metal failing due to incompatibility with jax 0.5.1 or later.

Hello,

I am interested in using jax-metal to train ML models using Apple Silicon. I understand this is experimental.

After installing jax-metal according to https://developer.apple.com/metal/jax/, my python code fails with the following error

JaxRuntimeError: UNKNOWN: -:0:0: error: unknown attribute code: 22
-:0:0: note: in bytecode version 6 produced by: StableHLO_v1.12.1

My issue is identical to the one reported here https://github.com/jax-ml/jax/issues/26968#issuecomment-2733120325, and is fixed by pinning to jax-metal 0.1.1., jax 0.5.0 and jaxlib 0.5.0.

Thank you!

Hi is Jax for Metal or apple silicon is supported any more it seems the backend metal is still not officially supported by JAX as per there repo.

jax-metal failing due to incompatibility with jax 0.5.1 or later.
 
 
Q