Failed to convert mxnet float16 model to tvm

Hi, guys. I got a few problems when trying to convert a float16 gluon model to tvm. The example code is pasted below:

import tvm
from tvm import relay
import mxnet as mx


if __name__ == "__main__":
    target = 'cuda'
    data_shape = (1, 3, 224, 224)
    dtype = 'float16'
    ctx = mx.gpu()

    net = mx.gluon.model_zoo.vision.get_model(
        'resnet18_v2',
        pretrained=False,
        ctx=ctx
    )
    #  convert to float16
    if dtype == 'float16':
        net.cast(dtype)
    net.collect_params().initialize(ctx=ctx)
    #  fake forward
    y = net(mx.nd.zeros(
        data_shape,
        dtype=dtype,
        ctx=ctx
    ))

    relay_sym, relay_params = relay.frontend.from_mxnet(
        net,
        shape={'data': data_shape},
        dtype={'data': dtype}
    )
    with relay.build_config(opt_level=3):
        graph, lib, params = relay.build(
            relay_sym,
            target,
            params=relay_params
        )

This script would crash with the following error message:

[14:26:52] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:109: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
Traceback (most recent call last):
  File "cache/test_fp16.py", line 40, in <module>
    params=relay_params
  File "/home/lhy/Documents/Lib/tvm/python/tvm/relay/build_module.py", line 262, in build
    func = optimize(func, target, params)
  File "/home/lhy/Documents/Lib/tvm/python/tvm/relay/build_module.py", line 161, in optimize
    func = ir_pass.infer_type(func)
  File "/home/lhy/Documents/Lib/tvm/python/tvm/relay/ir_pass.py", line 44, in infer_type
    return _ir_pass.infer_type(expr, mod)
  File "tvm/_ffi/_cython/./function.pxi", line 286, in core.FunctionBase.__call__
  File "tvm/_ffi/_cython/./function.pxi", line 221, in core.FuncCall
  File "tvm/_ffi/_cython/./function.pxi", line 210, in core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 143, in core.CALL
tvm._ffi.base.TVMError: [14:26:52] /home/lhy/Documents/Lib/tvm/src/relay/pass/type_solver.cc:92: Check failed: resolved.defined() Unable to unify parent types: TensorType([3], float16) and TensorType([3], float32)

Stack trace returned 10 entries:
[bt] (0) /home/lhy/Documents/Lib/tvm/build/libtvm.so(dmlc::StackTrace[abi:cxx11](unsigned long)+0x9d) [0x7f3c05ae5798]
[bt] (1) /home/lhy/Documents/Lib/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x2f) [0x7f3c05ae5ac1]
[bt] (2) /home/lhy/Documents/Lib/tvm/build/libtvm.so(tvm::relay::TypeSolver::Unifier::Unify(tvm::relay::Type const&, tvm::relay::Type const&)+0x404) [0x7f3c060d2bb8]
[bt] (3) /home/lhy/Documents/Lib/tvm/build/libtvm.so(tvm::relay::TypeSolver::Unify(tvm::relay::Type const&, tvm::relay::Type const&)+0x53) [0x7f3c060ce57d]
[bt] (4) /home/lhy/Documents/Lib/tvm/build/libtvm.so(tvm::relay::TypeSolver::Reporter::Assign(tvm::relay::Type const&, tvm::relay::Type const&)+0x3f) [0x7f3c060d247f]
[bt] (5) /home/lhy/Documents/Lib/tvm/build/libtvm.so(tvm::relay::BatchNormRel(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)+0x345) [0x7f3c05f51b02]
[bt] (6) /home/lhy/Documents/Lib/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<bool, 0, 4, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&)+0xc6) [0x7f3c05eefd15]
[bt] (7) /home/lhy/Documents/Lib/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<bool, 1, 3, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&)+0xa4) [0x7f3c05eef4e5]
[bt] (8) /home/lhy/Documents/Lib/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<bool, 2, 2, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&)+0x8a) [0x7f3c05eeef40]
[bt] (9) /home/lhy/Documents/Lib/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<bool, 3, 1, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&)+0x73) [0x7f3c05eee35f]

I didn’t find much tutorials about float16 support in the doc. Any help is appreciated.

Just noticed that the original gluon model use float32 mean, var, gamma, beta parameters even though the model is casted to float16. I guess this might not be supported in tvm right now.
Thus, I changed the “resnet18_v2” to “alexnet” which has no BN layers. This time, the script produced a different error log … o(╯□╰)o

