The problem happends when I using split and compute_at to tiling and merging stages. When I define the split factor as an uint32 variable, the itervar domain of producer stage seems wrong.
Simple demo for this problem:
import tvm
c = tvm.var("c")
h = tvm.var("h")
w = tvm.var("w")
data_shape = [c, h, w]
A = tvm.placeholder(data_shape, dtype="float32", name="tensor_A")
B = tvm.placeholder(A.shape, dtype="float32", name="tensor_B")
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), "C")
D = tvm.compute(A.shape, lambda *i: C(*i) * 2, "D")
s = tvm.create_schedule(D.op)
factor = tvm.var("ft", dtype="uint32")
bo, bi = s[D].split(D.op.axis[-1], factor)
s[C].compute_at(s[D], bo)
print(tvm.lower(s, [D, A, B, c, h, w], simple_mode=True))
and the output is:
// attr [C] storage_scope = "global"
allocate C[float32 * w]
produce D {
for (i0, 0, c) {
for (i1, 0, h) {
for (i2.outer, 0, floordiv((w + int32((ft - (uint32)1))), int32(ft))) {
produce C {
for (i2, 0, w) {
C[i2] = (tensor_A[(((i0*stride) + (i1*stride)) + (i2*stride))] + tensor_B[(((i0*stride) + (i1*stride)) + (i2*stride))])
}
}
for (i2.inner, 0, ft) {
if (likely(((i2.inner + (i2.outer*int32(ft))) < w))) {
if (likely(((i2.inner + (i2.outer*int32(ft))) < w))) {
D[(((i0*stride) + (i1*stride)) + ((i2.inner + (i2.outer*int32(ft)))*stride))] = (C[(i2.inner + (i2.outer*int32(ft)))]*2f)
}
}
}
}
}
}
}
In the printed IR, the axis entent of stage C is not correct, which is w instead of ft.
Here is more info for it:
- If I define factor as an int-type var or const numble, the result is right.
- I want to calculate split_factor dynamicly according to w and free memory space, so const type is not suitable here.
- Using an int-type variable will lead to more code to deal with nagetive situation, which is kind of redundant.
Any hint for which pass I should dig into? Thx a lot.
PS: Is there any tutorial about the lower process? Including the passes arrangement/relations and the design idea.