Replacing extent of realize bound with non-const expression

I’m experimenting with some modification of how TVM generates code. What I’m trying to do is to initially generate realize (in ScheduleOps) in such a way that the extent of its bounds are non-const expressions (i.e., extent depends on loop variables above realize). Then, I’d like to apply LoopPartition afterwards to partition a loop (say, loop j) outside this realize. The result of the LoopPartition will be two (or more) loops each with its own copy of realize. Now with the new range of j, we can simplify the non-const expression in each realize extent to a constant value.

As an example, consider this python script:

   import tvm  
   N = 65 #tvm.var("N")
   A = tvm.placeholder((N,N), "float32", "A")   
   C = tvm.compute((N, N), lambda i,j: tvm.select(tvm.all(i>0, i < N-1),A[i,j] + A[i+1,j], tvm.const(0. , "float32")), name='C')
   s = tvm.create_schedule(C.op)
   yo, xo, yi, xi = s[C].tile(C.op.axis[0], C.op.axis[1], 16, 16)
   AA = s.cache_read(A, "local", [C])
   s[AA].compute_at(s[C], xo)
   print(tvm.lower(s, [A,C], simple_mode=True))

The output after phase 0 will be

// attr [compute(C, 0x1818f00)] realize_scope = ""
realize C([0, 65], [0, 65]) {
  produce C {
    for (i.outer, 0, 5) {
      for (j.outer, 0, 5) {
        // attr [compute(A.local, 0x18f4850)] realize_scope = "local"
        realize A.local([(i.outer*16), 17], [(j.outer*16), 16]) {
          produce A.local {
            for (ax0, 0, 17) {
              for (ax1, 0, 16) {
                if (likely(((ax0 + (i.outer*16)) < 65))) {
                  if (likely(((ax1 + (j.outer*16)) < 65))) {
                    A.local((ax0 + (i.outer*16)), (ax1 + (j.outer*16))) =A((ax0 + (i.outer*16)), (ax1 + (j.outer*16)))
                  }
                }
              }
            }
          }
          for (i.inner, 0, 16) {
            for (j.inner, 0, 16) {
              if (likely(((i.inner + (i.outer*16)) < 65))) {
                if (likely(((i.inner + (i.outer*16)) < 65))) {
                  if (likely(((j.inner + (j.outer*16)) < 65))) {
                    if (likely(((j.inner + (j.outer*16)) < 65))) {
                      C((i.inner + (i.outer*16)), (j.inner + (j.outer*16))) =select((((i.inner + (i.outer*16)) > 0) && ((i.inner + (i.outer*16)) < 64)), (A.local((i.inner + (i.outer*16)), (j.inner + (j.outer*16))) + A.local(((i.inner + (i.outer*16)) + 1), (j.inner + (j.outer*16)))), 0.000000f)
                    }
                  }
                }
              }
            }
          }
        }
      }
    }
  }
}

For realize of A.local, I would like to initially generate this:

        realize A.local([(i.outer*16), min(17,65-16*i.outer)], [(j.outer*16), min(16,65 - 16*j.outer)]) {
          produce A.local {
            for (ax0, 0, 17) {
              for (ax1, 0, 16) {
                if (likely(((ax0 + (i.outer*16)) < 65))) {
                  if (likely(((ax1 + (j.outer*16)) < 65))) {
                    A.local((ax0 + (i.outer*16)), (ax1 + (j.outer*16))) =A((ax0 + (i.outer*16)), (ax1 + (j.outer*16)))
                  }
                }
              }
            }

Note the new extents of the new realize. Rest is unchanged. The extents are calculated from the conditions of the if statements.

Now my questions:

  1. What is a good time to build these realizes? Is it more appropriate to modify them after they are initially created (i.e., after ScheduleOps) or to generate them like above from the beginning (i.e., in BuildRealize)?
  2. I started to implement it in BuildRealize as follows: I call MakeLoopNest and MakeBoundCheck in the beginning of BuildRealize to get the list of those predicates. I got stuck because the IterVars of the ComputeOpNode.axis were not the same ones used in those predicates. Therefore I can’t figure out which predicates to use for each axis of realize. I was thinking that if there was a many-to-one map from the IterVars used in the predicate to IterVars in ComputeOpNode.axis, I could finish the code but I couldn’t find such mapping and I am not even sure whether such a mapping is theoretically possible in general. Any advice is very much appreciated.

FYI, @xqdan @kun-zh

Cheers!

for 2nd question, the IterVar used to build predicates is from the value of value_map,

these values are fromconst std::unordered_map<IterVar, Range>& dom_map,

and dom_map is from schedule.InferBound(sch)
Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {