This is reproducible with jax-metal 0.0.5. The lowering pattern need to be expanded to handle contracting_dimensions size > 1. To workaround, could you make the below changes and give it a try:
a = np.random.rand(11, 12, 13, 11, 12).reshape(132, 13, 11, 12)
b = np.random.rand(11, 12, 13).reshape(132, 13)
#subscripts = 'ijklm,ijk->lmk'
subscripts = 'iklm,ik->lmk'
It generates matching result on my side.
Thx for reporting it. Several bugs of advanced indexing, involving GatherOp and ScatterOp conversion have been fixed at the tip. The example in the post shall be fixed. The fixes will be integrated into next release of jax-metal.
This is reproducible with jax-metal 0.0.5. The lowering pattern need to be expanded to handle contracting_dimensions size > 1. To workaround, could you make the below changes and give it a try:
a = np.random.rand(11, 12, 13, 11, 12).reshape(132, 13, 11, 12)
b = np.random.rand(11, 12, 13).reshape(132, 13)
#subscripts = 'ijklm,ijk->lmk'
subscripts = 'iklm,ik->lmk'
It generates matching result on my side.
Thx for reporting it. Several bugs of advanced indexing, involving GatherOp and ScatterOp conversion have been fixed at the tip. The example in the post shall be fixed. The fixes will be integrated into next release of jax-metal.