[ONNX][Relay] Inconsistency between Int32 and Int64 after .view() opeartion

Problem Description

I am trying to deploy a PyTorch model to TVM. When loading the onnx version via relay.frontend.from_onnx, it throws the following errors

  %239 = take(%238, int64(0), axis=0)
  %240 = expand_dims(%239, axis=0)
  %241 = expand_dims(int64(-1), axis=0)
  %242 = (%240, %241)
  concatenate(%242) an internal invariant was violated while typechecking your program [00:12:52] /Users/ligeng/Workspace/tvm/src/relay/op/tensor/transform.cc:204: Check failed: e_dtype == dtype (int64 vs. int32) : relay.concatenate requires all tensors have the same dtype

The complete log is attached on Gist. It seems somewhere in the Relay is trying to concatenate Int64 and Int32 and causes the dtype error. After some exploration, I locate the related onnx snippets

  %545 : Float(1, 1432, 1, 1) = onnx::GlobalAveragePool(%544), scope: ProxylessNASNets/AdaptiveAvgPool2d[global_avg_pooling]
  %546 : Long() = onnx::Constant[value={0}](), scope: ProxylessNASNets
  %547 : Tensor = onnx::Shape(%545), scope: ProxylessNASNets
  %548 : Long() = onnx::Gather[axis=0](%547, %546), scope: ProxylessNASNets
  %549 : Long() = onnx::Constant[value={-1}](), scope: ProxylessNASNets
  %550 : Tensor = onnx::Unsqueeze[axes=[0]](%548)
  %551 : Tensor = onnx::Unsqueeze[axes=[0]](%549)
  %552 : Tensor = onnx::Concat[axis=0](%550, %551)
  %553 : Float(1, 1432) = onnx::Reshape(%545, %552), scope: ProxylessNASNets
  %output1 : Float(1, 1000) = onnx::Gemm[alpha=1, beta=1, transB=1](%553, %classifier.linear.weight, %classifier.linear.bias), scope: ProxylessNASNets/LinearLayer[classifier]/Linear[linear]

which is generated from the following pytorch code

    def forward(self, x):
        x = self.first_conv(x)
        for block in self.blocks:
            x = block(x)
        if self.feature_mix_layer:
            x = self.feature_mix_layer(x)
        x = self.global_avg_pooling(x)
        x = x.view(x.size(0), -1)  # flatten
        return x

It looks like the size x.size(0) is treated as int32 when parsing to onnx. So I first try to manually set the dtype.

         x = x.view(x.size(0), x.size(1))  # flatten

This time, Halide IR raises the error

[00:31:38] /Users/ligeng/Workspace/tvm/src/relay/pass/pass_manager.cc:312: Executing function pass : InferType with opt level: 0
[00:31:38] /Users/ligeng/Workspace/tvm/src/relay/pass/pass_manager.cc:312: Executing function pass : SimplifyInference with opt level: 0
[00:31:38] /Users/ligeng/Workspace/tvm/src/relay/pass/pass_manager.cc:312: Executing function pass : FuseOps with opt level: 1
[00:31:38] /Users/ligeng/Workspace/tvm/src/relay/pass/pass_manager.cc:312: Executing function pass : InferType with opt level: 0
Traceback (most recent call last):
  File "/Users/ligeng/Workspace/ProxylessNAS/load_onnx.py", line 32, in <module>
    sym, params = relay.frontend.from_onnx(onnx_model, shape_dict)
  File "/Users/ligeng/Workspace/tvm/python/tvm/relay/frontend/onnx.py", line 1246, in from_onnx
    sym, params = g.from_onnx(graph, opset)
  File "/Users/ligeng/Workspace/tvm/python/tvm/relay/frontend/onnx.py", line 1074, in from_onnx
    op = self._convert_operator(op_name, inputs, attr, opset)
  File "/Users/ligeng/Workspace/tvm/python/tvm/relay/frontend/onnx.py", line 1180, in _convert_operator
    sym = convert_map[op_name](inputs, attrs, self._params)
  File "/Users/ligeng/Workspace/tvm/python/tvm/relay/frontend/onnx.py", line 417, in _impl_v1
    graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
  File "/Users/ligeng/Workspace/tvm/python/tvm/relay/build_module.py", line 196, in build
    params)
  File "/Users/ligeng/Workspace/tvm/python/tvm/relay/build_module.py", line 107, in build
    self._build(func, target, target_host)
  File "/Users/ligeng/Workspace/tvm/python/tvm/_ffi/_ctypes/function.py", line 209, in __call__
    raise get_last_ffi_error()
  File "/Users/ligeng/Workspace/tvm/3rdparty/HalideIR/src/ir/IR.cpp", line 469
