Significant increase in the amount of cuda code gen after migrating indexdiv/mod to floordiv/mod

We have been using TVM to generate operator kernels to replace handcrafted ones in MXNet. We noticed there is a significant increase in the amount of cuda code gen after indexdiv/mod was replaced by floordiv/mod in this PR. One consequence from this is that the time cost of solely compiling TVM operators increases from about 2 minutes to 30 minutes in MXNet. Using the following script, for example, to compile an add kernel for 3D tensors, the my_kernel.cu cuda file generated by the function compile_cuda() has much larger size than before. It also results in much longer JIT time at runtime because of the larger ptx code. Is there a solution to reduce the cuda code gen size?

@tqchen @yzhliu @Laurawly

39K  my_kernel_after.cu   # w/ PR4008
4.4K my_kernel_before.cu  # w/o PR4008
import time
from tvm import autotvm
import tvm

print(tvm)

ndim = 3
a = tvm.placeholder([tvm.var() for _ in range(ndim)], dtype='float32', name='a')
b = tvm.placeholder([tvm.var() for _ in range(ndim)], dtype='float32', name='b')

c = tvm.compute([tvm.var() for _ in range(ndim)], lambda *idx: a[idx] + b[idx], name='c')
s = tvm.create_schedule(c.op)
axes = [axis for axis in c.op.axis]
fused = s[c].fuse(*axes)
bx, tx = s[c].split(fused, factor=64)
s[c].bind(bx, tvm.thread_axis('blockIdx.x'))
s[c].bind(tx, tvm.thread_axis('threadIdx.x'))

arg_binds = {arg: tvm.decl_buffer(arg.shape, arg.dtype, buffer_type='auto_broadcast')
             for arg in [a, b, c]}

start = time.time()
func_bin = tvm.build(s, [a, b, c], target='cuda', target_host='llvm', name='tvmop', binds=arg_binds)
elapse = time.time() - start
print('Build time cost: {} seconds'.format(elapse))

Seems that the constraint API of tvm.var is under developing.

I think @hgt312 is right. it’s due to floor_div requires bound information of whether it is larger than zero or not, otherwise there would be insufficient information to simplify, thereby generate too many if-conditions to make the program be safely guarded.