can't deploy ssd with gluoncv

I tried to deploy ssd with gluoncv according the tvm docs. I ran the source code file. The error infos as below.

Traceback (most recent call last):
  File "deploy_ssd_gluoncv.py", line 105, in <module>
    graph, lib, params = compile(target)
  File "deploy_ssd_gluoncv.py", line 80, in compile
    net, params = relay.frontend.from_mxnet(block, {"data": dshape})
  File "/Users/cxt123/Reposities/tvm/python/tvm/relay/frontend/mxnet.py", line 612, in from_mxnet
    sym = _from_mxnet_impl(sym, shape, dtype)
  File "/Users/cxt123/Reposities/tvm/python/tvm/relay/frontend/mxnet.py", line 504, in _from_mxnet_impl
    jgraph = json.loads(symbol.tojson())

And here is the related code in mxnet.py:

    if isinstance(symbol, mx.sym.Symbol):
        params = {}
        arg_params = arg_params if arg_params else {}
        aux_params = aux_params if aux_params else {}
        for k, v in arg_params.items():
            params[k] = _nd.array(v.asnumpy())
        for k, v in aux_params.items():
            params[k] = _nd.array(v.asnumpy())
        shape, dtype = _update_shape_dtype(shape, dtype, params)
        sym = _from_mxnet_impl(symbol, shape, dtype)
    elif isinstance(symbol, mx.gluon.HybridBlock):
        if arg_params is not None or aux_params is not None:
            raise ValueError("arg_params and aux_params ae not used when importing HybridBlock")
        params = {}
        for k, v in symbol.collect_params().items():
            params[k] = _nd.array(v.data().asnumpy())
        data = mx.sym.Variable("data")
        sym = symbol(data)
        shape, dtype = _update_shape_dtype(shape, dtype, params)
        sym = _from_mxnet_impl(sym, shape, dtype)

def _from_mxnet_impl(symbol, shape_dict, dtype_info):
    assert symbol is not None
    jgraph = json.loads(symbol.tojson())
    jnodes = jgraph["nodes"]
    node_map = {}

the root of this error is that sym is a tuple instead of symbol. It seems like there is a minor mistake when fetching symbol from hybrid block model. But i don’t know how to fix it.

my envs:

macOS: 10.14.3 
mxnet-mkl : 1.3
gluoncv: 0.3
tvm built from source commit id:  fd9fa4bf

I modified serveal lines in mxnet.py to fetch symbol from hybrid block according to Quick tip: converting Gluon models to symbolic format.
I replaced

        data = mx.sym.Variable("data")
        sym = symbol(data)

with

        import numpy as np
        import mxnet as mx
        model_name = "ssd_512_resnet50_v1_voc"
        symbol.hybridize()
        x = np.zeros([1, 3, 512, 512])
        x = mx.nd.array(x)
        symbol.forward(x)
        symbol.export(model_name)
        sym, _, _ = mx.model.load_checkpoint(model_name, 0)

the exception of tojson is fixed, but another error raised:

/Users/cxt123/Applications/anaconda3/envs/conda36_base/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
[09:46:18] src/operator/nn/mkldnn/mkldnn_base.cc:74: Allocate 147456 bytes with malloc directly
[09:46:18] src/operator/nn/mkldnn/mkldnn_base.cc:74: Allocate 589824 bytes with malloc directly
[09:46:18] src/operator/nn/mkldnn/mkldnn_base.cc:74: Allocate 2359296 bytes with malloc directly
[09:46:18] src/operator/nn/mkldnn/mkldnn_base.cc:74: Allocate 9437184 bytes with malloc directly
Traceback (most recent call last):
  File "deploy_ssd_gluoncv.py", line 94, in <module>
    graph, lib, params = compile(target)
  File "deploy_ssd_gluoncv.py", line 70, in compile
    net, params = relay.frontend.from_mxnet(block, {"data": dshape})
  File "/Users/cxt123/Reposities/tvm/python/tvm/relay/frontend/mxnet.py", line 625, in from_mxnet
    sym = _from_mxnet_impl(sym, shape, dtype)
  File "/Users/cxt123/Reposities/tvm/python/tvm/relay/frontend/mxnet.py", line 521, in _from_mxnet_impl
    res = _convert_map[op_name](children, attrs)
  File "/Users/cxt123/Reposities/tvm/python/tvm/relay/frontend/mxnet.py", line 199, in _mx_slice_axis
    shape = ir_pass.infer_type(inputs[0]).checked_type.shape
  File "/Users/cxt123/Reposities/tvm/python/tvm/relay/ir_pass.py", line 45, in infer_type
    return _ir_pass.infer_type(expr, mod)
  File "/Users/cxt123/Reposities/tvm/python/tvm/_ffi/_ctypes/function.py", line 185, in __call__
    ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
  File "/Users/cxt123/Reposities/tvm/python/tvm/_ffi/base.py", line 71, in check_call
    raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [09:46:19] /Users/cxt123/Reposities/tvm/src/relay/ir/error.cc:112:
