I have been testing the PyTorch frontend and have found an issue with using saved torchscript versus in-memory traced torchscript.
What I have observed is that the input names to the graph can be altered by the call to
torch._C._jit_pass_inline(). Which means that the
get_graph_input_names() function can return the wrong input names - the ones before they have changed.
Pseudo steps to recreate:
- Load pretrained model (e.g. torch vision)
- Perform torch.jit.trace
- Perform torch.jit.save
- Load in saved torchscript using torch.jit.load
- Use the loaded torchscript -
get_graph_input_names()will return an input name e.g.
- On performing
from_pytorchthe jit_pass_inline() call will change the input name to e.g.
I propose a simple fix is to update the
get_graph_input_names() to call
_run_jit_passes() before getting and returning the input names.
To make sure this is tested, the frontend tests could have a second invocation of the model tests (via a flag argument) that performs the following in the
verify_model test function after load_model:
- traced = torch.jit.trace(model)
- traced = torch.jit.load(temporary)
I can supply a PR for these changes if they sound reasonable?