How to seperate one stage into two stages in TE?

With the VTA’s convolution optimization tutorial, I can see the lowered IR as follow after virtual threading

// attr [data_buf] storage_scope = "global"
allocate data_buf[int8 * 65536]
// attr [kernel_buf] storage_scope = "global"
allocate kernel_buf[int8 * 589824]
// attr [res_conv] storage_scope = "global"
allocate res_conv[int32 * 25088]
produce data_buf {
  for (i1, 0, 16) {
    for (i2, 0, 16) {
      for (i3, 0, 16) {
        for (i5, 0, 16) {
          data_buf[((((i1*4096) + (i2*256)) + (i3*16)) + i5)] = tvm_if_then_else(((((1 <= i2) && (i2 < 15)) && (1 <= i3)) && (i3 < 15)), data[(((((i1*3136) + (i2*224)) + (i3*16)) + i5) - 240)], (int8)0)
        }
      }
    }
  }
}
produce kernel_buf {
  for (i0, 0, 16) {
    for (i1, 0, 16) {
      for (i2, 0, 3) {
        for (i3, 0, 3) {
          for (i4, 0, 16) {
            for (i5, 0, 16) {
              kernel_buf[((((((i0*36864) + (i1*2304)) + (i2*768)) + (i3*256)) + (i4*16)) + i5)] = kernel[((((((i0*36864) + (i1*2304)) + (i2*768)) + (i3*256)) + (i4*16)) + i5)]
            }
          }
        }
      }
    }
  }
}
produce res {
  for (i2.outer, 0, 2) {
    produce res_conv {
      for (co.init, 0, 8) {
        for (i.init, 0, 7) {
          for (j.init, 0, 14) {
            for (ci.init, 0, 16) {
              res_conv[((((co.init*1568) + (i.init*224)) + (j.init*16)) + ci.init)] = 0
              res_conv[(((((co.init*1568) + (i.init*224)) + (j.init*16)) + ci.init) + 12544)] = 0
            }
          }
        }
      }
      for (ic.outer, 0, 16) {
        for (co, 0, 8) {
          for (i, 0, 7) {
            for (dy, 0, 3) {
              for (dx, 0, 3) {
                for (j, 0, 14) {
                  for (ci, 0, 16) {
                    for (ic_tns, 0, 16) {
                      res_conv[((((co*1568) + (i*224)) + (j*16)) + ci)] = (res_conv[((((co*1568) + (i*224)) + (j*16)) + ci)] + (int32(data_buf[(((((((ic.outer*4096) + (i2.outer*1792)) + (i*256)) + (dy*256)) + (j*16)) + (dx*16)) + ic_tns)])*int32(kernel_buf[((((((co*36864) + (ic.outer*2304)) + (dy*768)) + (dx*256)) + (ci*16)) + ic_tns)])))
                      res_conv[(((((co*1568) + (i*224)) + (j*16)) + ci) + 12544)] = (res_conv[(((((co*1568) + (i*224)) + (j*16)) + ci) + 12544)] + (int32(data_buf[(((((((ic.outer*4096) + (i2.outer*1792)) + (i*256)) + (dy*256)) + (j*16)) + (dx*16)) + ic_tns)])*int32(kernel_buf[(((((((co*36864) + (ic.outer*2304)) + (dy*768)) + (dx*256)) + (ci*16)) + ic_tns) + 294912)])))
                    }
                  }
                }
              }
            }
          }
        }
      }
    }
    produce res_shr {
      for (i1, 0, 8) {
        for (i2, 0, 7) {
          for (i3, 0, 14) {
            for (i5, 0, 16) {
              res_conv[((((i1*1568) + (i2*224)) + (i3*16)) + i5)] = shift_right(res_conv[((((i1*1568) + (i2*224)) + (i3*16)) + i5)], 8)
              res_conv[(((((i1*1568) + (i2*224)) + (i3*16)) + i5) + 12544)] = shift_right(res_conv[(((((i1*1568) + (i2*224)) + (i3*16)) + i5) + 12544)], 8)
            }
          }
        }
      }
    }
    produce res_max {
      for (i1, 0, 8) {
        for (i2, 0, 7) {
          for (i3, 0, 14) {
            for (i5, 0, 16) {
              res_conv[((((i1*1568) + (i2*224)) + (i3*16)) + i5)] = max(res_conv[((((i1*1568) + (i2*224)) + (i3*16)) + i5)], 0)
              res_conv[(((((i1*1568) + (i2*224)) + (i3*16)) + i5) + 12544)] = max(res_conv[(((((i1*1568) + (i2*224)) + (i3*16)) + i5) + 12544)], 0)
            }
          }
        }
      }
    }
    produce res_min {
      for (i1, 0, 8) {
        for (i2, 0, 7) {
          for (i3, 0, 14) {
            for (i5, 0, 16) {
              res_conv[((((i1*1568) + (i2*224)) + (i3*16)) + i5)] = min(res_conv[((((i1*1568) + (i2*224)) + (i3*16)) + i5)], 127)
              res_conv[(((((i1*1568) + (i2*224)) + (i3*16)) + i5) + 12544)] = min(res_conv[(((((i1*1568) + (i2*224)) + (i3*16)) + i5) + 12544)], 127)
            }
          }
        }
      }
    }
    for (i1.inner, 0, 8) {
      for (i2.inner, 0, 7) {
        for (i3.inner, 0, 14) {
          for (i5, 0, 16) {
            res[(((((i1.inner*3136) + (i2.outer*1568)) + (i2.inner*224)) + (i3.inner*16)) + i5)] = int8(res_conv[((((i1.inner*1568) + (i2.inner*224)) + (i3.inner*16)) + i5)])
            res[((((((i1.inner*3136) + (i2.outer*1568)) + (i2.inner*224)) + (i3.inner*16)) + i5) + 25088)] = int8(res_conv[(((((i1.inner*1568) + (i2.inner*224)) + (i3.inner*16)) + i5) + 12544)])
          }
        }
      }
    }
  }
}

I’d like to seperate the last stage into two stages as follow

    for (i1.inner, 0, 8) {
      for (i2.inner, 0, 7) {
        for (i3.inner, 0, 14) {
          for (i5, 0, 16) {
            res[(((((i1.inner*3136) + (i2.outer*1568)) + (i2.inner*224)) + (i3.inner*16)) + i5)] = int8(res_conv[((((i1.inner*1568) + (i2.inner*224)) + (i3.inner*16)) + i5)])
          }
        }
      }
    }

    for (i1.inner, 0, 8) {
      for (i2.inner, 0, 7) {
        for (i3.inner, 0, 14) {
          for (i5, 0, 16) {
            res[((((((i1.inner*3136) + (i2.outer*1568)) + (i2.inner*224)) + (i3.inner*16)) + i5) + 25088)] = int8(res_conv[(((((i1.inner*1568) + (i2.inner*224)) + (i3.inner*16)) + i5) + 12544)])
          }
        }
      }
    }

Is there any ways to transform like this using TE APIs?