[Relay][Concatenate] Missing shape checking for non-concat axes

Minimal code to re-produce

For example, now I am concatenating two tensors with shapes (1, 7, 32, 32), (1, 7, 16, 16) with axis=1. Obviously, such an operation shouldn’t be allowed but the following code can execute without error.

import tvm
import numpy as np
import tvm.relay as relay
from tvm.contrib import graph_runtime

dshape1 = (1, 7, 32, 32)
dshape2 = (1, 7, 16, 16)

p1 = relay.var('p1', shape=dshape1)
p2 = relay.var('p2', shape=dshape2)
c = relay.concatenate([p1, p2], axis=1)  # (1, 5, 32, 32)

func = relay.Function([p1, p2], c)
relay_module = relay.Module.from_expr(func)
params = {
    'p1': tvm.nd.array(np.random.rand(*dshape1).astype(np.float32)),
    'p2': tvm.nd.array(np.random.rand(*dshape2).astype(np.float32))
}
print(func.astext())
#

with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(relay_module, target='llvm', params=params)

ctx = tvm.cpu()
module = graph_runtime.create(graph, lib, ctx)
module.set_input(**params)
# run
module.run()
# get output
out = module.get_output(0, tvm.nd.empty((1, 14, 32, 32))).asnumpy()
print(out.shape)

Problem

Looking deep into ConcatenateRel() and ConcatenateCompute() C++ implementaion, I notice the code only performs checking for ndims (actually 3 times :expressionless: transform.cc#L252, transform.cc#L260, transform.h#L299), and sum up the axis for the concat axis transform.cc#L280. However, there is no checking for other axes to ensure they have the same sizes, which leads to the error shown in the example above.

Solution

  1. Add the checking for non-concat axes, also related test cases.
  2. Instead checking -ndim <= axis < ndimmultiple times, do axis %= ndim at the beginning.

Also, would it better to add some shape checking at relay python parts? By doing so, users can find where the error is without touch c++ codes. Let me know how you guys think so I can prepare a PR.

It would be great if we can check it in cxx. Some early checks in the cxx(building the nodes) will also be helpful as well.pr is more than welcomed

Instead checking -ndim <= axis < ndim multiple times, do axis %= ndim at the beginning.

This check is used to be compatible with numpy semantics. Imagine the case that you have ndim = 2 but your axis = -5.

Get it. I agree it would be better to be compatible with numpy.

How do you think add some dtype / dshape check in python layer? For example, in my case if there is
a check in tvm/python/tvm/relay/op/tensor.py, the error will be raised in python and make debug much easier.

Current Relay

def concatenate(data, axis):
    data = list(data)
    if not data:
        raise ValueError("relay.concatenate requires data to be non-empty.")
    if not isinstance(axis, int):
        raise ValueError("For now, we only support integer axis")
    return _make.concatenate(Tuple(data), axis)

Current Error message

Traceback (most recent call last):
  File "test.py", line 24, in <module>
    graph, lib, params = relay.build(relay_module, target='llvm', params=params)
  File "/Users/ligeng/Workspace/tvm/python/tvm/relay/build_module.py", line 207, in build
    graph_json, mod, params = bld_mod.build(func, target, target_host, params)
  File "/Users/ligeng/Workspace/tvm/python/tvm/relay/build_module.py", line 108, in build
    self._build(func, target, target_host)
  File "/Users/ligeng/Workspace/tvm/python/tvm/_ffi/_ctypes/function.py", line 210, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) 9   libtvm.dylib                        0x0000000103da3086 tvm::relay::transform::PassNode::operator()(tvm::relay::Module const&) const + 54
  [bt] (7) 8   libtvm.dylib                        0x000000010413c7fe tvm::relay::transform::SequentialNode::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 1022
  [bt] (6) 7   libtvm.dylib                        0x000000010413cb7c tvm::relay::transform::Pass::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 156
  [bt] (5) 6   libtvm.dylib                        0x000000010413b237 tvm::relay::transform::FunctionPassNode::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 1223
  [bt] (4) 5   libtvm.dylib                        0x0000000103e67e00 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 1600
  [bt] (3) 4   libtvm.dylib                        0x0000000104186b48 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 472
  [bt] (2) 3   libtvm.dylib                        0x00000001041859d7 tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 135
  [bt] (1) 2   libtvm.dylib                        0x0000000103e32446 tvm::relay::ErrorReporter::RenderErrors(tvm::relay::Module const&, bool) + 5574
  [bt] (0) 1   libtvm.dylib                        0x00000001039aeba9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  File "/Users/ligeng/Workspace/tvm/src/relay/ir/error.cc", line 133
