Do NNVM broadcasting operators such as nnvm.symbol.broadcast_div
support scalars?
When I’ve converted an ONNX ResNet model to TVM using nnvm.frontend.from_onnx()
, an exception nnvm._base.NNVMError: ...: Check failed: out[i].ndim() == out_info[i].ndim() (2 vs. 0): broadcast_mul
was thrown.
The trigger is giving a scalar to broadcasting operators, which is legal in NumPy’s broadcasting rule.
It works correctly (including inference results) when I:
- let
SHAPE
to be non-scalar, i.e.,SHAPE = [1]
orSHAPE = [1, 1]
in the reproduction code below, - use
nnvm.symbol.mean()
instead ofnnvm.symbol.sum()
andnnvm.symbol.broadcast_div()
, or - use
tvm.relay.frontend.from_onnx()
to convert the ONNX model.
But it’s hard modify as such, as the ONNX model is automatically generated by a tool.
The following is my minimal reproduction code for the problem:
import nnvm
import numpy as np
BATCH_SIZE = 11
DTYPE = np.float32
SHAPE = []
# global average pooling in ResNets:
# (1) sum activations over area-axies (HW-axes of NCHW),
# (2) divide it by the area (H*W=49 here).
x = np.empty([BATCH_SIZE,2048,7,7], dtype=DTYPE)
conv5_x = nnvm.symbol.Variable('conv5_x', x)
area = nnvm.symbol.Variable('area', np.full(SHAPE, np.prod(x.shape[-2:]), dtype=DTYPE))
pool6 = nnvm.symbol.broadcast_div(nnvm.symbol.sum(data=conv5_x, axis=[2,3]), area)
graph = nnvm.graph.create(pool6)
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(
graph, 'llvm', { 'conv5_x': x.shape }, params={})
Full stacktrace on my environment (TVM 0.5/CentOS 7.5) is as follows:
Traceback (most recent call last):
File "test5.py", line 19, in <module>
graph, 'llvm', { 'conv5_x': x.shape }, params={})
File "/home/xxxxxxxx/anaconda3/envs/work/lib/python3.6/site-packages/nnvm/compiler/build_module.py", line 321, in build
graph = graph.apply("GraphCompile")
File "/home/xxxxxxxx/anaconda3/envs/work/lib/python3.6/site-packages/nnvm/graph.py", line 250, in apply
check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
File "/home/xxxxxxxx/anaconda3/envs/work/lib/python3.6/site-packages/nnvm/_base.py", line 91, in check_call
raise NNVMError(py_str(_LIB.NNGetLastError()))
nnvm._base.NNVMError: [15:29:34] /home/xxxxxxxx/local/src/tvm/nnvm/src/compiler/compile_engine.cc:212: Check failed: out[i].ndim() == out_info[i].ndim() (2 vs. 0) : broadcast_div
Stack trace:
[bt] (0) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x15f3e9) [0x7f4fdf9753e9]
[bt] (1) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x19281c) [0x7f4fdf9a881c]
[bt] (2) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x19355c) [0x7f4fdf9a955c]
[bt] (3) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x190cc5) [0x7f4fdf9a6cc5]
[bt] (4) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x18b948) [0x7f4fdf9a1948]
[bt] (5) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x1bc81b) [0x7f4fdf9d281b]
[bt] (6) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x17fc88) [0x7f4fdf995c88]
[bt] (7) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x217c9f) [0x7f4fdfa2dc9f]
[bt] (8) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x2174df) [0x7f4fdfa2d4df]