Hi,
I have been testing the PyTorch frontend and have found an issue with using saved torchscript versus in-memory traced torchscript.
What I have observed is that the input names to the graph can be altered by the call to torch._C._jit_pass_inline()
. Which means that the get_graph_input_names()
function can return the wrong input names - the ones before they have changed.
Pseudo steps to recreate:
- Load pretrained model (e.g. torch vision)
- Perform torch.jit.trace
- Perform torch.jit.save
- Load in saved torchscript using torch.jit.load
- Use the loaded torchscript -
get_graph_input_names()
will return an input name e.g.input.1
- On performing
from_pytorch
the jit_pass_inline() call will change the input name to e.g.input.11
I propose a simple fix is to update the get_graph_input_names()
to call _run_jit_passes()
before getting and returning the input names.
To make sure this is tested, the frontend tests could have a second invocation of the model tests (via a flag argument) that performs the following in the verify_model
test function after load_model:
- traced = torch.jit.trace(model)
- traced.save(temporary)
- traced = torch.jit.load(temporary)
I can supply a PR for these changes if they sound reasonable?
Jeremy.