Gather_nd semantics

Currently Relay gather_nd op uses the mxnet semantic, which each column in indices indicates the indices in data. However, Tensorflow gather_nd uses each row to indicate the indices.

Let’s use this topic is to discuss whether we should change the current semantics from mxnet gather_nd to tf gather_nd.

cc @kazum @srkreddy1238 @Laurawly

I think we don’t need to change the current semantics. We can easily implement Tensorflow gather_nd with the mxnet gather_nd (and vice versa).

Here is a pseudo code:

tf_gather_nd(data, indices) = relay.gather_nd(data, transpose(indices, [N-1, 0, 1, ..., N-2]))

where N is the dimension of indices.

The current tensorflow frontend doesn’t do this conversion. I’ll send a fix for this.

1 Like
1 Like

seems incorrect, can not pass my test case@kazum

#######################################################################
# GatherNd
# --------------------------
def _gather_nd(in_shape, indices):
    """test operator GatherNd"""
    np_indices = np.asarray(indices, dtype='int32')
    tf.reset_default_graph()
    with tf.Graph().as_default():
        np_data = np.random.uniform(size=in_shape).astype("float32")
        in_data = tf.placeholder(tf.float32, in_shape, name="in_data")
        in_indices = tf.placeholder(tf.float32, np_indices.shape, name="in_indices")
        out = tf.gather_nd(in_data, indices)
        compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'in_indices:0'], out.name)


def test_forward_gather_nd():
    _gather_nd((2, 3), [[0, 0], [1, 1]])
    _gather_nd((2, 3), [[1], [0]])
    _gather_nd((2, 3, 4), [[1]])
    _gather_nd((4, 3, 2), [[0, 1], [1, 0]])
    _gather_nd((2, 4), [[[0, 0]], [[0, 1]]])
    _gather_nd((4, 2), [[[1]], [[0]]])
    _gather_nd((2, 4, 2), [[[1]], [[0]]])
    _gather_nd((2, 2, 3), [[[0, 1], [1, 0]], [[0, 0], [1, 1]]])
    _gather_nd((2, 3, 3), [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]])

These lines should be

in_indices = tf.placeholder(tf.int32, np_indices.shape, name="in_indices")
out = tf.gather_nd(in_data, in_indices)

then, your test worked correctly on my environment.

ha-ha,great. can you help me review my pull request about cumsum? thanks.