[PyTorch] [Frontend] graph input names can change using loaded torchscript


I have been testing the PyTorch frontend and have found an issue with using saved torchscript versus in-memory traced torchscript.

What I have observed is that the input names to the graph can be altered by the call to torch._C._jit_pass_inline(). Which means that the get_graph_input_names() function can return the wrong input names - the ones before they have changed.

Pseudo steps to recreate:

  • Load pretrained model (e.g. torch vision)
  • Perform torch.jit.trace
  • Perform torch.jit.save
  • Load in saved torchscript using torch.jit.load
  • Use the loaded torchscript - get_graph_input_names() will return an input name e.g. input.1
  • On performing from_pytorch the jit_pass_inline() call will change the input name to e.g. input.11

I propose a simple fix is to update the get_graph_input_names() to call _run_jit_passes() before getting and returning the input names.

To make sure this is tested, the frontend tests could have a second invocation of the model tests (via a flag argument) that performs the following in the verify_model test function after load_model:

  • traced = torch.jit.trace(model)
  • traced.save(temporary)
  • traced = torch.jit.load(temporary)

I can supply a PR for these changes if they sound reasonable?


Thanks, it seems to me the best to way to fix this issue is not to require users to supply the correct input names, since this is an impl detail of torch and we can figure out the names in from_pytorch.

Instead, users should just supply a list of input shapes in correct order. We call get_graph_input_names() after we run _run_jit_passes as you proposed, and zip names with a list of shape to get input_shapes. This way, users don’t have to bother with input names anymore and we can also remove _check_input_names.

The only downside is that the API would deviate from other frontends. Personally I don’t think that is an issue and we should an make API that makes the most sense for Torch. Since Torch frontend is still new, we can fix it now.

What do you think? cc @alexwong @zhiics @jwfromm @pyjhzwh

If we all agree to change the API (input_shapes would become a list of shapes, not a dict as in other frontend), we can proceed to fix it.

That sounds reasonable, maybe we could at least check that the inputs are valid “shape tuples”? Though this is probably a common thing across frontends that could be done.

yes, we can repurpose and rename _check_input_names to do necessary input validation.

For other frontends, I also remember being annoyed for having to supply input names. Unfortunately for them it is too late to fix. We shouldn’t repeat the same mistake :wink:

1 Like

The input names are really annoying. I think one use case of the name to shape dict is to avoid the wrong order of the inputs. How hard is it for users to supply the inputs in the correct order? And it is possible to connect the names after _run_jin_passes?

PyTorch users should know the correct order of inputs, because PyTorch modules forward(...) method expects its inputs to be in the correct order (otherwise they cannot run any training code).

Yes, it is very straightforward as long as the user supplied input shape list is correct. Something like below:

Thanks for clarification. I think this change makes sense to me.

@jjohnson-arm Do you want to send a PR about it? Otherwise I will, no problem

Yes - I can send a PR.

1 Like

Hmm… one issue we still have if we do this is that the user still needs to know the input names to set the data input for the relay model - i.e. relay_model.set_input(input_name, data)

So I presume we still need some way of sorting that out, so either we need a way of querying the relay_model for the names of the inputs - is there something already?

Or maybe it is just better to supply the post run_jit_passes input names via the original call and the original way you were doing it.

Oh you are right… I also realized that in a typical AOT deploy use case, we just load compiled models directly from exported libs, so there is no torchscript or relay models. But users still need to keep input names around somehow.

I agree that an ideal solution is for compiled runtime modules to enable querying a list of input names in a correct order, but right now there is no way to do that. There is GraphRuntime::GetInputIndex(...) (used in set_input), but we need an “inverse” of this function.

A non-runtime invasive solution is to ask users to give us a list of (input_name, input_shape), and we override the Torch input IR names with names provided by users. Users can just choose arbitrary names (“input0”, “input1”, etc.).

I think this is better than returning whatever names Torch chooses from our frontend and ask users to somehow keep these names around until deployment.

Ok. So just to see if I understand, you are proposing:

  • User supplies something like: [('input0', [1,2,3]), ('input1', [4,5])]
  • from_pytorch() changes the relay_graph to use these names on conversion
  • User then uses the same names when using compiled models.

Is that right?

I have some code working as above, but I am using an input conversion map (created after reading the input_shapes) in _get_op_inputs() to convert the op inputs to the user supplied names.

I am wondering if it would be possible just to append some conversion entries to the output_map_index instead and it would achieve the same thing?

Yes, exactly right.

I’m not completely sure what you mean here, but since _get_op_inputs looks for the original Torch IR input names, we need to overwrite the input names or add additional entries to outputs and output_index_map.

Overwrite can be done by setDebugName method. For the latter solution, we can add

    for torch_input_name, relay_var in zip(get_graph_input_names(script_module),
        output_index_map[torch_input_name] = len(outputs)


I realized that output_index_map is completely redundant if we make outputs a dict instead of a list. Because outputs is always accessed via output_index_map like this, (here, outputs is a list)


instead it should be just outputs[var_name].

@jjohnson-arm Does this make sense? If yes, feel free to remove output_index_map and make outputs a dict from node name to relay output values.

Thanks for the comments - they have helped to shed light on things - :bulb:!

I agree with the removal of the output_map_index, though my suggestion was to actually use this as a redirect to the same entries for the user supplied names. I.e. It would have entries for the user specified names (from _get_relay_input_vars) and then you add some re-directs to these same outputs list entries for the pytorch names.

But as you say, we could just add some extra output entries if we turn outputs into a dictionary, or just use the setDebugName to change the graph - I will have a look into both.

FYI, my initial method (just as a trial):

  • Read in user input_shapes and create simple conversion map from pytorch to user `{ ‘pytorch input.1’: ‘user input 1’ }
  • _get_relay_input_vars is still used to construct outputs from the user input_shapes, and gets added to outputs - so the user names are already there
  • in _get_op_inputs I use the new conversion map to convert the pytorch names (from _get_input_names) into the user names before looking them up in the output_map_index - heres where I could have just used the output_map_index instead of an extra map - I was being a bit overly cautious.

Posted PR - https://github.com/apache/incubator-tvm/pull/5204

Just to warm this up a bit. While graph input debug names can change, PyTorch does keep the stem stable. This is used e.g. for script_module.code and to give an error for missing inputs (try script_module()).

Unfortunately I think this will not help if you have two inputs called input.0 and input.1 (this is allowed). These will get remapped to something new like input.X and input.Y and it will be an assumption to work out which is which.

Unless I am missing something?

Actually, this can happen in the body of the function, but not here because the inputs actually come from a function signature. You can print traced_module.code to witness the translation (that is from where I tracked down the function reproducing the non-processed names). Another place where you can the argument names them directly and programmatically is in the schema of the traced module’s forward method: [a.name for a in traced_module.forward.schema.arguments].

I haven’t fully investigated what it would take to make PyTorch present the signature in a way retrievable by inspect.signature(it currently isn’t available), that might be the best way to present it.

That said, people appear to prefer the current API with its requirement to pass names and shapes regardless of whether they are already provided by the module, so I guess it’ll have to stay that way.

I just thought that it would help to analyze what is going on in the JIT to make informed decisions about how to convert models.

Best regards


1 Like