TVMError: Check failed: args[i].type() == Int(32): Args to call to halide function must be type Int(32)

Halide IR requires all args to be Int32, so my current work through is to loose the dtype check 3rdparty/HalideIR/src/ir/IR.cpp line 468 from

       for (size_t i = 0; i < args.size(); i++) {
            internal_assert(args[i].type() == Int(32))
            << "Args to call to halide function must be type Int(32)\n";
        }

to

       for (size_t i = 0; i < args.size(); i++) {
            internal_assert(args[i].type() == Int(32) or args[i].type() == Int(64))
            << "Args to call to halide function must be type Int(32)\n";
        }
1 Like

But the weird thing is, the Int64 is not set on PyTorch either ONNX. Because in PyTorch, tensor.size(0) returns a python native int object (should be created the same as constant value). It is proved in two places, first is in the onnx log

  %545 : Float(1, 1432, 1, 1) = onnx::GlobalAveragePool(%544), scope: ProxylessNASNets/AdaptiveAvgPool2d[global_avg_pooling]
  %546 : Long() = onnx::Constant[value={0}](), scope: ProxylessNASNets
  %547 : Tensor = onnx::Shape(%545), scope: ProxylessNASNets
  %548 : Long() = onnx::Gather[axis=0](%547, %546), scope: ProxylessNASNets
  %549 : Long() = onnx::Constant[value={1}](), scope: ProxylessNASNets
  %550 : Tensor = onnx::Shape(%545), scope: ProxylessNASNets
  %551 : Long() = onnx::Gather[axis=0](%550, %549), scope: ProxylessNASNets
  %552 : Tensor = onnx::Unsqueeze[axes=[0]](%548)
  %553 : Tensor = onnx::Unsqueeze[axes=[0]](%551)
  %554 : Tensor = onnx::Concat[axis=0](%552, %553)
  %555 : Float(1, 1432) = onnx::Reshape(%545, %554), scope: ProxylessNASNets
  %output1 : Float(1, 1000) = onnx::Gemm[alpha=1, beta=1, transB=1](%555, %classifier.linear.weight, %classifier.linear.bias), scope: ProxylessNASNets/LinearLayer[classifier]/Linear[linear]

As shown above, the constant value is set to be long. I further validate the guess by simply change the reshape operation from

     x = x.view(x.size(0),  -1)

to

     x = x.view(int(x.size(0)),  -1)

By doing so, the jit will not trace .size() and my code runs without error in this case.

The two observations suggest that both PyTorch and ONNX treated the .view() and .size() correctly. The problem should be some places in tvm onnx frontend, where the return value of .size() is forcely set to be Int64. But I cannot diagnose where the exact code is, could anyone help on this?

For now, I do not think it is necessary to use int64 to describe a Tensor size. But to make tvm tool more user friendly, do you guys think it is better to loose some dtype checking? At least, when there is a dtype inconsistency, let the program raises warning(s) instead of aborting the program. For example

  • HalideIR: allow both int32 and int64 as args.
  • TVM Binary Operations (e.g., sum, concat): allow binary operations between int64 and int32, float16, float32 and float64 (by casting low bits version to high bits ones)

This, though, may bring performance regression. Users can first make the program runnable and the warning message will let them know how to further boost the performance.

I agree most of your points, https://github.com/dmlc/tvm/issues/2588 is an ongoing issue that touches int64/int32 issue. The gist is to automatically convert int64->int32 in the codegen if analysis find there won’t be overflowing issues.

That would be great. Thanks for your reply.