I am trying to implement the gather operation similar to how it is done in TensorFlow and my goal is to capture multiple columns/rows along the desired axis. But I am facing errors and the result isn't the same. I also can't understand the difference between the three different kinds of Gather operations since the documentation is so sparse.
TF Code:-
import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3]
print(tf.gather(tensor, columns, axis=1).numpy())
# [[2. 4.]
# [6. 8.]]
Swift Code:-
import Foundation
import MetalPerformanceShadersGraph
let graph = MPSGraph()
let device = MTLCreateSystemDefaultDevice()!
let b = 2
let w = 4
let inputShape = [NSNumber(value: b), NSNumber(value: w)]
let inputTensor = graph.placeholder(shape: inputShape, dataType: .float32, name: nil)
let desc = MPSNDArrayDescriptor(dataType: .float32, shape: inputShape)
let inputNDArray = MPSNDArray(device: device, descriptor: desc)
var inputValues: [Float] = []
for i in 1...b*w{
inputValues.append(Float(i))
}
print("Input: \(inputValues)")
print("Input Shape: \(inputShape)")
inputNDArray.writeBytes(&inputValues, strideBytes: nil)
let inputs = MPSGraphTensorData(inputNDArray)
var indices: [Int32] = [1, 3]
var remainingIn: [Int32] = [2, 3, 4]
let indicesTensor = graph.constant(Data(bytes: &indices, count: indices.count * 4), shape: [2, 4], dataType: .int32)
print("Indices Shape: \(indicesTensor.shape)")
let gather = graph.gatherAlongAxis(0, updates: inputTensor, indices: indicesTensor, name: nil)
//let gather = graph.gather(withUpdatesTensor: inputTensor, indicesTensor: indicesTensor, axis: 1, batchDimensions: 0, name: nil)
//let gather = graph.gatherND(withUpdatesTensor: inputTensor, indicesTensor: indicesTensor, batchDimensions: 0, name: nil)
print("Gather: \(gather.shape)")
let results = graph.run(feeds: [inputTensor: inputs],
targetTensors: [gather],
targetOperations: nil)
let outputNDArray = results[gather]!.mpsndarray()
var outputValues: [Float32] = .init(repeating: 0, count: Int(truncating: gather.shape![0]) * Int(truncating: gather.shape![1]))
outputNDArray.readBytes(&outputValues, strideBytes: nil)
print("Output: \(outputValues)")
//Output: [5.0, 0.0]