[PyTorch] wrongly defined graph input names cause KeyError

Hi,

I have been following the PyTorch frontend additions to TVM with interest, including the new quantization support that is in progress (PR #4977).

In my testing of this support I have found that after refactoring PR #4944 from_pytorch now needs the exact input names that match the traced graph (it used to work with any supplied names - though probably just happened to work).

You now need to call tvm.relay.frontend.python.get_graph_input_names() to populate the input shape dictionary correctly - which is reasonable, but if you do not you will get a KeyError in _get_op_inputs() as the first graph operator won’t be able to find the input:

Traceback (most recent call last):
  File ".../lib/python3.6/site-packages/ipdb/__main__.py", line 169, in main
    pdb._runscript(mainpyfile)
  File "/usr/lib/python3.6/pdb.py", line 1548, in _runscript
    self.run(statement)
  File "/usr/lib/python3.6/bdb.py", line 434, in run
    exec(cmd, globals, locals)
  File "<string>", line 1, in <module>
  File ".../incubator-tvm/tests/python/frontend/pytorch/test_forward.py", line 857, in <module>
    test_quantized_modules()
  File ".../incubator-tvm/tests/python/frontend/pytorch/qnn_test.py", line 280, in test_quantized_modules
    runtime = get_tvm_runtime(script_module, input_name, ishape)
  File ".../incubator-tvm/tests/python/frontend/pytorch/qnn_test.py", line 43, in get_tvm_runtime
    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
  File ".../lib/python3.6/site-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 1171, in from_pytorch
    output_index_map, ret_name)
  File ".../lib/python3.6/site-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 1079, in parse_operators
    inputs = _get_op_inputs(op_node, outputs, output_index_map)
  File ".../lib/python3.6/site-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 880, in _get_op_inputs
    for name in _get_input_names(op_node)]
  File ".../lib/python3.6/site-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 880, in <listcomp>
    for name in _get_input_names(op_node)]
KeyError: 'X'

It would be good if from_pytorch could check that the inputs matched the graph and error out earlier.

Alternatively, you already have code in parse_inputs() that could allow a look up conversion from input_shape name to the graph variable name, though as the input is a dictionary it won’t be ordered, but for my single input I fixed it by using ir_input.debugName() in the input_vars dictionary instead of input_name:

def parse_inputs(graph_inputs, input_shapes):
    """ Return Relay vars from torch input vars """
    ir_inputs = list(graph_inputs)
    input_vars = {}

    for input_name, ir_input in zip(input_shapes, ir_inputs[1:]):
        input_vars[ir_input.debugName()] = _expr.var(input_name,
                                           shape=input_shapes[input_name])
    return input_vars

See original github code.

Anyway, great stuff.

Jeremy.

2 Likes

@jjohnson-arm Thanks for being interested in using the converter. @alexwong and @masahi have been mainly working on it. Are you also interested in sending a fix/improvement and tag some people for review?

@jjohnson-arm Thank you for your interest in the torch frontend. I seriously want to make it the best frontend ever :slight_smile:

You are right that in the original implementation, the key of input shape could be arbitrary, and that we now require the same input name as in the IR. This is because the original impl was broken if there are more than one inputs. We have such example in our tests, on converting roi_align, which needs rois as the second input:

The example also shows how to use a custom convert map, which you might find useful.

For creating input_shape dict, I recommend zipping the output of get_graph_input_names with a list of shapes in the correct order. Even if you don’t know the name of inputs, you should know the shape and order of your inputs.

You can send a PR to add check on input name and I’ll merge immediately. Otherwise I’ll add the check in my ongoing PR https://github.com/apache/incubator-tvm/pull/4964. I’ll update this PR after my QNN PR is merged.

Thanks, I thought that support for multiple inputs was the reason you didn’t add the conversion back in.

I have a patch to perform the check in parse_inputs(), just going through internal review, will post PR soon.

parse_inputs was refactored in 4964, it is now called _get_relay_input_vars and it only has

def _get_relay_input_vars(input_shapes):
    """ Return Relay vars from input shapes """
    return {iname: _expr.var(iname, shape=ishape)
            for iname, ishape in input_shapes.items()}

I think a good place to add input name check is

We just need to go though graph.inputs() and make sure each debugName is in input shpe. I’d add _check_input_names(...) to do this.

Ok - understood - will refactor. Just thought that it gave an opportunity to warn about the inputs that weren’t needed and to ignore them, but maybe thats a step too far.

PR: https://github.com/apache/incubator-tvm/pull/4992