Relay.nn.conv2d asserting negative output dimension for this layout

Testing nn.conv2d with input “NHWC” and filters “OHWI”. I see LLVM asserting a negative dimension for the output tensor. graph_runtime.run() fails here.

Error

Traceback (most recent call last):

  File "test_conv2d.py", line 92, in <module>
    dilation=1)

  File "test_conv2d.py", line 80, in verify_conv2d_nhwc
    check_device(device)

  File "test_conv2d.py", line 74, in check_device
    graph_runtime.run()

  File "./incubator-tvm/python/tvm/contrib/graph_runtime.py", line 169, in run
    self._run()

  File "./incubator-tvm/python/tvm/_ffi/_ctypes/function.py", line 207, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (3) ./incubator-tvm/build/libtvm.so(TVMFuncCall+0x65) [0x7f347ed4dfe5]
  [bt] (2) ./incubator-tvm/build/libtvm.so(tvm::runtime::GraphRuntime::Run()+0x37) [0x7f347edc8477]
  [bt] (1) ./incubator-tvm/build/libtvm.so(+0xc0c3f7) [0x7f347edc83f7]
  [bt] (0) ./incubator-tvm/build/libtvm.so(+0xba4cc4) [0x7f347ed60cc4]
  File "./incubator-tvm/src/runtime/library_module.cc", line 91
TVMError: Check failed: ret == 0 (-1 vs. 0) : Assert fail: (-30 == int32(arg2.shape[1])), Argument arg2.shape[1] has an unsatisfied constraint

LLVM Assertions from "get_source()"

@__TVMAPISetLastError = linkonce dllexport local_unnamed_addr global void (i8*)* null, align 8
@__TVMBackendParallelLaunch = linkonce dllexport local_unnamed_addr global i32 (i32 (i32, %0*, i8*)*, i8*, i32)* null, align 8
@.str = private constant [68 x i8] c"Assert fail: (num_args == 3), fused_nn_conv2d: num_args should be 3\00", align 1
@.str.1 = private constant [143 x i8] c"Assert fail: ((((arg0.code == 3) || (arg0.code == 13)) || (arg0.code == 7)) || (arg0.code == 4)), fused_nn_conv2d: Expect arg[0] to be pointer\00", align 1
@.str.2 = private constant [143 x i8] c"Assert fail: ((((arg1.code == 3) || (arg1.code == 13)) || (arg1.code == 7)) || (arg1.code == 4)), fused_nn_conv2d: Expect arg[1] to be pointer\00", align 1
@.str.3 = private constant [143 x i8] c"Assert fail: ((((arg2.code == 3) || (arg2.code == 13)) || (arg2.code == 7)) || (arg2.code == 4)), fused_nn_conv2d: Expect arg[2] to be pointer\00", align 1
@.str.4 = private constant [55 x i8] c"Assert fail: (dev_type == 1), device_type need to be 1\00", align 1
@.str.5 = private constant [81 x i8] c"Assert fail: (4 == tvm_struct_get(arg0, 0, 4)), arg0.ndim is expected to equal 4\00", align 1
@.str.6 = private constant [186 x i8] c"Assert fail: (((tvm_struct_get(arg0, 0, 5) == (uint8)2) && (tvm_struct_get(arg0, 0, 6) == (uint8)32)) && (tvm_struct_get(arg0, 0, 7) == (uint16)1)), arg0.dtype is expected to be float32\00", align 1
@.str.7 = private constant [95 x i8] c"Assert fail: (1 == int32(arg0.shape[0])), Argument arg0.shape[0] has an unsatisfied constraint\00", align 1
@.str.8 = private constant [95 x i8] c"Assert fail: (1 == int32(arg0.shape[1])), Argument arg0.shape[1] has an unsatisfied constraint\00", align 1
@.str.9 = private constant [95 x i8] c"Assert fail: (1 == int32(arg0.shape[2])), Argument arg0.shape[2] has an unsatisfied constraint\00", align 1
@.str.10 = private constant [96 x i8] c"Assert fail: (32 == int32(arg0.shape[3])), Argument arg0.shape[3] has an unsatisfied constraint\00", align 1
@.str.11 = private constant [195 x i8] c"Assert fail: ((((1 == int32(arg0.strides[3])) && (32 == int32(arg0.strides[2]))) && (32 == int32(arg0.strides[1]))) && (32 == int32(arg0.strides[0]))), arg0.strides: expected to be compact array\00", align 1
@.str.12 = private constant [112 x i8] c"Assert fail: ((uint64)0 == tvm_struct_get(arg0, 0, 8)), Argument arg0.byte_offset has an unsatisfied constraint\00", align 1
@.str.13 = private constant [81 x i8] c"Assert fail: (4 == tvm_struct_get(arg1, 0, 4)), arg1.ndim is expected to equal 4\00", align 1
@.str.14 = private constant [186 x i8] c"Assert fail: (((tvm_struct_get(arg1, 0, 5) == (uint8)2) && (tvm_struct_get(arg1, 0, 6) == (uint8)32)) && (tvm_struct_get(arg1, 0, 7) == (uint16)1)), arg1.dtype is expected to be float32\00", align 1
@.str.15 = private constant [96 x i8] c"Assert fail: (32 == int32(arg1.shape[0])), Argument arg1.shape[0] has an unsatisfied constraint\00", align 1
@.str.16 = private constant [95 x i8] c"Assert fail: (1 == int32(arg1.shape[1])), Argument arg1.shape[1] has an unsatisfied constraint\00", align 1
@.str.17 = private constant [95 x i8] c"Assert fail: (1 == int32(arg1.shape[2])), Argument arg1.shape[2] has an unsatisfied constraint\00", align 1
@.str.18 = private constant [96 x i8] c"Assert fail: (32 == int32(arg1.shape[3])), Argument arg1.shape[3] has an unsatisfied constraint\00", align 1
@.str.19 = private constant [195 x i8] c"Assert fail: ((((1 == int32(arg1.strides[3])) && (32 == int32(arg1.strides[2]))) && (32 == int32(arg1.strides[1]))) && (32 == int32(arg1.strides[0]))), arg1.strides: expected to be compact array\00", align 1
@.str.20 = private constant [112 x i8] c"Assert fail: ((uint64)0 == tvm_struct_get(arg1, 0, 8)), Argument arg1.byte_offset has an unsatisfied constraint\00", align 1
@.str.21 = private constant [105 x i8] c"Assert fail: (1 == tvm_struct_get(arg1, 0, 10)), Argument arg1.device_type has an unsatisfied constraint\00", align 1
@.str.22 = private constant [107 x i8] c"Assert fail: (dev_id == tvm_struct_get(arg1, 0, 9)), Argument arg1.device_id has an unsatisfied constraint\00", align 1
@.str.23 = private constant [81 x i8] c"Assert fail: (4 == tvm_struct_get(arg2, 0, 4)), arg2.ndim is expected to equal 4\00", align 1
@.str.24 = private constant [186 x i8] c"Assert fail: (((tvm_struct_get(arg2, 0, 5) == (uint8)2) && (tvm_struct_get(arg2, 0, 6) == (uint8)32)) && (tvm_struct_get(arg2, 0, 7) == (uint16)1)), arg2.dtype is expected to be float32\00", align 1
@.str.25 = private constant [95 x i8] c"Assert fail: (1 == int32(arg2.shape[0])), Argument arg2.shape[0] has an unsatisfied constraint\00", align 1
@.str.26 = private constant [97 x i8] c"Assert fail: (-30 == int32(arg2.shape[1])), Argument arg2.shape[1] has an unsatisfied constraint\00", align 1
@.str.27 = private constant [95 x i8] c"Assert fail: (1 == int32(arg2.shape[2])), Argument arg2.shape[2] has an unsatisfied constraint\00", align 1
@.str.28 = private constant [96 x i8] c"Assert fail: (32 == int32(arg2.shape[3])), Argument arg2.shape[3] has an unsatisfied constraint\00", align 1
@.str.29 = private constant [197 x i8] c"Assert fail: ((((1 == int32(arg2.strides[3])) && (32 == int32(arg2.strides[2]))) && (32 == int32(arg2.strides[1]))) && (-960 == int32(arg2.strides[0]))), arg2.strides: expected to be compact array\00", align 1
@.str.30 = private constant [112 x i8] c"Assert fail: ((uint64)0 == tvm_struct_get(arg2, 0, 8)), Argument arg2.byte_offset has an unsatisfied constraint\00", align 1
@.str.31 = private constant [105 x i8] c"Assert fail: (1 == tvm_struct_get(arg2, 0, 10)), Argument arg2.device_type has an unsatisfied constraint\00", align 1
@.str.32 = private constant [107 x i8] c"Assert fail: (dev_id == tvm_struct_get(arg2, 0, 9)), Argument arg2.device_id has an unsatisfied constraint\00", align 1
@__tvm_main__ = weak local_unnamed_addr constant [16 x i8] c"fused_nn_conv2d\00", align 1

