A GPU thread binding question

We are trying to build an operational model of the scheduling primitives in Tensor Expression, especially for those GPU related ones.

Here is a simple test case that we’d expect to work but it does not.

block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
N = 32*8
A = tvm.placeholder((N,), name='A')
B = tvm.compute((N,), lambda i: A[i] + 1, name='B')   
C = tvm.compute((N,), lambda i: B[i] * 2, name='C')
Nblock = 32
s = tvm.create_schedule(C.op)
Ci_outer, Ci_inner = s[C].split(C.op.axis[0], Nblock)
s[B].compute_at(s[C], Ci_outer)
s[B].set_scope("shared")
s[C].bind(Ci_outer, block_x)
s[C].bind(Ci_inner, thread_x)
# the following one results in error
s[B].bind(s[B].op.axis[0], thread_x)

TE complains (https://github.com/apache/incubator-tvm/blob/master/src/op/op_util.cc#L158)

TVMError: Check failed: is_zero(dom->min):

The assertion checks whether the thread IterVar starts from 0. In this case, it starts from blockIdx.x*32.

Splitting B’s index into two makes the code work.

< s[B].bind(s[B].op.axis[0], thread_x)
---
> Bi_outer, Bi_inner = s[B].split(B.op.axis[0], Nblock)
> s[B].bind(Bi_inner, thread_x)

and it generates the following wanted CUDA code.

extern "C" __global__ void gemm_kernel0( float* __restrict__ A,  float* __restrict__ C) {
      __shared__ float B[32];
      B[((int)threadIdx.x)] = (A[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))] + 1.000000e+00f);
      C[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))] = (B[((int)threadIdx.x)] * 2.000000e+00f);
 }

Our questions are

  1. How do we explain this to kernel developers who are using TE?
  2. Is this limitation fundamental to how TE works? Or is it a known issue in the current implementation?

Thanks.

Yuan

1 Like

It is really weird. Do you have any ideas about this problem? @vinx13 @yzhliu

1 Like

Since B attaches to C and both Ci_inner and B’s root IterVar bind to threadIdx.x, they have to have compatible ranges. However, Ci_inner’s bound is [0, 32), while B’s root IterVar’s range is [blockIdx.x*32, (blockIdx.x+1)*32), after InferBound. Their ranges are not compatible.

After splitting B and binding Bi_inner to threadIdx.x, Bi_inner’s bound becomes [0,32) too. Therefore, problem is avoided.

A rebasing can offset B’s root IterVar’s range from [blockIdx.x*32, (blockIdx.x+1)*32) to [0, 32). I notice that bound paths are skipped to rebase today. The above code works with the following small change to allow rebasing bound paths:

--- a/src/schedule/schedule_dataflow_rewrite.cc
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -506,11 +506,6 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
 for (IterVar iv : root_iter_vars) {
   size_t idx = FindNodeRef(leaf_vars, iv);
   auto it  = s->iter_var_attrs.find(iv);
 -      // don;t need to rebase path that are binded.
 -      if (it != s->iter_var_attrs.end() &&
 -          (*it).second->bind_thread.defined()) {
 -        continue;
 -      }
   if (idx < leaf_vars->data.size()) {
     // insert rebase
     IterVar rebased = IterVarNode::make(
1 Like

Just found some old discussion on the same topic: https://github.com/apache/incubator-tvm/issues/756

At the end of the thread, Tianqi briefly explained why normalization don’t rebase bound paths. My understanding is that the workaround at that time was to ask user to do an explicit rebase. However, I don’t see rebase is exposed as a schedule primitive today. How to rebase explicitly today, or solve this issue w/o rebase?

After almost two years, has the situation of thread binding semantics definition changed, so that my attempt shown in the previous comment may be applicable?

@tqchen, @derisavi-huawei, can you shed some light upon this?

2 Likes

Although I don’t find explicit rebase API, I believe a trivial split serves the same purpose today. Say I would like to do:

rebased = s[B].rebase(B.op.axis[0])

s[B].bind(rebased, tvm.thread_axis(“threadIdx.x”)

Instead, I can do:

rebased, dontcare = s[B].split(B.op.axis[0], 1)

s[B].bind(rebased, tvm.thread_axis(“threadIdx.x”)

TVM today unconditionally sets inner and outer’s min to 0, exactly what rebase does. Setting these IterVars’ min to 0 doesn’t actually lose information, because TVM uses the parent’s (i.e. the parent node in the IterVar graph) min to calculate the offset.

This seems a hack to me, as it relies on assumption that split sets mins to 0 and so on. I found this behavior from code, but not from document tvm.schedule.Stage.split as semantics.

Can @tqchen, @derisavi-huawei, or someone comment on this?