same problem here, but when I changed the instructions to
git clone https://github.com/google/jax.git --branch jaxlib-v0.4.11 --single-branch
and
python -m pip install jax==v0.4.11
it now seems to recognize the GPU:
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
Metal device set to: Apple M2 Max
systemMemory: 96.00 GB
maxCacheSize: 36.00 GB
METAL
>>> import jax
>>> jax.devices()
[MetalDevice(id=0, process_index=0)]
>>> jax.devices()[0].platform
'METAL'
>>> jax.devices()[0].device_kind
'Metal'
>>> jax.devices()[0].client.platform
'METAL'
>>> jax.devices()[0].client.runtime_type
'tfrt'
But now, x = jnp.ones((10000, 10000)) generates errors:
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: -:0:0: error: bytecode version 5 is newer than the current version 1
Topic:
Machine Learning & AI
SubTopic:
General
Tags: