Maybe a bug related to Scan


In TVM recurrent computing, if any further calculating is done with the tensor returned by tvm.scan , TVM will crash in lowering stage with segment fault error.

Script to reproduce error (copied from tutorial, only added a transpose on scan result tensor) :

m = tvm.var("m")
n = tvm.var("n")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
s_scan = tvm.scan(s_init, s_update, s_state, inputs=[X])
res = topi.transpose(s_scan, (1,0))

s = tvm.create_schedule(res.op)
print(tvm.lower(s, [X, res], simple_mode=True))

And this is the information in coredump, don’t know whether it’s helpful:

Program terminated with signal SIGSEGV, Segmentation fault.
#0  tvm::NodePtr<tvm::Node>::NodePtr (other=..., this=0x258f290) at /home/jian/repositories/tvm/3rdparty/HalideIR/src/tvm/node/node_base.h:106
106           : NodePtr(other.data_) {