Hey,
I am trying to run a pytorch scripted function using this code:
import sys
import torch
from module import function
import tvm
from tvm import relay
from tvm.contrib import graph_runtime
left_elements = int(sys.argv[1])
right_elements = int(sys.argv[2])
left_side = torch.zeros([left_elements], dtype=torch.int32)
right_side = torch.zeros([right_elements], dtype=torch.int32)
for i in range(left_elements):
left_side[i] = i
for i in range(right_elements):
right_side[i] = i
traced_function = torch.jit.trace(function, (left_side,
right_side))
input1_name = 'input0'
input2_name = 'input1'
shape_list = [(input1_name, (left_elements,)),
(input2_name, (right_elements,))]
mod, params = relay.frontend.from_pytorch(traced_function,
shape_list)
target = 'llvm'
target_host = 'llvm'
ctx = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, target_host=target_host, params=params)
dtype = 'int32'
m = graph_runtime.GraphModule(lib['default'](ctx))
# TODO: pass multiple inputs to set_input
m.set_input({input1_name : tvm.nd.array(left_side.numpy().astype(dtype)),
input2_name : tvm.nd.array(right_side.numpy().astype(dtype))})
m.run()
tvm_output = m.get_output(0)
print(tvm_output)
My question is in the comment in the code. How can I pass multiple inputs to set_input
? If I do it using the dictionary I get an error.
Thanks!