I'm getting what looks like the same error during the backward pass for a transformer, but in this case it's less clear how to work around it:
layers.py:108:15: error: failed to legalize operation 'mhlo.dot_general'
attended = jnp.einsum('bsSh,bShd->bshd', weights, v)
^
layers.py:108:15: note: see current operation: %0 = "mhlo.dot_general"(%arg2, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1, 2], rhs_contracting_dimensions = [1, 3]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<16x256x4x64xf32>, tensor<16x256x256x4xf32>) -> tensor<16x64x256xf32>
edit: in my case the issues seems to be due to broadcasting, if I manually broadcast using jnp.repeat first the issue goes away:
if weights.shape[3] != v.shape[2]:
v = jnp.repeat(v, weights.shape[3] // v.shape[2], axis=2)
Topic:
Machine Learning & AI
SubTopic:
General
Tags: