NMS compile fails for CUDA target, but works fine for LLVM target

Hi:

I am trying to compile a small Tensorflow graph with NMS operator. It works fine with LLVM target, but it fails with CUDA target.

The output log:

Traceback (most recent call last):
  File "debug_nms.py", line 65, in <module>
    nms_lab()
  File "debug_nms.py", line 50, in nms_lab
    exe = relay.vm.compile(mod, target=target, params=params)
  File "/root/Codes/tvm_in_mac/python/tvm/relay/backend/vm.py", line 69, in compile
    compiler.lower(mod, target, target_host)
  File "/root/Codes/tvm_in_mac/python/tvm/relay/backend/vm.py", line 135, in lower
    self._lower(mod, target, target_host)
  File "/root/Codes/tvm_in_mac/python/tvm/_ffi/_ctypes/packed_func.py", line 225, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /root/Codes/tvm_in_mac/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::EmitInvokeTVMOp(tvm::relay::Function const&, tvm::RelayExpr const&, tvm::RelayExpr const&)+0x91d) [0x7f60fa807d9d]
  [bt] (7) /root/Codes/tvm_in_mac/build/libtvm.so(tvm::relay::CompileEngineImpl::Lower(tvm::relay::CCacheKey const&)+0x20) [0x7f60fa7b5b60]
  [bt] (6) /root/Codes/tvm_in_mac/build/libtvm.so(tvm::relay::CompileEngineImpl::LowerInternal(tvm::relay::CCacheKey const&)+0x75a) [0x7f60fa7b4c4a]
  [bt] (5) /root/Codes/tvm_in_mac/build/libtvm.so(tvm::relay::ScheduleGetter::Create(tvm::relay::Function const&)+0x94b) [0x7f60fa7b399b]
  [bt] (4) /root/Codes/tvm_in_mac/build/libtvm.so(tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)+0xa6) [0x7f60fa7b7b66]
  [bt] (3) /root/Codes/tvm_in_mac/build/libtvm.so(tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x89) [0x7f60fa7b78f9]
  [bt] (2) /root/Codes/tvm_in_mac/build/libtvm.so(tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)#6}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)+0x27) [0x7f60fa7aa107]
  [bt] (1) /root/Codes/tvm_in_mac/build/libtvm.so(tvm::relay::ScheduleGetter::VisitExpr_(tvm::relay::CallNode const*)+0x534) [0x7f60fa7b2084]
  [bt] (0) /root/Codes/tvm_in_mac/build/libtvm.so(+0x15bfa6b) [0x7f60fa924a6b]
  File "/root/Codes/tvm_in_mac/python/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun
    rv = local_pyfunc(*pyargs)
  File "/root/Codes/tvm_in_mac/python/tvm/relay/backend/compile_engine.py", line 263, in lower_call
    op, call.attrs, inputs, ret_type, target)
  File "/root/Codes/tvm_in_mac/python/tvm/relay/backend/compile_engine.py", line 199, in select_implementation
    outs = impl.compute(attrs, inputs, out_type)
  File "/root/Codes/tvm_in_mac/python/tvm/relay/op/op.py", line 89, in compute
    return _OpImplementationCompute(self, attrs, inputs, out_type)
  File "/root/Codes/tvm_in_mac/python/tvm/_ffi/_ctypes/packed_func.py", line 225, in __call__
    raise get_last_ffi_error()
  [bt] (4) /root/Codes/tvm_in_mac/build/libtvm.so(TVMFuncCall+0x61) [0x7f60fa928651]
  [bt] (3) /root/Codes/tvm_in_mac/build/libtvm.so(+0x15009cd) [0x7f60fa8659cd]
  [bt] (2) /root/Codes/tvm_in_mac/build/libtvm.so(tvm::relay::OpImplementation::Compute(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)+0xbc) [0x7f60fa8657ac]
  [bt] (1) /root/Codes/tvm_in_mac/build/libtvm.so(tvm::runtime::Array<tvm::te::Tensor, void> tvm::runtime::TVMPODValue_::AsObjectRef<tvm::runtime::Array<tvm::te::Tensor, void> >() const+0x3af) [0x7f60fa1290df]
  [bt] (0) /root/Codes/tvm_in_mac/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x67) [0x7f60f9f5c907]
  File "/root/Codes/tvm_in_mac/include/tvm/runtime/packed_func.h", line 1402
TVMError: Check failed: ObjectTypeChecker<TObjectRef>: :Check(ptr): Expect Array[Tensor] but get Tensor

The test script


import os
import tvm
from tvm import te
from tvm import relay


bx_shape = (20, 4)
score_shape = (20,)
iou_threshold = 0.5
score_threshold = 0.6
out_size = 10
dtype = "float32"
boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype)
scores = np.random.uniform(size=score_shape).astype(dtype)

def get_nms_relay_ir_module():
    nms_graph = tf.Graph()
    with nms_graph.as_default():
        in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1")
        in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2")
        nms_result = tf.image.non_max_suppression(
                                    boxes=in_data_1, scores=in_data_2,
                                    max_output_size=out_size, iou_threshold=iou_threshold,
                                    score_threshold=score_threshold, name="nms")
    nms_graph_def = nms_graph.as_graph_def()
    layout = "NCHW"
    out_names = ['nms/NonMaxSuppressionV3:0']
    input_data = [boxes, scores]
    input_node = ['in_data_1', 'in_data_2']
    shape_dict = {e: i.shape for e, i in zip(input_node, input_data)}
    mod, params = relay.frontend.from_tensorflow(nms_graph_def,
                                                layout=layout,
                                                shape=shape_dict,
                                                outputs=out_names)
    return mod, params


def nms_lab():
    use_gpu = True
    if use_gpu:
        target = "cuda"
        context = tvm.gpu()
    else:
        target = "llvm"
        context = tvm.cpu()

    mod, params = get_nms_relay_ir_module()
    exe = relay.vm.compile(mod, target=target, params=params)
    code, lib = exe.save()

    des_exec = tvm.runtime.vm.Executable.load_exec(code, lib)
    des_vm = tvm.runtime.vm.VirtualMachine(des_exec)
    des_vm.init(context)
    args = (scores, boxes)

    print('\n****** Start the running **********')
    ret = des_vm.run(*args)
    print('ret type: ', type(ret))
    print('ret: ', ret)


if __name__ == '__main__':
    nms_lab()

I believe this bug happens because the non_max_suppression funtion in topi/python/topi/cuda/nms.py directly returns the box_indices, which is a tensor rather than a vector of tensors.

Besides, according to the non_max_suppression funtion in topi/python/topi/vision/nms.py, when return_indices is true, a function called hybrid_rearrange_indices_out is called to generate two tensors: one with shape [batch_size, num_anchors], and another one with shape [batch_size, 1].

Therefore, I think a similar function should be developed for non_max_suppression funtion in topi/python/topi/cuda/nms.py.

Here is a minimum test to reproduce the problem.

from tvm import relay

target = "cuda"

x = relay.var("x", shape=(1, 100, 6))
y = relay.vision.non_max_suppression(x,
                                     relay.const([100]),
                                     relay.const([list(range(100))]))
func = relay.Function([x], y[0])

relay.build(func, target)