Error: "expression must have bool type (or be convertible to bool)" when generating CUDA kernels

Hello, I’m trying to run a TVM code to generate a CUDA kernel and I encounter this error. The generated code that triggers this is

int4 _30 = _21 ? _24 : _29;

where _21 is a uint4 variable defined and used as

uint4 _21;
_21.x = (_11.x||_20.x);
_21.y = (_11.y||_20.y);
_21.z = (_11.z||_20.z);
_21.w = (_11.w||_20.w);

Full error log:

Traceback (most recent call last):

  File "schedule_test.py", line 483, in <module>
    verify(p, print_ir=True, print_src=True, save_data=False, export_code=False)

  File "schedule_test.py", line 472, in verify
    check_device(device)

  File "schedule_test.py", line 452, in check_device
    func = tvm.build(s, params, device, name=("kernel"))

  File "/home/moderato/Documents/incubator-tvm/python/tvm/build_module.py", line 638, in build
    fhost, mdev = _build_for_device(flist, tar, target_host)

  File "/home/moderato/Documents/incubator-tvm/python/tvm/build_module.py", line 504, in _build_for_device
    mdev = codegen.build_module(fdevice, str(target)) if fdevice else None

  File "/home/moderato/Documents/incubator-tvm/python/tvm/codegen.py", line 36, in build_module
    return _Build(lowered_func, target)

  File "/home/moderato/Documents/incubator-tvm/python/tvm/_ffi/_ctypes/function.py", line 207, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (5) /home/moderato/Documents/incubator-tvm/build/libtvm.so(TVMFuncCall+0x65) [0x7fd5b3924585]
  [bt] (4) /home/moderato/Documents/incubator-tvm/build/libtvm.so(+0x375194) [0x7fd5b3146194]
  [bt] (3) /home/moderato/Documents/incubator-tvm/build/libtvm.so(tvm::codegen::Build(tvm::Array<tvm::LoweredFunc, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)+0xb9f) [0x7fd5b325f5cf]
  [bt] (2) /home/moderato/Documents/incubator-tvm/build/libtvm.so(std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), void tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::Array<tvm::LoweredFunc, void>)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::Array<tvm::LoweredFunc, void>)>(tvm::runtime::Module (*)(tvm::Array<tvm::LoweredFunc, void>))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)+0x51) [0x7fd5b328d341]
  [bt] (1) /home/moderato/Documents/incubator-tvm/build/libtvm.so(tvm::codegen::BuildCUDA(tvm::Array<tvm::LoweredFunc, void>)+0x463) [0x7fd5b38c4f93]
  [bt] (0) /home/moderato/Documents/incubator-tvm/build/libtvm.so(+0xb4eadb) [0x7fd5b391fadb]
  File "/home/moderato/Documents/incubator-tvm/python/tvm/_ffi/_ctypes/function.py", line 72, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/moderato/Documents/incubator-tvm/python/tvm/autotvm/measure/measure_methods.py", line 593, in tvm_callback_cuda_compile
    ptx = nvcc.compile_cuda(code, target=target, arch=AutotvmGlobalScope.current.cuda_target_arch)
  File "/home/moderato/Documents/incubator-tvm/python/tvm/contrib/nvcc.py", line 101, in compile_cuda
    raise RuntimeError(msg)
RuntimeError: Compilation error:
/tmp/tmpfxicdulv/my_kernel.cu(119): error: expression must have bool type (or be convertible to bool)

/tmp/tmpfxicdulv/my_kernel.cu(212): error: expression must have bool type (or be convertible to bool)

/tmp/tmpfxicdulv/my_kernel.cu(222): warning: attribute "__shared__" does not apply here

/tmp/tmpfxicdulv/my_kernel.cu(101): warning: variable "_24" was set but never used

/tmp/tmpfxicdulv/my_kernel.cu(114): warning: variable "_29" was set but never used

/tmp/tmpfxicdulv/my_kernel.cu(194): warning: variable "_57" was set but never used