Error(s) have occurred. We have annotated the program with them:

The detail error message as below:

/Users/cxt123/Applications/anaconda3/envs/conda36_base/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
[09:54:53] src/operator/nn/mkldnn/mkldnn_base.cc:74: Allocate 147456 bytes with malloc directly
[09:54:53] src/operator/nn/mkldnn/mkldnn_base.cc:74: Allocate 589824 bytes with malloc directly
[09:54:53] src/operator/nn/mkldnn/mkldnn_base.cc:74: Allocate 2359296 bytes with malloc directly
[09:54:53] src/operator/nn/mkldnn/mkldnn_base.cc:74: Allocate 9437184 bytes with malloc directly
Traceback (most recent call last):
  File "deploy_ssd_gluoncv.py", line 94, in <module>
    graph, lib, params = compile(target)
  File "deploy_ssd_gluoncv.py", line 70, in compile
    net, params = relay.frontend.from_mxnet(block, {"data": dshape})
  File "/Users/cxt123/Reposities/tvm/python/tvm/relay/frontend/mxnet.py", line 625, in from_mxnet
    sym = _from_mxnet_impl(sym, shape, dtype)
  File "/Users/cxt123/Reposities/tvm/python/tvm/relay/frontend/mxnet.py", line 521, in _from_mxnet_impl
    res = _convert_map[op_name](children, attrs)
  File "/Users/cxt123/Reposities/tvm/python/tvm/relay/frontend/mxnet.py", line 199, in _mx_slice_axis
    shape = ir_pass.infer_type(inputs[0]).checked_type.shape
  File "/Users/cxt123/Reposities/tvm/python/tvm/relay/ir_pass.py", line 45, in infer_type
    return _ir_pass.infer_type(expr, mod)
  File "/Users/cxt123/Reposities/tvm/python/tvm/_ffi/_ctypes/function.py", line 185, in __call__
    ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
  File "/Users/cxt123/Reposities/tvm/python/tvm/_ffi/base.py", line 71, in check_call
    raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [09:54:54] /Users/cxt123/Reposities/tvm/src/relay/ir/error.cc:112:
Error(s) have occurred. We have annotated the program with them:

In `main`:
fn () {
  free_var %data: Tensor[(1, 3, 512, 512), float32]
  free_var %ssd0_resnetv10_conv0_weight: Tensor[(64, 3, 7, 7), float32]
  %0 = nn.conv2d(%data, %ssd0_resnetv10_conv0_weight, strides=[2, 2], padding=[3, 3], channels=64, kernel_size=[7, 7]) #
  free_var %ssd0_resnetv10_batchnorm0_gamma: Tensor[(64,), float32]
  free_var %ssd0_resnetv10_batchnorm0_beta: Tensor[(64,), float32]
  free_var %ssd0_resnetv10_batchnorm0_running_mean: Tensor[(64,), float32]
  free_var %ssd0_resnetv10_batchnorm0_running_var: Tensor[(64,), float32]
  %1 = nn.batch_norm(%0, %ssd0_resnetv10_batchnorm0_gamma, %ssd0_resnetv10_batchnorm0_beta, %ssd0_resnetv10_batchnorm0_running_mean, %ssd0_resnetv10_batchnorm0_running_var) #
  %2 = %1.0
  %3 = nn.relu(%2) #
  %4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1]) #
  free_var %ssd0_resnetv10_stage1_conv0_weight: Tensor[(64, 64, 1, 1), float32]
