`nnvm.broadcast_div` raises `NNVMError` for scalars

Do NNVM broadcasting operators such as nnvm.symbol.broadcast_div support scalars?

When I’ve converted an ONNX ResNet model to TVM using nnvm.frontend.from_onnx(), an exception nnvm._base.NNVMError: ...: Check failed: out[i].ndim() == out_info[i].ndim() (2 vs. 0): broadcast_mul was thrown.

The trigger is giving a scalar to broadcasting operators, which is legal in NumPy’s broadcasting rule.

It works correctly (including inference results) when I:

  • let SHAPE to be non-scalar, i.e., SHAPE = [1] or SHAPE = [1, 1] in the reproduction code below,
  • use nnvm.symbol.mean() instead of nnvm.symbol.sum() and nnvm.symbol.broadcast_div(), or
  • use tvm.relay.frontend.from_onnx() to convert the ONNX model.

But it’s hard modify as such, as the ONNX model is automatically generated by a tool.

The following is my minimal reproduction code for the problem:

import nnvm
import numpy as np

BATCH_SIZE = 11
DTYPE = np.float32
SHAPE = []

# global average pooling in ResNets:
# (1) sum activations over area-axies (HW-axes of NCHW),
# (2) divide it by the area (H*W=49 here).
x = np.empty([BATCH_SIZE,2048,7,7], dtype=DTYPE)
conv5_x = nnvm.symbol.Variable('conv5_x', x)
area = nnvm.symbol.Variable('area', np.full(SHAPE, np.prod(x.shape[-2:]), dtype=DTYPE))
pool6 = nnvm.symbol.broadcast_div(nnvm.symbol.sum(data=conv5_x, axis=[2,3]), area)
graph = nnvm.graph.create(pool6)

with nnvm.compiler.build_config(opt_level=3):
    graph, lib, params = nnvm.compiler.build(
        graph, 'llvm', { 'conv5_x': x.shape }, params={})

Full stacktrace on my environment (TVM 0.5/CentOS 7.5) is as follows:

Traceback (most recent call last):
  File "test5.py", line 19, in <module>
    graph, 'llvm', { 'conv5_x': x.shape }, params={})
  File "/home/xxxxxxxx/anaconda3/envs/work/lib/python3.6/site-packages/nnvm/compiler/build_module.py", line 321, in build
    graph = graph.apply("GraphCompile")
  File "/home/xxxxxxxx/anaconda3/envs/work/lib/python3.6/site-packages/nnvm/graph.py", line 250, in apply
    check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
  File "/home/xxxxxxxx/anaconda3/envs/work/lib/python3.6/site-packages/nnvm/_base.py", line 91, in check_call
    raise NNVMError(py_str(_LIB.NNGetLastError()))
nnvm._base.NNVMError: [15:29:34] /home/xxxxxxxx/local/src/tvm/nnvm/src/compiler/compile_engine.cc:212: Check failed: out[i].ndim() == out_info[i].ndim() (2 vs. 0) : broadcast_div
Stack trace:
  [bt] (0) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x15f3e9) [0x7f4fdf9753e9]
  [bt] (1) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x19281c) [0x7f4fdf9a881c]
  [bt] (2) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x19355c) [0x7f4fdf9a955c]
  [bt] (3) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x190cc5) [0x7f4fdf9a6cc5]
  [bt] (4) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x18b948) [0x7f4fdf9a1948]
  [bt] (5) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x1bc81b) [0x7f4fdf9d281b]
  [bt] (6) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x17fc88) [0x7f4fdf995c88]
  [bt] (7) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x217c9f) [0x7f4fdfa2dc9f]
  [bt] (8) /home/xxxxxxxx/local/tvm/lib/libnnvm_compiler.so(+0x2174df) [0x7f4fdfa2d4df]