/tmp/tmpfxicdulv/my_kernel.cu(207): warning: variable "_62" was set but never used

2 errors detected in the compilation of "/tmp/tmpxft_00002af4_00000000-6_my_kernel.cpp1.ii".

Anyone can help with this problem? The OS I’m using is Ubuntu 16.04, and my gcc and nvcc versions are 6.5 and 10.0. Thanks!

Can you please provide your code?

I’m afraid I can only share some code snippets here. It looks like it’s the codegen of vectorize(). The related python snippet is

    s[FS_1].compute_at(s[OL], xoicc)
    h1, w1, i1, o1 = s[FS_1].op.axis
    io = s[FS_1].fuse(i1, o1)
    io, iox = s[FS_1].split(io, factor=num_thread_x * 4)
    ioy, io = s[FS_1].split(io, nparts=num_thread_y)
    iox, io4 = s[FS_1].split(iox, factor=4)
    s[FS_1].reorder(h1, w1, io, ioy, iox, io4)
    s[FS_1].bind(iox, thread_x)
    s[FS_1].bind(ioy, thread_y)
    s[FS_1].vectorize(io4)

With this code, the corresponding CUDA snippet is

for (int ax2_ax3_fused_outer_inner = 0; ax2_ax3_fused_outer_inner < 4; ++ax2_ax3_fused_outer_inner) {
          float4 _1;
                int4 _2 = make_int4((((rc_outer * 16384) + (rc_inner_outer * 4096)) + (((int)threadIdx.y) * 1024)), (((rc_outer * 16384) + (rc_inner_outer * 4096)) + (((int)threadIdx.y) * 1024)), (((rc_outer * 16384) + (rc_inner_outer * 4096)) + (((int)threadIdx.y) * 1024)), (((rc_outer * 16384) + (rc_inner_outer * 4096)) + (((int)threadIdx.y) * 1024)));
                        int4 _3 = make_int4(256, 256, 256, 256);
                        int4 _4 = make_int4(0, 0, 0, 0);
                        uint4 _5;
                        _5.x = (_3.x>=_4.x);
                        _5.y = (_3.y>=_4.y);
                        _5.z = (_3.z>=_4.z);
                        _5.w = (_3.w>=_4.w);
                          int4 _6 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
                          int4 _7 = make_int4(256, 256, 256, 256);
                          int4 _8;
                          _8.x = (_6.x%_7.x);
                          _8.y = (_6.y%_7.y);
                          _8.z = (_6.z%_7.z);
                          _8.w = (_6.w%_7.w);
                        int4 _9 = make_int4(0, 0, 0, 0);
                        uint4 _10;
                        _10.x = (_8.x>=_9.x);
                        _10.y = (_8.y>=_9.y);
                        _10.z = (_8.z>=_9.z);
                        _10.w = (_8.w>=_9.w);
                      uint4 _11;
                      _11.x = (_5.x&&_10.x);
                      _11.y = (_5.y&&_10.y);
                      _11.z = (_5.z&&_10.z);
                      _11.w = (_5.w&&_10.w);
                        int4 _12 = make_int4(256, 256, 256, 256);
                        int4 _13 = make_int4(0, 0, 0, 0);
                        uint4 _14;
                        _14.x = (_12.x<_13.x);
                        _14.y = (_12.y<_13.y);
                        _14.z = (_12.z<_13.z);
                        _14.w = (_12.w<_13.w);
                          int4 _15 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
                          int4 _16 = make_int4(256, 256, 256, 256);
                          int4 _17;
                          _17.x = (_15.x%_16.x);
                          _17.y = (_15.y%_16.y);
                          _17.z = (_15.z%_16.z);
                          _17.w = (_15.w%_16.w);
                        int4 _18 = make_int4(0, 0, 0, 0);
                        uint4 _19;
                        _19.x = (_17.x<=_18.x);
                        _19.y = (_17.y<=_18.y);
                        _19.z = (_17.z<=_18.z);
                        _19.w = (_17.w<=_18.w);
                      uint4 _20;
                      _20.x = (_14.x&&_19.x);
                      _20.y = (_14.y&&_19.y);
                      _20.z = (_14.z&&_19.z);
                      _20.w = (_14.w&&_19.w);
                    uint4 _21;
                    _21.x = (_11.x||_20.x);
                    _21.y = (_11.y||_20.y);
                    _21.z = (_11.z||_20.z);
                    _21.w = (_11.w||_20.w);
                    int4 _22 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
                    int4 _23 = make_int4(256, 256, 256, 256);
                    int4 _24;
                    _24.x = (_22.x/_23.x);
                    _24.y = (_22.y/_23.y);
                    _24.z = (_22.z/_23.z);
                    _24.w = (_22.w/_23.w);
                      int4 _25 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
                      int4 _26 = make_int4(256, 256, 256, 256);
                      int4 _27;
                      _27.x = (_25.x/_26.x);
                      _27.y = (_25.y/_26.y);
                      _27.z = (_25.z/_26.z);
                      _27.w = (_25.w/_26.w);
                    int4 _28 = make_int4(1, 1, 1, 1);
                    int4 _29;
                    _29.x = (_27.x-_28.x);
                    _29.y = (_27.y-_28.y);
                    _29.z = (_27.z-_28.z);
                    _29.w = (_27.w-_28.w);
                  int4 _30 = _21 ? _24 : _29;
                  int4 _31 = make_int4(512, 512, 512, 512);
                  int4 _32;
                  _32.x = (_30.x*_31.x);
                  _32.y = (_30.y*_31.y);
                  _32.z = (_30.z*_31.z);
                  _32.w = (_30.w*_31.w);
                int4 _33;
                _33.x = (_2.x+_32.x);
                _33.y = (_2.y+_32.y);
                _33.z = (_2.z+_32.z);
                _33.w = (_2.w+_32.w);
              int4 _34 = make_int4(((((int)blockIdx.x) & 1) * 256), ((((int)blockIdx.x) & 1) * 256), ((((int)blockIdx.x) & 1) * 256), ((((int)blockIdx.x) & 1) * 256));
              int4 _35;
              _35.x = (_33.x+_34.x);
              _35.y = (_33.y+_34.y);
              _35.z = (_33.z+_34.z);
              _35.w = (_33.w+_34.w);
                  int4 _36 = make_int4(256, 256, 256, 256);
                  int4 _37 = make_int4(0, 0, 0, 0);
                  uint4 _38;
                  _38.x = (_36.x>=_37.x);
                  _38.y = (_36.y>=_37.y);
                  _38.z = (_36.z>=_37.z);
                  _38.w = (_36.w>=_37.w);
                    int4 _39 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
                    int4 _40 = make_int4(256, 256, 256, 256);
                    int4 _41;
                    _41.x = (_39.x%_40.x);
                    _41.y = (_39.y%_40.y);
                    _41.z = (_39.z%_40.z);
                    _41.w = (_39.w%_40.w);
                  int4 _42 = make_int4(0, 0, 0, 0);
                  uint4 _43;
                  _43.x = (_41.x>=_42.x);
                  _43.y = (_41.y>=_42.y);
                  _43.z = (_41.z>=_42.z);
                  _43.w = (_41.w>=_42.w);
                uint4 _44;
                _44.x = (_38.x&&_43.x);
                _44.y = (_38.y&&_43.y);
                _44.z = (_38.z&&_43.z);
                _44.w = (_38.w&&_43.w);
                  int4 _45 = make_int4(256, 256, 256, 256);
                  int4 _46 = make_int4(0, 0, 0, 0);
                  uint4 _47;
                  _47.x = (_45.x<_46.x);
                  _47.y = (_45.y<_46.y);
                  _47.z = (_45.z<_46.z);
                  _47.w = (_45.w<_46.w);
                    int4 _48 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
                    int4 _49 = make_int4(256, 256, 256, 256);
                    int4 _50;
                    _50.x = (_48.x%_49.x);
                    _50.y = (_48.y%_49.y);
                    _50.z = (_48.z%_49.z);
                    _50.w = (_48.w%_49.w);
                  int4 _51 = make_int4(0, 0, 0, 0);
                  uint4 _52;
                  _52.x = (_50.x<=_51.x);
                  _52.y = (_50.y<=_51.y);
                  _52.z = (_50.z<=_51.z);
                  _52.w = (_50.w<=_51.w);
                uint4 _53;
                _53.x = (_47.x&&_52.x);
                _53.y = (_47.y&&_52.y);
                _53.z = (_47.z&&_52.z);
                _53.w = (_47.w&&_52.w);
              uint4 _54;
              _54.x = (_44.x||_53.x);
              _54.y = (_44.y||_53.y);
              _54.z = (_44.z||_53.z);
              _54.w = (_44.w||_53.w);
              int4 _55 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
              int4 _56 = make_int4(256, 256, 256, 256);
              int4 _57;
              _57.x = (_55.x%_56.x);
              _57.y = (_55.y%_56.y);
              _57.z = (_55.z%_56.z);
              _57.w = (_55.w%_56.w);
                int4 _58 = (make_int4)((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*0), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*1), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*2), (((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)))+(1*3));
                int4 _59 = make_int4(256, 256, 256, 256);
                int4 _60;
                _60.x = (_58.x%_59.x);
                _60.y = (_58.y%_59.y);
                _60.z = (_58.z%_59.z);
                _60.w = (_58.w%_59.w);
              int4 _61 = make_int4(256, 256, 256, 256);
              int4 _62;
              _62.x = (_60.x+_61.x);
              _62.y = (_60.y+_61.y);
              _62.z = (_60.z+_61.z);
              _62.w = (_60.w+_61.w);
            int4 _63 = _54 ? _57 : _62;
            int4 _64;
            _64.x = (_35.x+_63.x);
            _64.y = (_35.y+_63.y);
            _64.z = (_35.z+_63.z);
            _64.w = (_35.w+_63.w);
          _1.x = Conv2dFilter_1[_64.x];
          _1.y = Conv2dFilter_1[_64.y];
          _1.z = Conv2dFilter_1[_64.z];
          _1.w = Conv2dFilter_1[_64.w];
        ((__shared__ float4*)(Conv2dFilter_1_shared + (((((int)threadIdx.y) * 512) + (ax2_ax3_fused_outer_inner * 128)) + (((int)threadIdx.x) * 4))))[0] = _1;
      }

