Post

Replies

Boosts

Views

Activity

Reply to jax.lax.conv_transpose not correctly implemented
Stride not (1,1) gives me that. I use manual scaling instead.: def upscale_nearest_neighbor(x, scale_factor=2): # Assuming x has shape (batch, height, width, channels) b, h, w, c = x.shape x = x.reshape(b, h, 1, w, 1, c) x = lax.tie_in(x, jnp.broadcast_to(x, (b, h, scale_factor, w, scale_factor, c))) return x.reshape(b, h * scale_factor, w * scale_factor, c) def deconv2d(x, w): x_upscaled = upscale_nearest_neighbor(x) return lax.conv_transpose( x_upscaled, w, strides=(1, 1), padding='SAME', dimension_numbers=("NHWC", "HWIO", "NHWC"))
Topic: Machine Learning & AI SubTopic: General Tags:
Oct ’23
Reply to jax.lax.conv_transpose not correctly implemented
Stride not (1,1) gives me that. I use manual scaling instead.: def upscale_nearest_neighbor(x, scale_factor=2): # Assuming x has shape (batch, height, width, channels) b, h, w, c = x.shape x = x.reshape(b, h, 1, w, 1, c) x = lax.tie_in(x, jnp.broadcast_to(x, (b, h, scale_factor, w, scale_factor, c))) return x.reshape(b, h * scale_factor, w * scale_factor, c) def deconv2d(x, w): x_upscaled = upscale_nearest_neighbor(x) return lax.conv_transpose( x_upscaled, w, strides=(1, 1), padding='SAME', dimension_numbers=("NHWC", "HWIO", "NHWC"))
Topic: Machine Learning & AI SubTopic: General Tags:
Replies
Boosts
Views
Activity
Oct ’23
Reply to Jax-metal - whisper-jax
I get this problem when using conv2dtranspose. Seems metal does not support all opperations yet. Did you find a fix?
Topic: Machine Learning & AI SubTopic: General Tags:
Replies
Boosts
Views
Activity
Sep ’23