Iter variables not in schedule after rfactor


Consider this snippet of code:

import tvm

N = 2048
x = tvm.placeholder((N,), name='x')
y = tvm.placeholder((N,), name='y')
k = tvm.reduce_axis((0,N), name='k')
z = tvm.compute((1,), lambda _: tvm.sum(x[k]*y[k], axis=k))
s = tvm.create_schedule(z.op)

ko, ki = s[z].split(k, factor=16)

# reorder(ki, ko) works here
r = s.rfactor(z, ko)
# reorder(ki, ko) fails here

print("ko:", ko, "ki", ki)
print("real_ko:", s[r].op.axis[0], "real_ki:", s[r].op.reduce_axis[0])

s[r].reorder(ki, ko)  # crash

# The following way works:
# s[r].reorder(s[r].op.reduce_axis[0], s[r].op.axis[0])

print(tvm.lower(s, [x, y], simple_mode=True))

It starts with a reduction. The reduction is split into two iteration variables ko and ki. At this point reordering these variables works ok, but after promoting ki to an axis (rfactor), reorder(ki, ko) crashes, but works when the iteration variables are extracted directly from s[r].op.

The output of this code is

ko: iter_var(k.outer, ) ki iter_var(k.inner, )
real_ko: iter_var(k.outer, range(min=0, ext=128)) real_ki: iter_var(k.inner, range(min=0, ext=16))
[11:24:23] /w/src/dmlc/tvm/src/schedule/ Operate on iter var iter_var(k.inner, )that is not part of the schedule

Is this a bug, or is this an intentional behavior?


When we run rfactor, we insert another stage (that does factored computation) and the original iter var became stale. So we do need to re-extract from the op.