Incorrect CUDA code generated when 3d thread-grid is used

Given the test case shown at the bottom, if cuda is set to False, the code runs fine and gives the correct result. If cuda is set to True, the execution result is incorrect. From the CUDA code generated, it seems the threads corresponding to threadIdx.y == 1 and threadIdx.z==1 (both extends are 2) are being predicated off.

extern "C" __global__ void transpose_kernel0( float* __restrict__ A,  float* __restrict__ C) {
   float B[32];
  for (int i_outer = 0; i_outer < 1; ++i_outer) {
    for (int j_outer = 0; j_outer < 1; ++j_outer) {
      for (int j_inner_inner = 0; j_inner_inner < 32; ++j_inner_inner) {
        if (((i_outer * 64) + (((int)threadIdx.z) * 32)) < (1 - ((int)threadIdx.x))) {
          if (((j_outer * 64) + (((int)threadIdx.y) * 32)) < (32 - j_inner_inner)) {
            if ((((((int)threadIdx.z) * 32) + (i_outer * 32)) + (((int)blockIdx.y) * 32)) < (512 - ((int)threadIdx.x))) {
              if ((((((int)threadIdx.y) * 64) + (j_outer * 64)) + (((int)blockIdx.x) * 64)) < (1024 - j_inner_inner)) {
                B[((((((i_outer * 2048) + (((int)threadIdx.z) * 1024)) + (j_outer * 64)) + (((int)threadIdx.x) * 32)) + (((int)threadIdx.y) * 32)) + j_inner_inner)] = (A[((((((((((int)threadIdx.z) * 65536) + (i_outer * 65536)) + (((int)blockIdx.y) * 65536)) + (((int)threadIdx.x) * 2048)) + (((int)threadIdx.y) * 64)) + (j_outer * 64)) + (((int)blockIdx.x) * 64)) + j_inner_inner)] * 3.140000e+00f);
              }
            }
          }
        }
      }
    }
  }
  for (int j_inner_inner1 = 0; j_inner_inner1 < 32; ++j_inner_inner1) {
    C[((((((((int)blockIdx.y) * 65536) + (((int)threadIdx.z) * 32768)) + (((int)threadIdx.x) * 1024)) + (((int)blockIdx.x) * 64)) + (((int)threadIdx.y) * 32)) + j_inner_inner1)] = (B[j_inner_inner1] * 2.170000e+00f);
  }
}

Is the use of scheduling primitives incorrect? Or are we hitting some known issue in TE?

import tvm
import numpy as np
import time

cuda = True
#cuda = False

block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_x = tvm.thread_axis("threadIdx.x")
thread_y = tvm.thread_axis("threadIdx.y")
thread_z = tvm.thread_axis("threadIdx.z")

def build_and_test(s, A, B, target, name, showcuda=False):
    ctx = tvm.context(target, 0)
    func = tvm.build(s, [A, B], target=target, name='transpose')
    assert func

    if showcuda:
        print(func.imported_modules[0].get_source())
        #print(func.get_source())
        
    # Random generated tensor for testing
    a = tvm.nd.array(np.random.rand(A.shape[0].value, A.shape[1].value).astype("float32"), ctx)
    b = tvm.nd.array(np.random.rand(B.shape[0].value, B.shape[1].value).astype("float32"), ctx)
    
    func(a, b)
    answer = 3.14 * 2.17 * a.asnumpy()
    tvm.testing.assert_allclose(b.asnumpy(), answer, rtol=1e-5)

    evaluator = func.time_evaluator(func.entry_name, ctx, number=1)
    print(name+': %f ms' % (evaluator(a, b).mean * 1e3))

# Algorithm
M = 1024
N = 1024

A = tvm.placeholder((M, N), name='A')

B = tvm.compute((M,N), lambda i,j: A[i,j] * 3.14, name='B')
C = tvm.compute((M,N), lambda i,j: 2.17 * B[i,j], name='C')


# Schedule
s = tvm.create_schedule(C.op)
c_i, c_j = s[C].op.axis

c_i_outer, c_j_outer, c_i_inner, c_j_inner = s[C].tile(c_i,c_j, 64, 64)
c_i_inner_outer, c_j_inner_outer, c_i_inner_inner, c_j_inner_inner = s[C].tile(c_i_inner, c_j_inner, 32, 32)


b_i, b_j = s[B].op.axis
b_i_outer, b_j_outer, b_i_inner, b_j_inner = s[B].tile(b_i,b_j, 64, 64)
b_i_inner_outer, b_j_inner_outer, b_i_inner_inner, b_j_inner_inner = s[B].tile(b_i_inner, b_j_inner, 32, 32)

s[B].set_scope("local")

s[B].compute_at(s[C], c_i_inner_inner)

if cuda:
    s[C].bind(c_i_outer, block_y)
    s[C].bind(c_j_outer, block_x)
    s[C].bind(c_i_inner_outer, thread_z)
    s[C].bind(c_j_inner_outer, thread_y)
    s[C].bind(c_i_inner_inner, thread_x)

    #s[B].bind(b_i_outer, block_y)
    #s[B].bind(b_j_outer, block_x)
    s[B].bind(b_i_inner_outer, thread_z)
    s[B].bind(b_j_inner_outer, thread_y)
    s[B].bind(b_i_inner_inner, thread_x)

    build_and_test(s, A, C, "cuda", "gpu", showcuda=True)
else:
    build_and_test(s, A, C, "llvm", "cpu")

If you set scope of B to local, B is thread local and you can’t bind B to thread axes

1 Like

I see. It works w/o the binding. Thank you.