Error happends while using split&compute_at with an uint32 var-type split_factor

The problem happends when I using split and compute_at to tiling and merging stages. When I define the split factor as an uint32 variable, the itervar domain of producer stage seems wrong.

Simple demo for this problem:

import tvm

c = tvm.var("c")
h = tvm.var("h")
w = tvm.var("w")

data_shape = [c, h, w]

A = tvm.placeholder(data_shape, dtype="float32", name="tensor_A")
B = tvm.placeholder(A.shape, dtype="float32", name="tensor_B")

C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), "C")
D = tvm.compute(A.shape, lambda *i: C(*i) * 2, "D")

s = tvm.create_schedule(D.op)

factor = tvm.var("ft", dtype="uint32")
bo, bi = s[D].split(D.op.axis[-1], factor)

s[C].compute_at(s[D], bo)

print(tvm.lower(s, [D, A, B, c, h, w], simple_mode=True))

and the output is:

// attr [C] storage_scope = "global"
allocate C[float32 * w]
produce D {
  for (i0, 0, c) {
    for (i1, 0, h) {
      for (i2.outer, 0, floordiv((w + int32((ft - (uint32)1))), int32(ft))) {
        produce C {
          for (i2, 0, w) {
            C[i2] = (tensor_A[(((i0*stride) + (i1*stride)) + (i2*stride))] + tensor_B[(((i0*stride) + (i1*stride)) + (i2*stride))])
          }
        }
        for (i2.inner, 0, ft) {
          if (likely(((i2.inner + (i2.outer*int32(ft))) < w))) {
            if (likely(((i2.inner + (i2.outer*int32(ft))) < w))) {
              D[(((i0*stride) + (i1*stride)) + ((i2.inner + (i2.outer*int32(ft)))*stride))] = (C[(i2.inner + (i2.outer*int32(ft)))]*2f)
            }
          }
        }
      }
    }
  }
}

In the printed IR, the axis entent of stage C is not correct, which is w instead of ft.

Here is more info for it:

  • If I define factor as an int-type var or const numble, the result is right.
  • I want to calculate split_factor dynamicly according to w and free memory space, so const type is not suitable here.
  • Using an int-type variable will lead to more code to deal with nagetive situation, which is kind of redundant.

Any hint for which pass I should dig into? Thx a lot.

PS: Is there any tutorial about the lower process? Including the passes arrangement/relations and the design idea.