We are trying to build an operational model of the scheduling primitives in Tensor Expression, especially for those GPU related ones.
Here is a simple test case that we’d expect to work but it does not.
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
N = 32*8
A = tvm.placeholder((N,), name='A')
B = tvm.compute((N,), lambda i: A[i] + 1, name='B')
C = tvm.compute((N,), lambda i: B[i] * 2, name='C')
Nblock = 32
s = tvm.create_schedule(C.op)
Ci_outer, Ci_inner = s[C].split(C.op.axis[0], Nblock)
s[B].compute_at(s[C], Ci_outer)
s[B].set_scope("shared")
s[C].bind(Ci_outer, block_x)
s[C].bind(Ci_inner, thread_x)
# the following one results in error
s[B].bind(s[B].op.axis[0], thread_x)
TE complains (https://github.com/apache/incubator-tvm/blob/master/src/op/op_util.cc#L158)
TVMError: Check failed: is_zero(dom->min):
The assertion checks whether the thread IterVar starts from 0. In this case, it starts from blockIdx.x*32.
Splitting B’s index into two makes the code work.
< s[B].bind(s[B].op.axis[0], thread_x)
---
> Bi_outer, Bi_inner = s[B].split(B.op.axis[0], Nblock)
> s[B].bind(Bi_inner, thread_x)
and it generates the following wanted CUDA code.
extern "C" __global__ void gemm_kernel0( float* __restrict__ A, float* __restrict__ C) {
__shared__ float B[32];
B[((int)threadIdx.x)] = (A[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))] + 1.000000e+00f);
C[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))] = (B[((int)threadIdx.x)] * 2.000000e+00f);
}
Our questions are
- How do we explain this to kernel developers who are using TE?
- Is this limitation fundamental to how TE works? Or is it a known issue in the current implementation?
Thanks.
Yuan