How do you think add some dtype / dshape check in python layer? For example, in my case if there is
a check in tvm/python/tvm/relay/op/tensor.py, the error will be raised in python and make debug much easier.
Current Relay
def concatenate(data, axis):
data = list(data)
if not data:
raise ValueError("relay.concatenate requires data to be non-empty.")
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")
return _make.concatenate(Tuple(data), axis)
Current Error message
Traceback (most recent call last):
File "test.py", line 24, in <module>
graph, lib, params = relay.build(relay_module, target='llvm', params=params)
File "/Users/ligeng/Workspace/tvm/python/tvm/relay/build_module.py", line 207, in build
graph_json, mod, params = bld_mod.build(func, target, target_host, params)
File "/Users/ligeng/Workspace/tvm/python/tvm/relay/build_module.py", line 108, in build
self._build(func, target, target_host)
File "/Users/ligeng/Workspace/tvm/python/tvm/_ffi/_ctypes/function.py", line 210, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (8) 9 libtvm.dylib 0x0000000103da3086 tvm::relay::transform::PassNode::operator()(tvm::relay::Module const&) const + 54
[bt] (7) 8 libtvm.dylib 0x000000010413c7fe tvm::relay::transform::SequentialNode::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 1022
[bt] (6) 7 libtvm.dylib 0x000000010413cb7c tvm::relay::transform::Pass::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 156
[bt] (5) 6 libtvm.dylib 0x000000010413b237 tvm::relay::transform::FunctionPassNode::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 1223
[bt] (4) 5 libtvm.dylib 0x0000000103e67e00 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 1600
[bt] (3) 4 libtvm.dylib 0x0000000104186b48 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 472
[bt] (2) 3 libtvm.dylib 0x00000001041859d7 tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 135
[bt] (1) 2 libtvm.dylib 0x0000000103e32446 tvm::relay::ErrorReporter::RenderErrors(tvm::relay::Module const&, bool) + 5574
[bt] (0) 1 libtvm.dylib 0x00000001039aeba9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
File "/Users/ligeng/Workspace/tvm/src/relay/ir/error.cc", line 133
TVMError:
Error(s) have occurred. The program has been annotated with them:
v0.0.3
fn (%a: Tensor[(1, 5, 32, 32), float32]) -> Tensor[(1, 7, 32, 32), float32] {
%0 = layout_transform(%a, src_layout="NCHW", dst_layout="NCHW1c");
%1 = layout_transform(meta[relay.Constant][0], src_layout="OIHW", dst_layout="OIHW1i4o");
%2 = nn.contrib_conv2d_NCHWc(%0, %1, meta[relay.attrs.Conv2DAttrs][0]);
%3 = layout_transform(meta[relay.Constant][1], src_layout="OIHW", dst_layout="OIHW1i3o");
%4 = nn.contrib_conv2d_NCHWc(%0, %3, meta[relay.attrs.Conv2DAttrs][1]);
%5 = layout_transform(%4, src_layout="NCHW3c", dst_layout="NCHW4c");
%6 = (%2, %5);
%7 = concatenate(%6, axis=1);
layout_transform(%7, src_layout="NCHW4c", dst_layout="NCHW") in particular dimension 1 conflicts 4 does not match 7; unable to unify: `Tensor[(1, 4, 32, 32), float32]` and `Tensor[(1, 7, 32, 32), float32]`;
}
Proposed Solution
def concatenate(data, axis):
data = list(data)
if not data:
raise ValueError("relay.concatenate requires data to be non-empty.")
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")
shapes = []
for d in data:
shapes.append(d.type_annotation.shape)
# check ndim
ndims = list(len(_) for _ in shapes)
if ndims.count(ndims[0]) != len(ndims):
raise ValueError('relay.concatenate requires data to have same dimension.')
# check shape
for i in range(ndims[0]):
if i == (axis % ndims[0]): # skip for axis to be concatenated
continue
dshapes = [int(_[i]) for _ in shapes]
if len(set(dshapes)) != 1:
raise ValueError("relay.concatenate requires data to have same shape for non-concat axes.")
return _make.concatenate(Tuple(data), axis)
New Error message
Traceback (most recent call last):
File "test_3.py", line 11, in <module>
c = relay.concatenate([p1, p2], axis=0) # (1, 5, 32, 32)
File "/Users/ligeng/Workspace/tvm/python/tvm/relay/op/tensor.py", line 722, in concatenate
raise ValueError("relay.concatenate requires data to have same shape for non-concat axes.")
ValueError: relay.concatenate requires data to have same shape for non-concat axes.
It is more readable and user friendly.