【error】Expected type relay.Expr but get relay.Module

I am trying ti run quantize in tvm, however, I got following errors: Expected type relay.Expr but get relay.Module

my code :

import numpy as np
import tvm
import tvm.autotvm as autotvm
import tvm.relay as relay
import tvm.relay.testing
import tvm.autotvm
from tvm.contrib import graph_runtime
from common import get_network
import sys
import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--log_file', type=str, default='logs/history_best_1080.log')
args = parser.parse_args()

ctx = tvm.gpu(0)
target = tvm.target.cuda()

def bench(name, batch):
    sym, data_shape = get_network(name, batch)
    data_shape = data_shape[0][1]
    sym, _ = relay.frontend.from_mxnet(sym, {'data': data_shape})
    sym, params = tvm.relay.testing.create_workload(sym)
    with relay.quantize.qconfig(skip_k_conv=0, round_for_shift=True):
        sym = relay.quantize.quantize(sym, params)

    with relay.build_module.build_config(opt_level=3):
        graph, lib, params = relay.build(sym, 'cuda', 'llvm', params=params)

    m = graph_runtime.create(graph, lib, ctx)
    x = np.random.uniform(size=data_shape)
    data_tvm = tvm.nd.array(x.astype('float32'))
    m.set_input("data", data_tvm)
    m.set_input(**{k:tvm.nd.array(v, ctx) for k, v in params.items()})
    m.run()
    e = m.module.time_evaluator("run", ctx, number=2000, repeat=3)
    t = e(data_tvm).results
    t = np.array(t) * 1000

    print('{} (batch={}): {} ms'.format(name, batch, t.mean()))


def main():
    with tvm.target.cuda():
        with autotvm.apply_history_best(args.log_file):
            for batch in [1, 16]:
                for name in ['vgg-19', 'resnet-50', 'resnext-50', 'inception_v3', 'drn-c-26', 'dcn-resnet-101']:
                    bench(name, batch)


if __name__ == '__main__':
    main()

error information

  File "run_tvm.py", line 53, in <module>
    main()
  File "run_tvm.py", line 49, in main
    bench(name, batch)
  File "run_tvm.py", line 24, in bench
    sym, params = tvm.relay.testing.create_workload(sym)
  File "/mnt/ebs0/lst/tvm/tvm/python/tvm/relay/testing/init.py", line 153, in create_workload
    mod = relay.Module.from_expr(net)
  File "/mnt/ebs0/lst/tvm/tvm/python/tvm/relay/module.py", line 183, in from_expr
    return _module.Module_FromExpr(expr)
  File "tvm/_ffi/_cython/./function.pxi", line 310, in tvm._ffi._cy3.core.FunctionBase.__call__
  File "tvm/_ffi/_cython/./function.pxi", line 245, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./function.pxi", line 234, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 170, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (3) /mnt/ebs0/lst/tvm/tvm/build/libtvm.so(TVMFuncCall+0x46) [0x2b5cf662c556]
  [bt] (2) /mnt/ebs0/lst/tvm/tvm/build/libtvm.so(+0xd98e03) [0x2b5cf634ce03]
  [bt] (1) /mnt/ebs0/lst/tvm/tvm/build/libtvm.so(tvm::relay::Expr tvm::runtime::TVMArgValue::AsNodeRef<tvm::relay::Expr>()const+0x254) [0x2b5cf62a2e74]
  [bt] (0) /mnt/ebs0/lst/tvm/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x22) [0x2b5cf5fbf0b2]
  File "/mnt/ebs0/lst/tvm/tvm/include/tvm/packed_func_ext.h", line 141
TVMError: Check failed: NodeTypeChecker<TNodeRef>: :Check(sptr.get()): Expected type relay.Expr but get relay.Module

how did you solve this problem?

how did you solve this problem?

We’ve moved to new module API, see updated example https://github.com/vinx13/tvm-cuda-int8-benchmark/commit/8ee2c4b1ba51c302332433417967ec66ab23c87e

okey, thanks. i try it

TVM support int8 quantize on arm-cpu?

i update the run_tvm.py, but some errors happend:
python run_tvm.py --log_file logs/history_best_1080.log

Cannot find config for target=cuda, workload=('dense', (1, 4096, 'float32'), (1000, 4096, 'float32'), 0,  'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=cuda, workload=('dense', (1, 4096, 'float32'), (4096, 4096, 'float32'), 0, 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=cuda, workload=('dense', (1, 25088, 'float32'), (4096, 25088, 'float32'), 0, 'float32'). A fallback configuration is used, which may bring great performance regression.
vgg-19 (batch=1): 3.4733736873333334 ms
Traceback (most recent call last):
  File "run_tvm.py", line 65, in <module>
    main()
  File "run_tvm.py", line 61, in main
    bench(name, batch)
  File "run_tvm.py", line 37, in bench
    mod['main'] = relay.quantize.quantize(mod['main'], params=params)
  File "/data0/zzw/tvm/python/tvm/relay/quantize/quantize.py", line 366, in quantize
    mod = quantize_seq(mod)
  File "/data0/zzw/tvm/python/tvm/relay/transform.py", line 185, in __call__
    return _transform.RunPass(self, mod)
  File "/data0/zzw/tvm/python/tvm/_ffi/_ctypes/function.py", line 210, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /data0/zzw/tvm/build/libtvm.so(tvm::relay::ForwardRewriter::VisitExpr_(tvm::relay::CallNode const*)+0x2c4) [0x7fdf8a2682e4]
  [bt] (7) /data0/zzw/tvm/build/libtvm.so(tvm::relay::ForwardRewriter::GetTempExpr(tvm::relay::Expr const&)+0x15d) [0x7fdf8a26725d]
  [bt] (6) /data0/zzw/tvm/build/libtvm.so(tvm::relay::ExprMutator::VisitExpr(tvm::relay::Expr const&)+0x9e) [0x7fdf8a09a9fe]
  [bt] (5) /data0/zzw/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>::VisitExpr(tvm::relay::Expr const&)+0xd2) [0x7fdf8a0a1542]
  [bt] (4) /data0/zzw/tvm/build/libtvm.so(std::_Function_handler<tvm::relay::Expr (tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*), tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>::InitVTable()::{lambda(tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*)#6}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*&&)+0x34) [0x7fdf8a09cd84]
  [bt] (3) /data0/zzw/tvm/build/libtvm.so(tvm::relay::ForwardRewriter::VisitExpr_(tvm::relay::CallNode const*)+0x5ec) [0x7fdf8a26860c]
  [bt] (2) /data0/zzw/tvm/build/libtvm.so(std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::relay::Expr (tvm::relay::Call const&, tvm::Array<tvm::relay::Expr, void> const&, tvm::NodeRef const&)>::AssignTypedLambda<tvm::relay::Expr (*)(tvm::relay::Call const&, tvm::Array<tvm::relay::Expr, void> const&, tvm::NodeRef const&)>(tvm::relay::Expr (*)(tvm::relay::Call const&, tvm::Array<tvm::relay::Expr, void> const&, tvm::NodeRef const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)+0xb0) [0x7fdf8a22db70]
  [bt] (1) /data0/zzw/tvm/build/libtvm.so(tvm::relay::quantize::MulRealize(tvm::relay::Call const&, tvm::Array<tvm::relay::Expr, void> const&, tvm::NodeRef const&)+0x292) [0x7fdf8a2bce72]
  [bt] (0) /data0/zzw/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x43) [0x7fdf89c6daf3]
  File "/data0/zzw/tvm/src/relay/pass/quantize.cc", line 344
TVMError: Check failed: lhs->dtype == dtype (int8 vs. int32) :

try store_lowbit_output=False in qconfig, quantization is being refactored

it works, but another error occurs:

Cannot find config for target=cuda, workload=('dense', (1, 4096, 'float32'), (1000, 4096, 'float32'), 0, 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=cuda, workload=('dense', (1, 4096, 'float32'), (4096, 4096, 'float32'), 0, 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=cuda, workload=('dense', (1, 25088, 'float32'), (4096, 25088, 'float32'), 0, 'float32'). A fallback configuration is used, which may bring great performance regression.
vgg-19 (batch=1): 2.286234292 ms
Cannot find config for target=cuda, workload=('dense', (1, 2048, 'float32'), (1000, 2048, 'float32'), 0, 'float32'). A fallback configuration is used, which may bring great performance regression.
resnet-50 (batch=1): 0.9510388265 ms
resnext-50 (batch=1): 0.9545957925 ms
inception_v3 (batch=1): 1.8752471994999997 ms
drn-c-26 (batch=1): 1.5227248015 ms
/data0/zzw/tvm/python/tvm/relay/frontend/nnvm_common.py:28: UserWarning: use_global_stats is ignored in batch_norm.
  warnings.warn(err)
Traceback (most recent call last):
  File "run_tvm.py", line 65, in <module>
    main()
  File "run_tvm.py", line 61, in main
    bench(name, batch)
  File "run_tvm.py", line 37, in bench
    mod['main'] = relay.quantize.quantize(mod['main'], params=params)
  File "/data0/zzw/tvm/python/tvm/relay/quantize/quantize.py", line 366, in quantize
    mod = quantize_seq(mod)
  File "/data0/zzw/tvm/python/tvm/relay/transform.py", line 185, in __call__
    return _transform.RunPass(self, mod)
  File "/data0/zzw/tvm/python/tvm/_ffi/_ctypes/function.py", line 210, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /data0/zzw/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>::VisitExpr(tvm::relay::Expr const&)+0xd2) [0x7f75d70d6542]
  [bt] (7) /data0/zzw/tvm/build/libtvm.so(std::_Function_handler<tvm::relay::Expr (tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*), tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>::InitVTable()::{lambda(tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*)#6}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*&&)+0x34) [0x7f75d70d1d84]
  [bt] (6) /data0/zzw/tvm/build/libtvm.so(tvm::relay::ForwardRewriter::VisitExpr_(tvm::relay::CallNode const*)+0x2c4) [0x7f75d729d2e4]
  [bt] (5) /data0/zzw/tvm/build/libtvm.so(tvm::relay::ForwardRewriter::GetTempExpr(tvm::relay::Expr const&)+0x42) [0x7f75d729c142]
  [bt] (4) /data0/zzw/tvm/build/libtvm.so(tvm::relay::ExprMutator::VisitExpr(tvm::relay::Expr const&)+0x9e) [0x7f75d70cf9fe]
  [bt] (3) /data0/zzw/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>::VisitExpr(tvm::relay::Expr const&)+0xd2) [0x7f75d70d6542]
  [bt] (2) /data0/zzw/tvm/build/libtvm.so(std::_Function_handler<tvm::relay::Expr (tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*), tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>::InitVTable()::{lambda(tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*)#6}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*&&)+0x34) [0x7f75d70d1d84]
  [bt] (1) /data0/zzw/tvm/build/libtvm.so(tvm::relay::ForwardRewriter::VisitExpr_(tvm::relay::CallNode const*)+0x5ec) [0x7f75d729d60c]
  [bt] (0) /data0/zzw/tvm/build/libtvm.so(+0x21b34eb) [0x7f75d73c04eb]
  File "/data0/zzw/tvm/python/tvm/_ffi/_ctypes/function.py", line 72, in cfun
    rv = local_pyfunc(*pyargs)
  File "/data0/zzw/tvm/python/tvm/relay/quantize/_annotate.py", line 112, in frewrite_with_guard
    return func(ref_call, new_args, ctx)
  File "/data0/zzw/tvm/python/tvm/relay/quantize/_annotate.py", line 281, in add_rewrite
    raise ValueError()
TVMError: ValueError

see [VTA] Error when quantizing MxNet model

thanks for your replay!