Problem Description
I am trying to deploy a PyTorch model to TVM. When loading the onnx version via relay.frontend.from_onnx
, it throws the following errors
%239 = take(%238, int64(0), axis=0)
%240 = expand_dims(%239, axis=0)
%241 = expand_dims(int64(-1), axis=0)
%242 = (%240, %241)
concatenate(%242) an internal invariant was violated while typechecking your program [00:12:52] /Users/ligeng/Workspace/tvm/src/relay/op/tensor/transform.cc:204: Check failed: e_dtype == dtype (int64 vs. int32) : relay.concatenate requires all tensors have the same dtype
The complete log is attached on Gist. It seems somewhere in the Relay is trying to concatenate Int64 and Int32 and causes the dtype error. After some exploration, I locate the related onnx snippets
%545 : Float(1, 1432, 1, 1) = onnx::GlobalAveragePool(%544), scope: ProxylessNASNets/AdaptiveAvgPool2d[global_avg_pooling]
%546 : Long() = onnx::Constant[value={0}](), scope: ProxylessNASNets
%547 : Tensor = onnx::Shape(%545), scope: ProxylessNASNets
%548 : Long() = onnx::Gather[axis=0](%547, %546), scope: ProxylessNASNets
%549 : Long() = onnx::Constant[value={-1}](), scope: ProxylessNASNets
%550 : Tensor = onnx::Unsqueeze[axes=[0]](%548)
%551 : Tensor = onnx::Unsqueeze[axes=[0]](%549)
%552 : Tensor = onnx::Concat[axis=0](%550, %551)
%553 : Float(1, 1432) = onnx::Reshape(%545, %552), scope: ProxylessNASNets
%output1 : Float(1, 1000) = onnx::Gemm[alpha=1, beta=1, transB=1](%553, %classifier.linear.weight, %classifier.linear.bias), scope: ProxylessNASNets/LinearLayer[classifier]/Linear[linear]
which is generated from the following pytorch code
def forward(self, x):
x = self.first_conv(x)
for block in self.blocks:
x = block(x)
if self.feature_mix_layer:
x = self.feature_mix_layer(x)
x = self.global_avg_pooling(x)
x = x.view(x.size(0), -1) # flatten
return x
It looks like the size x.size(0)
is treated as int32
when parsing to onnx. So I first try to manually set the dtype.
x = x.view(x.size(0), x.size(1)) # flatten
This time, Halide IR raises the error
[00:31:38] /Users/ligeng/Workspace/tvm/src/relay/pass/pass_manager.cc:312: Executing function pass : InferType with opt level: 0
[00:31:38] /Users/ligeng/Workspace/tvm/src/relay/pass/pass_manager.cc:312: Executing function pass : SimplifyInference with opt level: 0
[00:31:38] /Users/ligeng/Workspace/tvm/src/relay/pass/pass_manager.cc:312: Executing function pass : FuseOps with opt level: 1
[00:31:38] /Users/ligeng/Workspace/tvm/src/relay/pass/pass_manager.cc:312: Executing function pass : InferType with opt level: 0
Traceback (most recent call last):
File "/Users/ligeng/Workspace/ProxylessNAS/load_onnx.py", line 32, in <module>
sym, params = relay.frontend.from_onnx(onnx_model, shape_dict)
File "/Users/ligeng/Workspace/tvm/python/tvm/relay/frontend/onnx.py", line 1246, in from_onnx
sym, params = g.from_onnx(graph, opset)
File "/Users/ligeng/Workspace/tvm/python/tvm/relay/frontend/onnx.py", line 1074, in from_onnx
op = self._convert_operator(op_name, inputs, attr, opset)
File "/Users/ligeng/Workspace/tvm/python/tvm/relay/frontend/onnx.py", line 1180, in _convert_operator
sym = convert_map[op_name](inputs, attrs, self._params)
File "/Users/ligeng/Workspace/tvm/python/tvm/relay/frontend/onnx.py", line 417, in _impl_v1
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
File "/Users/ligeng/Workspace/tvm/python/tvm/relay/build_module.py", line 196, in build
params)
File "/Users/ligeng/Workspace/tvm/python/tvm/relay/build_module.py", line 107, in build
self._build(func, target, target_host)
File "/Users/ligeng/Workspace/tvm/python/tvm/_ffi/_ctypes/function.py", line 209, in __call__
raise get_last_ffi_error()
File "/Users/ligeng/Workspace/tvm/3rdparty/HalideIR/src/ir/IR.cpp", line 469
TVMError: Check failed: args[i].type() == Int(32): Args to call to halide function must be type Int(32)
Halide IR requires all args to be Int32, so my current work through is to loose the dtype check 3rdparty/HalideIR/src/ir/IR.cpp line 468
from
for (size_t i = 0; i < args.size(); i++) {
internal_assert(args[i].type() == Int(32))
<< "Args to call to halide function must be type Int(32)\n";
}
to
for (size_t i = 0; i < args.size(); i++) {
internal_assert(args[i].type() == Int(32) or args[i].type() == Int(64))
<< "Args to call to halide function must be type Int(32)\n";
}