Gather Operations

I am trying to implement the gather operation similar to how it is done in TensorFlow and my goal is to capture multiple columns/rows along the desired axis. But I am facing errors and the result isn't the same. I also can't understand the difference between the three different kinds of Gather operations since the documentation is so sparse.

TF Code:-

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3]

print(tf.gather(tensor, columns, axis=1).numpy())
# [[2. 4.]
#  [6. 8.]]

Swift Code:-

import Foundation
import MetalPerformanceShadersGraph

let graph = MPSGraph()
let device = MTLCreateSystemDefaultDevice()!

let b = 2
let w = 4

let inputShape = [NSNumber(value: b), NSNumber(value: w)]
let inputTensor = graph.placeholder(shape: inputShape, dataType: .float32, name: nil)

let desc = MPSNDArrayDescriptor(dataType: .float32, shape: inputShape)
let inputNDArray = MPSNDArray(device: device, descriptor: desc)

var inputValues: [Float] = []

for i in 1...b*w{
  inputValues.append(Float(i))
}

print("Input: \(inputValues)")
print("Input Shape: \(inputShape)")

inputNDArray.writeBytes(&inputValues, strideBytes: nil)
let inputs = MPSGraphTensorData(inputNDArray)


var indices: [Int32] = [1, 3]

var remainingIn: [Int32] = [2, 3, 4]


let indicesTensor = graph.constant(Data(bytes: &indices, count: indices.count * 4), shape: [2, 4], dataType: .int32)
print("Indices Shape: \(indicesTensor.shape)")
let gather = graph.gatherAlongAxis(0, updates: inputTensor, indices: indicesTensor, name: nil)
//let gather = graph.gather(withUpdatesTensor: inputTensor, indicesTensor: indicesTensor, axis: 1, batchDimensions: 0, name: nil)
//let gather = graph.gatherND(withUpdatesTensor: inputTensor, indicesTensor: indicesTensor, batchDimensions: 0, name: nil)

print("Gather: \(gather.shape)")

let results = graph.run(feeds: [inputTensor: inputs],
            targetTensors: [gather],
            targetOperations: nil)

let outputNDArray = results[gather]!.mpsndarray()
var outputValues: [Float32] = .init(repeating: 0, count: Int(truncating: gather.shape![0]) * Int(truncating: gather.shape![1]))

outputNDArray.readBytes(&outputValues, strideBytes: nil)

print("Output: \(outputValues)")
//Output: [5.0, 0.0]
Gather Operations
 
 
Q