Cross thread reduction fails when reduce_axis length is variable

Hi, I want to do cross thread reduction on GPU with variable reduce_axis length.

I modify the Reduction tutorial and the code looks like below.

import tvm
n = tvm.var("n")
m = tvm.var("m")
A = tvm.placeholder((n, m), name='A')
# assume the values in R are smaller than m
R = tvm.placeholder((n,), name='R', dtype="int")
def f(i):
    k = tvm.reduce_axis((0, R[i]), "k")
    return tvm.sum(A[i, k], axis=k)
B = tvm.compute((n,), f, name="B")

s = tvm.create_schedule(B.op)
ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
BF = s.rfactor(B, ki)
xo, xi = s[B].split(B.op.axis[0], factor=32)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.y"))
tx = tvm.thread_axis("threadIdx.x")
s[B].bind(s[B].op.reduce_axis[0], tx)
s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
s[B].set_store_predicate(tx.var.equal(0))

However, this code will fail with the following errors.

Traceback (most recent call last):

  File "reduction_var.py", line 93, in <module>
    xo, xi = s[B].split(B.op.axis[0], factor=32)

  File "/home/cyulin/tvm/python/tvm/schedule.py", line 383, in split
    outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)

  File "/home/cyulin/tvm/python/tvm/_ffi/_ctypes/function.py", line 207, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (5) /home/cyulin/tvm/build/libtvm.so(TVMFuncCall+0x61) [0x7f0921b10ce1]
  [bt] (4) /home/cyulin/tvm/build/libtvm.so(+0x3dda44) [0x7f0921365a44]
  [bt] (3) /home/cyulin/tvm/build/libtvm.so(tvm::Stage::split(tvm::IterVar, tvm::Expr, tvm::IterVar*, tvm::IterVar*)+0x5f) [0x7f09216b32df]
  [bt] (2) /home/cyulin/tvm/build/libtvm.so(+0x72ae86) [0x7f09216b2e86]
  [bt] (1) /home/cyulin/tvm/build/libtvm.so(+0x724391) [0x7f09216ac391]
  [bt] (0) /home/cyulin/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7f092131b8a2]
  File "/home/cyulin/tvm/src/schedule/schedule_lang.cc", line 53
TVMError: Operate on iter var iter_var(i, range(min=0, ext=n))that is not part of the schedule

Any thoughts on why this issue happens and how to solve it?

Thanks!