...
...
  %340 = add(%321, 18f) #
  %341 = add(%321, 19f) #
  %342 = (%322, %323, %324, %325, %326, %327, %328, %329, %330, %331, %332, %333, %334, %335, %336, %337, %338, %339, %340, %341)
  %343 = concatenate(%342, axis=-1) #
  %344 = ones_like(%343) #
  %345 = multiply(%344, -1f) #
  %346 = where(%319, %343, %345) # an internal invariant was violdated whiletypechecking your program[09:54:54] /Users/cxt123/Reposities/tvm/src/relay/op/tensor/transform.cc:965: Check failed: reporter->AssertEQ(cond_shape[i], x_shape[i]) Shape of condition [1, 6132, 21] must be either equal to x or has dimension of 1.

Stack trace returned 10 entries:
[bt] (0) 0   libtvm.dylib                        0x000000010cc45140 dmlc::StackTrace(unsigned long) + 464
[bt] (1) 1   libtvm.dylib                        0x000000010cc44e24 dmlc::LogMessageFatal::~LogMessageFatal() + 52
[bt] (2) 2   libtvm.dylib                        0x000000010d171965 tvm::relay::WhereRel(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) + 2629
[bt] (3) 3   libtvm.dylib                        0x000000010d0800cf void tvm::runtime::detail::unpack_call_dispatcher<bool, 0, 4, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&) + 95
[bt] (4) 4   libtvm.dylib                        0x000000010d080029 std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 137
[bt] (5) 5   libtvm.dylib                        0x000000010d28bd3d tvm::TypedEnvFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::operator()(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) const + 365
[bt] (6) 6   libtvm.dylib                        0x000000010d28b665 tvm::relay::TypeSolver::Solve() + 1125
[bt] (7) 7   libtvm.dylib                        0x000000010d26e7a4 tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 116
[bt] (8) 8   libtvm.dylib                        0x000000010d26f948 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 472
[bt] (9) 9   libtvm.dylib                        0x000000010d036f08 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 152

;
  %346
}



Stack trace returned 10 entries:
[bt] (0) 0   libtvm.dylib                        0x000000010cc45140 dmlc::StackTrace(unsigned long) + 464
[bt] (1) 1   libtvm.dylib                        0x000000010cc44e24 dmlc::LogMessageFatal::~LogMessageFatal() + 52
[bt] (2) 2   libtvm.dylib                        0x000000010cffe32e tvm::relay::ErrorReporter::RenderErrors(tvm::relay::Module const&, bool) + 5406
[bt] (3) 3   libtvm.dylib                        0x000000010d26e7c0 tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 144
[bt] (4) 4   libtvm.dylib                        0x000000010d26f948 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 472
[bt] (5) 5   libtvm.dylib                        0x000000010d036f08 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 152
[bt] (6) 6   libtvm.dylib                        0x000000010d03870f tvm::relay::ModuleNode::FromExpr(tvm::relay::Expr const&, tvm::Map<tvm::relay::GlobalVar, tvm::relay::Function, void, void> const&) + 943
[bt] (7) 7   libtvm.dylib                        0x000000010d26f3c9 tvm::relay::InferType(tvm::relay::Expr const&, tvm::relay::Module const&) + 793
[bt] (8) 8   libtvm.dylib                        0x000000010d28954f std::__1::__function::__func<tvm::relay::$_1, std::__1::allocator<tvm::relay::$_1>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 159
[bt] (9) 9   libtvm.dylib                        0x000000010d4d5266 TVMFuncCall + 70

I cannot reproduce the error! I’m on mxnet-cu9.0==1.4

Could you paste the full exception message that you recieve, not just the traceback?

There is no other message, the full exception message when i ran the code with python deploy_ssd_gluoncv.py as below:

