Multiple inputs to graph_runtime.GraphModule

Hey,

I am trying to run a pytorch scripted function using this code:

import sys
import torch
from module import function
import tvm
from tvm import relay
from tvm.contrib import graph_runtime

left_elements = int(sys.argv[1])
right_elements = int(sys.argv[2])

left_side = torch.zeros([left_elements], dtype=torch.int32)
right_side = torch.zeros([right_elements], dtype=torch.int32)
for i in range(left_elements):
    left_side[i] = i

for i in range(right_elements):
    right_side[i] = i

traced_function = torch.jit.trace(function, (left_side,
                                             right_side))

input1_name = 'input0'
input2_name = 'input1'
shape_list = [(input1_name, (left_elements,)),
              (input2_name, (right_elements,))]
mod, params = relay.frontend.from_pytorch(traced_function,
                                          shape_list)
target = 'llvm'
target_host = 'llvm'
ctx = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, target_host=target_host, params=params)
    dtype = 'int32'
    m = graph_runtime.GraphModule(lib['default'](ctx))
   # TODO: pass multiple inputs to set_input
    m.set_input({input1_name : tvm.nd.array(left_side.numpy().astype(dtype)),
                 input2_name : tvm.nd.array(right_side.numpy().astype(dtype))})
    m.run()
    tvm_output = m.get_output(0)
    print(tvm_output)

My question is in the comment in the code. How can I pass multiple inputs to set_input? If I do it using the dictionary I get an error.

Thanks!

Just use a dict data struture and use ** to the set_input. It works for me when I want to use TVM to optimize transformer model (two input: src_input, tgt_input)

You can write the codes like bellow:

...
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, target_host=target_host, params=params)
    dtype = 'int32'
    m = graph_runtime.GraphModule(lib['default'](ctx))
    # TODO: pass multiple inputs to set_input
    m.set_input(**{input1_name : tvm.nd.array(left_side.numpy().astype(dtype)),
                 input2_name : tvm.nd.array(right_side.numpy().astype(dtype))})
    m.run()
    tvm_output = m.get_output(0)
    print(tvm_output)

use ** to unpack a dict type and it works.