Hi,
I have been following the PyTorch frontend additions to TVM with interest, including the new quantization support that is in progress (PR #4977).
In my testing of this support I have found that after refactoring PR #4944 from_pytorch
now needs the exact input names that match the traced graph (it used to work with any supplied names - though probably just happened to work).
You now need to call tvm.relay.frontend.python.get_graph_input_names()
to populate the input shape dictionary correctly - which is reasonable, but if you do not you will get a KeyError
in _get_op_inputs()
as the first graph operator won’t be able to find the input:
Traceback (most recent call last):
File ".../lib/python3.6/site-packages/ipdb/__main__.py", line 169, in main
pdb._runscript(mainpyfile)
File "/usr/lib/python3.6/pdb.py", line 1548, in _runscript
self.run(statement)
File "/usr/lib/python3.6/bdb.py", line 434, in run
exec(cmd, globals, locals)
File "<string>", line 1, in <module>
File ".../incubator-tvm/tests/python/frontend/pytorch/test_forward.py", line 857, in <module>
test_quantized_modules()
File ".../incubator-tvm/tests/python/frontend/pytorch/qnn_test.py", line 280, in test_quantized_modules
runtime = get_tvm_runtime(script_module, input_name, ishape)
File ".../incubator-tvm/tests/python/frontend/pytorch/qnn_test.py", line 43, in get_tvm_runtime
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
File ".../lib/python3.6/site-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 1171, in from_pytorch
output_index_map, ret_name)
File ".../lib/python3.6/site-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 1079, in parse_operators
inputs = _get_op_inputs(op_node, outputs, output_index_map)
File ".../lib/python3.6/site-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 880, in _get_op_inputs
for name in _get_input_names(op_node)]
File ".../lib/python3.6/site-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 880, in <listcomp>
for name in _get_input_names(op_node)]
KeyError: 'X'
It would be good if from_pytorch
could check that the inputs matched the graph and error out earlier.
Alternatively, you already have code in parse_inputs()
that could allow a look up conversion from input_shape name to the graph variable name, though as the input is a dictionary it won’t be ordered, but for my single input I fixed it by using ir_input.debugName()
in the input_vars
dictionary instead of input_name
:
def parse_inputs(graph_inputs, input_shapes):
""" Return Relay vars from torch input vars """
ir_inputs = list(graph_inputs)
input_vars = {}
for input_name, ir_input in zip(input_shapes, ir_inputs[1:]):
input_vars[ir_input.debugName()] = _expr.var(input_name,
shape=input_shapes[input_name])
return input_vars
Anyway, great stuff.
Jeremy.