I would expect this line

@.str.26 = private constant [97 x i8] c"Assert fail: (-30 == int32(arg2.shape[1])), Argument arg2.shape[1] has an unsatisfied constraint\00", align 1

to instead be

@.str.26 = private constant [97 x i8] c"Assert fail: (1 == int32(arg2.shape[1])), Argument arg2.shape[1] has an unsatisfied constraint\00", align 1

as the output should be (1, 1, 1, 32)

Test script

import os
import tvm
import topi
from tvm import relay
from tvm.relay.expr_functor import ExprMutator
from tvm.relay.frontend.common import infer_type
import numpy as np

import topi.testing

np.random.seed(0)


def get_reference_data(A_shape, W_shape, dtype):
    """ Returns a param dict for weights and inputs"""
    a_np = np.random.uniform(size=A_shape).astype(dtype)
    w_np = np.random.uniform(size=W_shape).astype(dtype)

    params =  {"W": tvm.nd.array(w_np)}
    inputs =  {"A": tvm.nd.array(a_np)}

    return params, inputs


def verify_conv2d_nhwc(target_list, batch, 
                                    in_channel, 
                                    in_size, 
                                    num_filter, 
                                    kernel, 
                                    stride, 
                                    padding, 
                                    dilation=1):
    def check_device(device):
        if not tvm.module.enabled(device):
            raise Exception("TVM was not built with %s runtime enabled" % device)
        ######################################################################
        # Define a compute definition
        ######################################################################
        dtype = "float32"
        out_dtype = "float32"

        in_height = in_width = in_size

        A_shape = (batch, in_height, in_width, in_channel)
        W_shape = (num_filter, kernel, kernel, in_channel)

        A = relay.var("A", shape=A_shape, dtype=dtype)
        W = relay.var("W", shape=W_shape, dtype=dtype)

        conv2d = relay.nn.conv2d(A, W, strides=(stride, stride), 
                                       padding=(padding, padding),
                                       dilation=(dilation, dilation),
                                       kernel_size=(kernel, kernel),
                                       data_layout='NHWC', 
                                       kernel_layout='OHWI',
                                       out_dtype=dtype)

        # create the initial Relay IR module
        relay_ir = relay.Module.from_expr(conv2d)
        ctx = tvm.context(device)

        with relay.build_config(opt_level=3):
            graph_json, lib, params = relay.build_module.build(relay_ir, device)
            graph_runtime = tvm.contrib.graph_runtime.create(graph_json, lib, ctx)

        print(relay_ir)
        print(graph_json)
        print(lib.get_source())

        params, inputs = get_reference_data(A_shape, W_shape, dtype)

        graph_runtime.set_input(**params)
        graph_runtime.set_input(**inputs)
        graph_runtime.run()

        out = graph_runtime.get_output(0).asnumpy()
        print("Graph runtime output:", out.shape)

    for device in target_list:
        check_device(device)


if __name__ == '__main__':
    target_list = ["llvm"]
    verify_conv2d_nhwc(target_list, batch=1, 
                                    in_channel=32, 
                                    in_size=1, 
                                    num_filter=32, 
                                    kernel=1, 
                                    stride=1, 
                                    padding=0, 
                                    dilation=1)

Output dim appears correct in graph_json, seems issue is just with the generated code