Thanks! That helped and Jax now recognizes the GPU.
Unfortunately, when I tried to run a simple example of the Newton algorithm from https://jax.quantecon.org/newtons_method.html it fails:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 11, in newton
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:4:0: error: failed to legalize operation 'mhlo.scatter'
<stdin>:4:0: note: called from
<stdin>:4:0: note: see current operation:
%2177 = "mhlo.scatter"(%2052, %2176, %2167) ({
^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):
"mhlo.return"(%arg7) : (tensor<f32>) -> ()
}) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<5000x128xf32>, tensor<1xsi32>, tensor<5000xf32>) -> tensor<5000x128xf32>
Running the same code on the CPU works fine.