TVMError: 
Error(s) have occurred. The program has been annotated with them:
v0.0.3
fn (%a: Tensor[(1, 5, 32, 32), float32]) -> Tensor[(1, 7, 32, 32), float32] {
  %0 = layout_transform(%a, src_layout="NCHW", dst_layout="NCHW1c");
  %1 = layout_transform(meta[relay.Constant][0], src_layout="OIHW", dst_layout="OIHW1i4o");
  %2 = nn.contrib_conv2d_NCHWc(%0, %1, meta[relay.attrs.Conv2DAttrs][0]);
  %3 = layout_transform(meta[relay.Constant][1], src_layout="OIHW", dst_layout="OIHW1i3o");
  %4 = nn.contrib_conv2d_NCHWc(%0, %3, meta[relay.attrs.Conv2DAttrs][1]);
  %5 = layout_transform(%4, src_layout="NCHW3c", dst_layout="NCHW4c");
  %6 = (%2, %5);
  %7 = concatenate(%6, axis=1);
  layout_transform(%7, src_layout="NCHW4c", dst_layout="NCHW") in particular dimension 1 conflicts 4 does not match 7; unable to unify: `Tensor[(1, 4, 32, 32), float32]` and `Tensor[(1, 7, 32, 32), float32]`; 
}

Proposed Solution

def concatenate(data, axis):
    data = list(data)
    if not data:
        raise ValueError("relay.concatenate requires data to be non-empty.")
    if not isinstance(axis, int):
        raise ValueError("For now, we only support integer axis")

    shapes = []
    for d in data:
        shapes.append(d.type_annotation.shape)

    # check ndim
    ndims = list(len(_) for _ in shapes)
    if ndims.count(ndims[0]) != len(ndims):
        raise ValueError('relay.concatenate requires data to have same dimension.')

    # check shape
    for i in range(ndims[0]):
        if i == (axis % ndims[0]): # skip for axis to be concatenated
            continue
        dshapes = [int(_[i]) for _ in shapes]
        if len(set(dshapes)) != 1:
            raise ValueError("relay.concatenate requires data to have same shape for non-concat axes.")

    return _make.concatenate(Tuple(data), axis)

New Error message

Traceback (most recent call last):
  File "test_3.py", line 11, in <module>
    c = relay.concatenate([p1, p2], axis=0)  # (1, 5, 32, 32)
  File "/Users/ligeng/Workspace/tvm/python/tvm/relay/op/tensor.py", line 722, in concatenate
    raise ValueError("relay.concatenate requires data to have same shape for non-concat axes.")
ValueError: relay.concatenate requires data to have same shape for non-concat axes.

It is more readable and user friendly.

I suggest keeping these type checking related work in a single place. Now it is inside c++ Rel function for every op. By doing this we support both c++ and python compiling interface.

given that not all types have type annotations, and some of them have to be inferred, we still need have the same check in the type relation

Got it. I am now scratching a PR to fix.

What is the proper way to compare the value for tvm:Expr? I note in line 280 there is

      concat_dim += e->shape[axis];

where concat_dim and e->shape are both IndexExpr. However, when I tried to compare them, it throws the error

  transform.cc:281:11: error: value of type 'tvm::Expr' is not contextually convertible to 'bool'

After going through the code, it seems tvm:Expr is just placeholder and the value is unknown until model is executed?

Expr can be either immediate values or expression that need to be evaluated at the runtime. You can use reporter->AssertEQ

Thanks for the hint! I have finished the pull request. @tqchen can you have a look?

import tvm.relay as relay
from tvm.contrib import graph_runtime
dshape1 = (2, 5, 3, 3)
dshape2 = (2, 3, 4, 3)

p1 = relay.var('p1', shape=dshape1)
p2 = relay.var('p2', shape=dshape2)
c = relay.stack([p1, p2], axis=1)

func = relay.Function([p1, p2], c)
# TVMError: 
# Error(s) have occurred. The program has been annotated with them:
# 
# In `main`: 
# v0.0.3
# fn (%p1: Tensor[(2, 5, 3, 3), float32], %p2: Tensor[(2, 3, 4, 3), float32]) {
#   %0 = (%p1, %p2);
#   stack(%0, axis=1) relay.stack requires all tensors have the same shape on non-stacking axes; 
# }