Tvm compute can implement cumsum


#1

Hi, I want to ask a question about how to do cumsum in tvm.compute. For example, we define a tensor X as follows.

X = tvm.placeholder((2,3), name=“X”)
cum = tvm.compute((2,3), lambda i,j,k: X[i,j,k]) ???

cumsum
[ [1,1,1], [2,2,2] ] --> [ [1,2,3], [2,4,6] ]

I want to implement a cumsum in the second axis, axis[1], of tensor X, how should I write tvm.compute.
I find the the tvm.scan in doc as follows:

The following code is equivalent to numpy.cumsum

  m = tvm.var("m")
  n = tvm.var("n")
  X = tvm.placeholder((m, n), name="X")
  s_state = tvm.placeholder((m, n))
  s_init = tvm.compute((1, n), lambda _, i: X[0, i])
  s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
  res = tvm.scan(s_init, s_update, s_state, X)

[[1,1,1], [2,2,2]] --tvm.scan-> [[1,1,1],[3,3,3]] ?

but it seems that tvm.scan can only do cumsum in the first axis, how should I do scan or cumsum on the other axises?