As mentioned above, the error occurs at two lines:

int4 _30 = _21 ? _24 : _29;

and

int4 _63 = _54 ? _57 : _62;

It’s confused how a uint4 variable like _21 and _54 is recognized as true/false here. All elements are true? One of them is true?

If I comment out vectorize(), the corresponding CUDA snippet becomes

for (int ax2_ax3_fused_outer_inner = 0; ax2_ax3_fused_outer_inner < 4; ++ax2_ax3_fused_outer_inner) {
        for (int ax2_ax3_fused_inner_inner = 0; ax2_ax3_fused_inner_inner < 4; ++ax2_ax3_fused_inner_inner) {
          Conv2dFilter_1_shared[((((((int)threadIdx.y) * 512) + (ax2_ax3_fused_outer_inner * 128)) + (((int)threadIdx.x) * 4)) + ax2_ax3_fused_inner_inner)] = Conv2dFilter_1[((((((rc_outer * 16384) + (rc_inner_outer * 4096)) + (((int)threadIdx.y) * 1024)) + (((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)) + ax2_ax3_fused_inner_inner) >> 8) * 512)) + ((((int)blockIdx.x) & 1) * 256)) + ((((ax2_ax3_fused_outer_inner * 128) + (((int)threadIdx.x) * 4)) + ax2_ax3_fused_inner_inner) & 255))];
        }
      }

which looks quite normal.