def test_forward_topk():
node = onnx.helper.make_node(
'TopK',
inputs=['x','k'],
outputs=['values', 'indices'],
)
X = np.array([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
], dtype=np.float32)
K = np.array([3], dtype=np.int64)
values_ref = np.array([
[3, 2, 1],
[7, 6, 5],
[11, 10, 9],
],dtype=np.float32)
indices_ref = np.array([
[3, 2, 1],
[3, 2, 1],
[3, 2, 1],
], dtype=np.int64)
k_tensor = onnx.helper.make_tensor(name='k', data_type=onnx.TensorProto.INT64, dims=(1,), vals=K)
graph = onnx.helper.make_graph([node],
'TopK_test',
inputs=[onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, list(X.shape)),
onnx.helper.make_tensor_value_info('k', onnx.TensorProto.INT64, list(K.shape))],
outputs=[onnx.helper.make_tensor_value_info('values', onnx.TensorProto.FLOAT, list(values_ref.shape)),
onnx.helper.make_tensor_value_info('indices', onnx.TensorProto.INT64, list(indices_ref.shape))],
initializer=[k_tensor],)
onnx.checker.check_graph(graph)
model = helper.make_model(graph, producer_name='TopK_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, X, target, ctx, [values_ref.shape, indices_ref.shape], [values_ref.dtype, indices_ref.dtype])
tvm.testing.assert_allclose([values_ref, indices_ref],tvm_out, rtol=1e-5, atol=1e-5)