Seemingly incorrect codegen for scans with compute_at

Hi all,

I have a simple computation as follows

m = 100
n = 256
X = te.placeholder((m, n), name="X")
s_state = te.placeholder((m, n))
s_init = te.placeholder((1, n))
s_update = te.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i], name = 'update')
s_scan = te.scan(s_init, s_update, s_state, inputs=[X], name = 'scan')
c_out = te.compute((n,), lambda i: s_scan[m - 1, i] * 17)

s = te.create_schedule(c_out.op)
s[s_scan].compute_at(s[c_out], s[c_out].op.axis[0])
s[s_scan].set_scope("local")
s[c_out].bind(s[c_out].op.axis[0], te.thread_axis('threadIdx.x'))

print(tvm.lower(s, [X, c_out], simple_mode=True))

This generates the following IR on lowering

primfn(X_1: handle, compute_1: handle) -> ()
  attr = {"tir.noalias": True, "global_symbol": "main"}
  buffers = {X: Buffer(X_2: handle, float32, [100, 256], []),
             compute: Buffer(compute_2: handle, float32, [256], [])}
  buffer_map = {compute_1: compute, X_1: X} {
  attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 256;
  attr [scan: handle] "storage_scope" = "local";
  allocate(scan, float32, [100]) {
    for (scan.idx: int32, 0, 99) {
      for (i: int32, 0, 256) {
        scan[(((scan.idx + i) + 1) - threadIdx.x)] = ((float32*)scan[((scan.idx + i) - threadIdx.x)]) + (float32*)X_2[(((scan.idx*256) + i) + 256)]))
      }
    }
    compute_2[threadIdx.x] = ((float32*)scan[99])*17f32)
  }
}

I don’t understand why the inner i-loop in the scan operation is generated. This seemingly results in redundant recomputations. In this case, shouldn’t i be bound to threadIdx.x? If this is the intended behavior of compute_at, is there another way to generate IR similar to the following for the same computation? In any case, the generated IR also leads to negative indices when indexing into the scan tensor (when threadIdx.x > scan.idx + i + 1), which is clearly incorrect.

primfn(X_1: handle, compute_1: handle) -> ()
  attr = {"tir.noalias": True, "global_symbol": "main"}
  buffers = {X: Buffer(X_2: handle, float32, [100, 256], []),
             compute: Buffer(compute_2: handle, float32, [256], [])}
  buffer_map = {compute_1: compute, X_1: X} {
  attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 256;
  attr [scan: handle] "storage_scope" = "local";
  allocate(scan, float32, [1]) {
    for (scan.idx: int32, 0, 99) {
        scan[(scan.idx + 1)] = ((float32*)scan[scan.idx]) + (float32*)X_2[((scan.idx*256) + threadIdx.x)]))
    }
    compute_2[threadIdx.x] = ((float32*)scan[99])*17f32)
  }
}

I’m using TVM’s master on commit d8b185c674f56a71479cee3381c0655f1633d1d8 Thanks!