$ python3 cache/test_fp16.py
[15:04:59] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:109: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT 
to 0 to disable)
Cannot find config for target=cuda, workload=('conv2d', (1, 3, 224, 224, 'float16'), (64, 3, 11, 11, 'float16'), (4, 4), (2, 2), (1, 1), 'NCHW', 'float16'). A fallback configuration is used, which may bri
ng great performance regression.
Cannot find config for target=cuda, workload=('conv2d', (1, 64, 27, 27, 'float16'), (192, 64, 5, 5, 'float16'), (1, 1), (2, 2), (1, 1), 'NCHW', 'float16'). A fallback configuration is used, which may brin
g great performance regression.
Cannot find config for target=cuda, workload=('conv2d', (1, 192, 13, 13, 'float16'), (384, 192, 3, 3, 'float16'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float16'). A fallback configuration is used, which may br
ing great performance regression.
Cannot find config for target=cuda, workload=('conv2d', (1, 384, 13, 13, 'float16'), (256, 384, 3, 3, 'float16'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float16'). A fallback configuration is used, which may br
ing great performance regression.
Cannot find config for target=cuda, workload=('conv2d', (1, 256, 13, 13, 'float16'), (256, 256, 3, 3, 'float16'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float16'). A fallback configuration is used, which may br
ing great performance regression.
Traceback (most recent call last):
  File "cache/test_fp16.py", line 43, in <module>
    params=relay_params
  File "/home/lhy/Documents/Lib/tvm/python/tvm/relay/build_module.py", line 275, in build
    lowered_funcs, target=target, target_host=target_host)
  File "/home/lhy/Documents/Lib/tvm/python/tvm/build_module.py", line 586, in build
    fhost, mdev = _build_for_device(flist, tar, target_host)
  File "/home/lhy/Documents/Lib/tvm/python/tvm/build_module.py", line 453, in _build_for_device
    mdev = codegen.build_module(fdevice, str(target)) if fdevice else None
  File "/home/lhy/Documents/Lib/tvm/python/tvm/codegen.py", line 20, in build_module
    return _Build(lowered_func, target)
  File "tvm/_ffi/_cython/./function.pxi", line 286, in core.FunctionBase.__call__
  File "tvm/_ffi/_cython/./function.pxi", line 221, in core.FuncCall
  File "tvm/_ffi/_cython/./function.pxi", line 210, in core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 143, in core.CALL
tvm._ffi.base.TVMError: TVMCall CFunc Error:
Traceback (most recent call last):
  File "tvm/_ffi/_cython/./function.pxi", line 38, in core.tvm_callback
  File "/home/lhy/Documents/Lib/tvm/python/tvm/autotvm/measure/measure_methods.py", line 557, in tvm_callback_cuda_compile
    ptx = nvcc.compile_cuda(code, target="ptx", arch=AutotvmGlobalScope.current.cuda_target_arch)
  File "/home/lhy/Documents/Lib/tvm/python/tvm/contrib/nvcc.py", line 82, in compile_cuda
    raise RuntimeError(msg)
RuntimeError: Compilation error:
/tmp/tmpyp0ltjf9/my_kernel.cu(10): warning: attribute "__shared__" does not apply here

/tmp/tmpyp0ltjf9/my_kernel.cu(13): warning: attribute "__shared__" does not apply here

/tmp/tmpyp0ltjf9/my_kernel.cu(13): warning: attribute "__shared__" does not apply here

/tmp/tmpyp0ltjf9/my_kernel.cu(13): warning: attribute "__shared__" does not apply here

/tmp/tmpyp0ltjf9/my_kernel.cu(13): error: no operator "+" matches these operands
            operand types are: volatile half + volatile half

/tmp/tmpyp0ltjf9/my_kernel.cu(17): warning: attribute "__shared__" does not apply here

/tmp/tmpyp0ltjf9/my_kernel.cu(17): warning: attribute "__shared__" does not apply here

/tmp/tmpyp0ltjf9/my_kernel.cu(17): warning: attribute "__shared__" does not apply here

/tmp/tmpyp0ltjf9/my_kernel.cu(17): error: no operator "+" matches these operands
            operand types are: volatile half + volatile half

/tmp/tmpyp0ltjf9/my_kernel.cu(18): warning: attribute "__shared__" does not apply here

/tmp/tmpyp0ltjf9/my_kernel.cu(21): error: no operator "+" matches these operands
            operand types are: volatile half + volatile half

/tmp/tmpyp0ltjf9/my_kernel.cu(25): warning: attribute "__shared__" does not apply here

/tmp/tmpyp0ltjf9/my_kernel.cu(41): error: more than one instance of overloaded function "max" matches the argument list:
            function "max(int, int)"
            function "max(unsigned int, unsigned int)"
            function "max(int, unsigned int)"
            function "max(unsigned int, int)"
            function "max(long, long)"
            function "max(unsigned long, unsigned long)"
            function "max(long, unsigned long)"
            function "max(unsigned long, long)"
            function "max(long long, long long)"
            function "max(unsigned long long, unsigned long long)"
            function "max(long long, unsigned long long)"
            function "max(unsigned long long, long long)"
            function "max(float, float)"
            argument types are: (half, half)

/tmp/tmpyp0ltjf9/my_kernel.cu(136): error: more than one instance of overloaded function "max" matches the argument list:
            function "max(int, int)"
            function "max(unsigned int, unsigned int)"
            function "max(int, unsigned int)"
            function "max(unsigned int, int)"
            function "max(long, long)"
            function "max(unsigned long, unsigned long)"
            function "max(long, unsigned long)"
            function "max(unsigned long, long)"
            function "max(long long, long long)"
            function "max(unsigned long long, unsigned long long)"
            function "max(long long, unsigned long long)"
            function "max(unsigned long long, long long)"
            function "max(float, float)"
            argument types are: (__half, __half)

... ...

Update: the error is located finally to the max op in topi. The above error can be produced with a simple script like below:

import tvm
import topi

if __name__ == "__main__":
    ctx = tvm.gpu()
    dtype = 'float16'
    A = tvm.placeholder((1, 3, 10, 10), dtype=dtype, name='A')
    B = topi.nn.relu(A)
    with tvm.target.cuda():
        sg = topi.generic.schedule_elemwise(B)
        f = tvm.build(sg, [A, B], 'cuda')

Error

Traceback (most recent call last):
  File "cache/test_relu.py", line 15, in <module>
    f = tvm.build(sg, [A, B], 'cuda')
  File "/home/lhy/Documents/Lib/tvm/python/tvm/build_module.py", line 586, in build
    fhost, mdev = _build_for_device(flist, tar, target_host)
  File "/home/lhy/Documents/Lib/tvm/python/tvm/build_module.py", line 453, in _build_for_device
    mdev = codegen.build_module(fdevice, str(target)) if fdevice else None
  File "/home/lhy/Documents/Lib/tvm/python/tvm/codegen.py", line 20, in build_module
    return _Build(lowered_func, target)
  File "tvm/_ffi/_cython/./function.pxi", line 286, in core.FunctionBase.__call__
  File "tvm/_ffi/_cython/./function.pxi", line 221, in core.FuncCall
  File "tvm/_ffi/_cython/./function.pxi", line 210, in core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 143, in core.CALL
tvm._ffi.base.TVMError: TVMCall CFunc Error:
Traceback (most recent call last):
  File "tvm/_ffi/_cython/./function.pxi", line 38, in core.tvm_callback
  File "/home/lhy/Documents/Lib/tvm/python/tvm/autotvm/measure/measure_methods.py", line 557, in tvm_callback_cuda_compile
    ptx = nvcc.compile_cuda(code, target="ptx", arch=AutotvmGlobalScope.current.cuda_target_arch)
  File "/home/lhy/Documents/Lib/tvm/python/tvm/contrib/nvcc.py", line 82, in compile_cuda
    raise RuntimeError(msg)
RuntimeError: Compilation error:
/tmp/tmpui8ewz89/my_kernel.cu(4): error: more than one instance of overloaded function "max" matches the argument list:
            function "max(int, int)"
            function "max(unsigned int, unsigned int)"
            function "max(int, unsigned int)"
            function "max(unsigned int, int)"
            function "max(long, long)"
            function "max(unsigned long, unsigned long)"
            function "max(long, unsigned long)"
            function "max(unsigned long, long)"
            function "max(long long, long long)"
            function "max(unsigned long long, unsigned long long)"
            function "max(long long, unsigned long long)"
            function "max(unsigned long long, long long)"
            function "max(float, float)"
            argument types are: (half, __half)

1 error detected in the compilation of "/tmp/tmpxft_00003ed3_00000000-6_my_kernel.cpp1.ii".

BTW, changing dtype to ‘float32’ or ‘int8’ both works well.

1 Like

I think I’m hitting the same problem

nvcc --ptx -O3 -arch sm_60 -o /home/iliacher/tmp.kernel.compiled /home/iliacher/tmp.kernel.cu
/home/iliacher/tmp.kernel.cu(5): error: more than one instance of overloaded function "max" matches the argument list:
        function "max(int, int)"
        function "max(unsigned int, unsigned int)"
        function "max(int, unsigned int)"
        function "max(unsigned int, int)"
        function "max(long, long)"
        function "max(unsigned long, unsigned long)"
        function "max(long, unsigned long)"
        function "max(unsigned long, long)"
        function "max(long long, long long)"
        function "max(unsigned long long, unsigned long long)"
        function "max(long long, unsigned long long)"
        function "max(unsigned long long, long long)"
        function "max(float, float)"
        argument types are: (half, half)

Have you found a solution?

Please check the pr here.

1 Like