I write the dsl as follows.
import tvm
shape = (2, 1)
reduce_end = tvm.placeholder(shape, name=“reduce_end”, dtype=“int32”)
fm = tvm.placeholder((5, 4, 16), name=“fm”)
reduce_2 = tvm.compute(shape, lambda *indice: reduce_end(*indice)*tvm.const(2, dtype=“int32”), name=“reduce_2”)
bin_w = reduce_2[0, 0]
#bin_w = 4
var_reduce_axis = tvm.reduce_axis((0, bin_w), name=“var_reduce_axis”)
res = tvm.compute((5, 16), lambda i1, i2:
tvm.max(fm[i1, var_reduce_axis, i2], axis=var_reduce_axis), name=“res”)
s = tvm.create_schedule(res.op)
red_ub = s.cache_read(reduce_end, “local.UB”, [reduce_2])
fm_ub = s.cache_read(fm, “local.UB”, [res])
print(tvm.lower(s, [reduce_end, fm, res], simple_mode=True))
and when i run python, i get the error as follows
raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [09:40:44] …/src/schedule/bound.cc:135: Check failed: it != rmap->end()
I think the main reason here is the reduce_axis, right end range is from a tensor, as follows
bin_w = reduce_2[0, 0]
#bin_w = 4
var_reduce_axis = tvm.reduce_axis((0, bin_w), name=“var_reduce_axis”)
Does anyone can give me some advise?