[ERROR]FP16 CUDA compilation error

In summary, here are what happened:
Since sm53, arithmetic operator overload for half are supported. But overloading for volatile half is not supported.
Assignment to volatile half is added since CUDA 9.2

Thanks for your clarification. But it seems we still need to overload for volatile half to pass a resnet float16 test case.

decl_stream << "__device__ half operator+"
                << "(const volatile __half &a,  const volatile __half &b)\n"
                <<"{\n  return __hadd(a, b);\n}\n";
decl_stream << "__device__ half operator*"
                << "(const volatile __half &a, const volatile __half &b)\n"
                <<   "{\n  return __hmul(a, b);\n}\n";

yes volatile overloading are needed

@vinx13, after removing the overloadings you mentioned then adding the overloadings @xyzhou mentioned, I still get the following error:

lm_long/diagonaled_mm_tvm.py:116: in _compile_function
    mm = tvm.build(s, [X, Y, Z], target=device, target_host=tgt_host, name='mm')
/usr/tvm/python/tvm/build_module.py:636: in build
    fhost, mdev = _build_for_device(flist, tar, target_host)
/usr/tvm/python/tvm/build_module.py:502: in _build_for_device
    mdev = codegen.build_module(fdevice, str(target)) if fdevice else None
/usr/tvm/python/tvm/codegen.py:36: in build_module
    return _Build(lowered_func, target)
/usr/tvm/python/tvm/_ffi/_ctypes/function.py:207: in __call__
    raise get_last_ffi_error()
E   tvm._ffi.base.TVMError: Traceback (most recent call last):
E     File "/usr/tvm/src/codegen/opt/build_cuda_on.cc", line 119
E   TVMError: Check failed: compile_res == NVRTC_SUCCESS (6 vs. 0) : default_program(16): error: class "__half_raw" has no suitable copy constructor
E   default_program(16): error: class "__half_raw" has no suitable copy constructor
E   
E   default_program(20): error: class "__half_raw" has no suitable copy constructor
E   
E   default_program(20): error: class "__half_raw" has no suitable copy constructor
E   
E   4 errors detected in the compilation of "default_program".
decl_stream << "__device__ half operator+"
                << "(const volatile __half &a,  const volatile __half &b)\n"
                <<"{\n  return __hadd(a, b);\n}\n";
decl_stream << "__device__ half operator*"
                << "(const volatile __half &a, const volatile __half &b)\n"
                <<   "{\n  return __hmul(a, b);\n}\n";

The error I am getting is related to these overloadings. Adding them results into the error: class "__half_raw" has no suitable copy constructor error I mentioned earlier, even with a simple example like this:

import tvm
n = tvm.var('n')
X = tvm.placeholder((n), name='X', dtype='float16')
Z = tvm.compute((n), lambda i: X[i], name='Z')
s = tvm.create_schedule(Z.op)
s[Z].bind(Z.op.axis[0], tvm.thread_axis("blockIdx.x"))
f = tvm.build(s, [X, Z], target='cuda', target_host='llvm', name='f')

can you try setting arch to see if the error still exists?
autotvm.measure.measure_methods.set_cuda_target_arch(“sm_70”)

1 Like

adding set_cuda_target_arch(“sm_70”) fixed the __half_raw error. Thanks @yzhliu

@vinx13 Does it mean current fp16 code only works for __CUDA_ARCH__ >= 530 and cuda version >= 92 ?

volatile half is needed for softmax, which use cross thread reduction (but we can fix this by modifying codegen).
For arch >= 530, there are some conflicts with overloading. Currently we have overloaded half arithmetic operators, but they are actually available on cuda since arch 530 and caused error. This means that current one does work for cuda arch >= 530 but it is easy to fix

It doesn’t actually work. It compiles successfully but the output of the compiled function is wrong. I will debug this further and let you know, but if you have thought why this might be the case, that would be helpful.
@xyzhou, did you notice something similar ?

1 Like

I have the same compilation error as yours, I have not tried the workaround yet. I will try that after this week.

@vinx13 what’s the purpose of adding __device__ half operator<=(__half a, __half b) ?

it fails for following piece of code

import tvm
from tvm import autotvm

autotvm.measure.measure_methods.set_cuda_target_arch("sm_70")

size = 1024
dtype = "float16"
A = tvm.placeholder((size, ), name='A', dtype=dtype)
B = tvm.placeholder((size, ), name='B', dtype=dtype)
C = tvm.compute((size, ), lambda i: A[i] <= B[i], name='C')

s = tvm.create_schedule(C.op)

bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))

fadd = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="xxx")

as cuda gets confused which operator<= to use when casting half to __half:

RuntimeError: Compilation error:
/tmp/tmpfkvvzzy9/my_kernel.cu(277): error: more than one operator "<=" matches these operands:
            function "operator<=(const __half &, const __half &)"
            function "operator<=(__half, __half)"
            operand types are: half <= half

In fact, @vinx13 added the code

__device__ half operator<=(const volatile half &a,  const volatile half &b) 

It will cause compile error in NVRTC. Then tianqi tried to fix it changing it into

__device__ half operator<=(__half a, __half b)

but it cause compile error in nvcc.

The problem is we need volatile overload in NVCC but it will fail in NVRTC. It would be a very interesting question how to make it work in both NVCC and NVRTC.

can we remove the <= override?

I am afraid not. In some workloads volatile overload is necessary. e.g. volatile + override is needed in fp16 resnet.

then any chance we can guard them by ifdef ?

if I understand correctly, volatile is required for functionality in specific case, e.g., softmax, (not for performance), thus current implement operator<=(__half a, __half b) does not help anyhow. so why don’t we remove operator<=(__half a, __half b) to unblock the case I posted, and add back once we figure out a solution?

correct me if my understanding is wrong.

Yes, operator<=(__half a, __half b) is useless as far as I know. But the previous solution also has a problem. What I prefer is to solve this problem once for all. But if it really blocks your work, it is OK to roll back to the last commit for these lines.

@yzhliu @xyzhou
I think we should just remove these overloads. They are originally added to support volatile, and overloading for non-volatile is not necessary since CUDA has already provided overloading.

I agree. just remove these overloads and revert the volatile overloading. This would only cause NVRTC error, which may have minimal influence.

On the other hand, I am considering why we still need NVRTC. It generates slower programs in most cases. Can we just make NVCC as default (or the only) cuda compiler? Then it is safe to overload volatile.