I'm having the exact same issue. Using jax 0.4.11, jaxlib 0.4.10 and jax-metal 0.0.2. The gpu is not recognised on my 2020 MBA M1.
Any other things I could possibly try?
Topic:
Machine Learning & AI
SubTopic:
General
Tags: