Hi, guys. I got a few problems when trying to convert a float16 gluon model to tvm. The example code is pasted below:
import tvm
from tvm import relay
import mxnet as mx
if __name__ == "__main__":
target = 'cuda'
data_shape = (1, 3, 224, 224)
dtype = 'float16'
ctx = mx.gpu()
net = mx.gluon.model_zoo.vision.get_model(
'resnet18_v2',
pretrained=False,
ctx=ctx
)
# convert to float16
if dtype == 'float16':
net.cast(dtype)
net.collect_params().initialize(ctx=ctx)
# fake forward
y = net(mx.nd.zeros(
data_shape,
dtype=dtype,
ctx=ctx
))
relay_sym, relay_params = relay.frontend.from_mxnet(
net,
shape={'data': data_shape},
dtype={'data': dtype}
)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(
relay_sym,
target,
params=relay_params
)
This script would crash with the following error message:
[14:26:52] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:109: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
Traceback (most recent call last):
File "cache/test_fp16.py", line 40, in <module>
params=relay_params
File "/home/lhy/Documents/Lib/tvm/python/tvm/relay/build_module.py", line 262, in build
func = optimize(func, target, params)
File "/home/lhy/Documents/Lib/tvm/python/tvm/relay/build_module.py", line 161, in optimize
func = ir_pass.infer_type(func)
File "/home/lhy/Documents/Lib/tvm/python/tvm/relay/ir_pass.py", line 44, in infer_type
return _ir_pass.infer_type(expr, mod)
File "tvm/_ffi/_cython/./function.pxi", line 286, in core.FunctionBase.__call__
File "tvm/_ffi/_cython/./function.pxi", line 221, in core.FuncCall
File "tvm/_ffi/_cython/./function.pxi", line 210, in core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 143, in core.CALL
tvm._ffi.base.TVMError: [14:26:52] /home/lhy/Documents/Lib/tvm/src/relay/pass/type_solver.cc:92: Check failed: resolved.defined() Unable to unify parent types: TensorType([3], float16) and TensorType([3], float32)
Stack trace returned 10 entries:
[bt] (0) /home/lhy/Documents/Lib/tvm/build/libtvm.so(dmlc::StackTrace[abi:cxx11](unsigned long)+0x9d) [0x7f3c05ae5798]
[bt] (1) /home/lhy/Documents/Lib/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x2f) [0x7f3c05ae5ac1]
[bt] (2) /home/lhy/Documents/Lib/tvm/build/libtvm.so(tvm::relay::TypeSolver::Unifier::Unify(tvm::relay::Type const&, tvm::relay::Type const&)+0x404) [0x7f3c060d2bb8]
[bt] (3) /home/lhy/Documents/Lib/tvm/build/libtvm.so(tvm::relay::TypeSolver::Unify(tvm::relay::Type const&, tvm::relay::Type const&)+0x53) [0x7f3c060ce57d]
[bt] (4) /home/lhy/Documents/Lib/tvm/build/libtvm.so(tvm::relay::TypeSolver::Reporter::Assign(tvm::relay::Type const&, tvm::relay::Type const&)+0x3f) [0x7f3c060d247f]
[bt] (5) /home/lhy/Documents/Lib/tvm/build/libtvm.so(tvm::relay::BatchNormRel(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)+0x345) [0x7f3c05f51b02]
[bt] (6) /home/lhy/Documents/Lib/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<bool, 0, 4, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&)+0xc6) [0x7f3c05eefd15]
[bt] (7) /home/lhy/Documents/Lib/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<bool, 1, 3, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&)+0xa4) [0x7f3c05eef4e5]
[bt] (8) /home/lhy/Documents/Lib/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<bool, 2, 2, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&)+0x8a) [0x7f3c05eeef40]
[bt] (9) /home/lhy/Documents/Lib/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<bool, 3, 1, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&)+0x73) [0x7f3c05eee35f]
I didn’t find much tutorials about float16 support in the doc. Any help is appreciated.