[Frontend][ONNX][TopK] I've added a TopK op


#1

Hi,
I add an onnx TopK v10 OP, is there anyone help to check it?

class TopK(OnnxOpConverter):
    @classmethod
    def _impl_v10(cls, inputs, attrs, params):
        assert len(inputs) == 2
        new_attrs = {}
        new_attrs["axis"] = attrs.get("axis", -1)
        new_attrs["ret_type"] = "both"
        new_attrs["is_ascend"] = False
        new_attrs["dtype"] = "int32"
        new_attrs["k"] = params[inputs[1].name_hint].asnumpy()[0]
        return _op.topk(inputs[0], **new_attrs)

[Error] [ONNX] Relay.frontend.from_onnx
#2

could you add a test case for it?


#3
def test_forward_topk():
node = onnx.helper.make_node(
    'TopK',
    inputs=['x','k'],
    outputs=['values', 'indices'],
)
X = np.array([
    [0, 1, 2, 3],
    [4, 5, 6, 7],
    [8, 9, 10, 11],
], dtype=np.float32)
K = np.array([3], dtype=np.int64)
values_ref = np.array([
    [3, 2, 1],
    [7, 6, 5],
    [11, 10, 9],
    ],dtype=np.float32)
indices_ref = np.array([
    [3, 2, 1],
    [3, 2, 1],
    [3, 2, 1],
    ], dtype=np.int64)

k_tensor = onnx.helper.make_tensor(name='k', data_type=onnx.TensorProto.INT64, dims=(1,), vals=K)

graph = onnx.helper.make_graph([node],
                    'TopK_test',
                    inputs=[onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, list(X.shape)),
                    onnx.helper.make_tensor_value_info('k', onnx.TensorProto.INT64, list(K.shape))],
                    outputs=[onnx.helper.make_tensor_value_info('values', onnx.TensorProto.FLOAT, list(values_ref.shape)),
                    onnx.helper.make_tensor_value_info('indices', onnx.TensorProto.INT64, list(indices_ref.shape))],
                    initializer=[k_tensor],)
onnx.checker.check_graph(graph)

model = helper.make_model(graph, producer_name='TopK_test')
for target, ctx in ctx_list():    
    tvm_out = get_tvm_output(model, X, target, ctx, [values_ref.shape, indices_ref.shape], [values_ref.dtype, indices_ref.dtype])
    tvm.testing.assert_allclose([values_ref, indices_ref],tvm_out, rtol=1e-5, atol=1e-5)

#4

But I found an issue there, If I set attribute axis=0, the topk’s result will be strange, I don’t know why. So currently, I set axis always to ‘-1’ to prevent the issue happen.


#5

why do you set _impl_v10 not _impl_v1?


#6

Sorry for that mistake.
Cause onnx model file which I have is v10, so I start add TopK op in v10.
But I think v1 and v10 is similar.
I found the latest onnx doc has released the v11, :joy:


#7

maybe you want to prepare a PR?