/Users/cxt123/Applications/anaconda3/envs/conda36_base/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
Traceback (most recent call last):
  File "deploy_ssd_gluoncv.py", line 104, in <module>
    graph, lib, params = compile(target)
  File "deploy_ssd_gluoncv.py", line 80, in compile
    net, params = relay.frontend.from_mxnet(block, {"data": dshape})
  File "/Users/cxt123/Reposities/tvm/python/tvm/relay/frontend/mxnet.py", line 612, in from_mxnet
    sym = _from_mxnet_impl(sym, shape, dtype)
  File "/Users/cxt123/Reposities/tvm/python/tvm/relay/frontend/mxnet.py", line 504, in _from_mxnet_impl
    jgraph = json.loads(symbol.tojson())
AttributeError: 'tuple' object has no attribute 'tojson'

Is my mxnet’s version or gluoncv’s version too old to cause this exception?

maybe my mxnet version is too old ? it’s mxnet-mkl 1.3.0

Reproduced this on an older tvm version, updating to latest works

Thanks for your suggestion. This bug is fixed by group the tuple of symbols. But there is another error as below:

/Users/cxt123/Applications/anaconda3/envs/conda36_base/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 3, 512, 512, 'float32'), (64, 3, 7, 7, 'float32'), (2, 2), (3, 3), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 64, 128, 128, 'float32'), (320, 64, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 64, 128, 128, 'float32'), (64, 64, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 64, 128, 128, 'float32'), (256, 64, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 256, 128, 128, 'float32'), (64, 256, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 256, 128, 128, 'float32'), (640, 256, 1, 1, 'float32'), (2, 2), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 128, 64, 64, 'float32'), (128, 128, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 128, 64, 64, 'float32'), (512, 128, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 512, 64, 64, 'float32'), (128, 512, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 512, 64, 64, 'float32'), (1280, 512, 1, 1, 'float32'), (2, 2), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 256, 32, 32, 'float32'), (256, 256, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 256, 32, 32, 'float32'), (1024, 256, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 1024, 32, 32, 'float32'), (256, 1024, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 1024, 32, 32, 'float32'), (100, 1024, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 1024, 32, 32, 'float32'), (2560, 1024, 1, 1, 'float32'), (2, 2), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 512, 16, 16, 'float32'), (512, 512, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 512, 16, 16, 'float32'), (2048, 512, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 2048, 16, 16, 'float32'), (512, 2048, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 2048, 16, 16, 'float32'), (150, 2048, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 512, 16, 16, 'float32'), (512, 512, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 512, 8, 8, 'float32'), (150, 512, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 512, 8, 8, 'float32'), (512, 512, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 512, 8, 8, 'float32'), (512, 512, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 512, 4, 4, 'float32'), (150, 512, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 512, 4, 4, 'float32'), (256, 512, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 256, 4, 4, 'float32'), (256, 256, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 256, 2, 2, 'float32'), (100, 256, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 256, 2, 2, 'float32'), (256, 256, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 256, 2, 2, 'float32'), (256, 256, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm, workload=('conv2d', (1, 256, 1, 1, 'float32'), (100, 256, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
[14:45:12] /Users/cxt123/Reposities/tvm/src/schedule/bound.cc:110: not in feed graph consumer = hybrid(hybrid_nms, 0x7f9a230f2b50)
[14:45:12] /Users/cxt123/Reposities/tvm/src/arithmetic/int_set.cc:540: cannot evaluate set type Load
[14:45:12] /Users/cxt123/Reposities/tvm/src/arithmetic/int_set.cc:540: cannot evaluate set type Load
[14:45:12] /Users/cxt123/Reposities/tvm/src/schedule/bound.cc:110: not in feed graph consumer = hybrid(hybrid_nms, 0x7f9a230f2b50)
[14:45:12] /Users/cxt123/Reposities/tvm/src/arithmetic/int_set.cc:540: cannot evaluate set type Load
[14:45:12] /Users/cxt123/Reposities/tvm/src/arithmetic/int_set.cc:540: cannot evaluate set type Load
Traceback (most recent call last):
  File "deploy_ssd_gluoncv.py", line 95, in <module>
    class_IDs, scores, bounding_boxs = run(graph, lib, params, ctx)
  File "deploy_ssd_gluoncv.py", line 85, in run
    m.run()
  File "/Users/cxt123/Reposities/tvm/python/tvm/contrib/graph_runtime.py", line 151, in run
    self._run()
  File "/Users/cxt123/Reposities/tvm/python/tvm/_ffi/_ctypes/function.py", line 185, in __call__
    ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
  File "/Users/cxt123/Reposities/tvm/python/tvm/_ffi/base.py", line 71, in check_call
    raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [14:45:23] /Users/cxt123/Reposities/tvm/src/runtime/module_util.cc:54: Check failed: ret == 0 (-1 vs. 0) [14:45:23] /Users/cxt123/Reposities/tvm/src/runtime/module.cc:92: Check failed: f != nullptr Cannot find function tvm.contrib.sort.argsort in the imported modules or global registry

Stack trace returned 6 entries:
[bt] (0) 0   libtvm.dylib                        0x00000001134f5250 dmlc::StackTrace(unsigned long) + 464
[bt] (1) 1   libtvm.dylib                        0x00000001134f4f34 dmlc::LogMessageFatal::~LogMessageFatal() + 52
[bt] (2) 2   libtvm.dylib                        0x0000000113e159f7 tvm::runtime::ModuleNode::GetFuncFromEnv(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 535
[bt] (3) 3   libtvm.dylib                        0x0000000113e0cdc4 TVMBackendGetFuncFromEnv + 164
[bt] (4) 4   ???                                 0x0000001a44c5b18a 0x0 + 112822956426
[bt] (5) 5   ???                                 0x0000001a44c5aae9 0x0 + 112822954729



Stack trace returned 8 entries:
[bt] (0) 0   libtvm.dylib                        0x00000001134f5250 dmlc::StackTrace(unsigned long) + 464
[bt] (1) 1   libtvm.dylib                        0x00000001134f4f34 dmlc::LogMessageFatal::~LogMessageFatal() + 52
[bt] (2) 2   libtvm.dylib                        0x0000000113e19b21 std::__1::__function::__func<tvm::runtime::WrapPackedFunc(int (*)(void*, int*, int), std::__1::shared_ptr<tvm::runtime::ModuleNode> const&)::$_0, std::__1::allocator<tvm::runtime::WrapPackedFunc(int (*)(void*, int*, int), std::__1::shared_ptr<tvm::runtime::ModuleNode> const&)::$_0>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 257
[bt] (3) 3   libtvm.dylib                        0x0000000113e42081 std::__1::__function::__func<tvm::runtime::GraphRuntime::CreateTVMOp(tvm::runtime::TVMOpParam const&, std::__1::vector<DLTensor, std::__1::allocator<DLTensor> > const&, unsigned long)::$_3, std::__1::allocator<tvm::runtime::GraphRuntime::CreateTVMOp(tvm::runtime::TVMOpParam const&, std::__1::vector<DLTensor, std::__1::allocator<DLTensor> > const&, unsigned long)::$_3>, void ()>::operator()() + 81
[bt] (4) 4   libtvm.dylib                        0x0000000113e4304f std::__1::__function::__func<tvm::runtime::GraphRuntime::GetFunction(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, std::__1::shared_ptr<tvm::runtime::ModuleNode> const&)::$_8, std::__1::allocator<tvm::runtime::GraphRuntime::GetFunction(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, std::__1::shared_ptr<tvm::runtime::ModuleNode> const&)::$_8>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 79
[bt] (5) 5   libtvm.dylib                        0x0000000113e0d1c6 TVMFuncCall + 70
[bt] (6) 6   libffi.6.dylib                      0x000000010e5f7884 ffi_call_unix64 + 76
[bt] (7) 7   ???                                 0x00007ffee1ce1440 0x0 + 140732686799936


libc++abi.dylib: terminating with uncaught exception of type dmlc::Error
[1]    60317 abort      python deploy_ssd_gluoncv.py

Is my tvm built in wrong way ? I’m not sure about that.

you need to build with argsort. see config.cmake

Thanks, It work normally.:grinning: