Bind loops of different range to one thread axis, i.e. num threads for load larger than num threads for compute

Hello,

I am trying to write a two-phase schedule:

In phase 1, I want to load padded data into shared mem with padding. In this stage I would like to have more threads participate in the load than for computing in phase 2

In phase 2, I do computations using the padding.

When I lower this code, all the sizes and the loops are as I want them until I bind to thread axis. The problem is that I have two loops with different ranges, because the first loop includes the padding and the computation loop is over the valid range.

Omitting some details, the IR looks like this:

  produce Apad.shared {
    for (ax2, 0, 260) {
      for (ax1, 0, 8) {
        if (likely((1 <= ((blockIdx.y*4) + ax1)))) {
          if (likely((1 <= ((blockIdx.x*256) + ax2)))) {
             //..... load into shared
          }
        }
      }
    }
  }
  for (i.inner, 0, 256) {
    for (j.inner, 0, 4) {
        //compute
    }
  }

I have a thread axis of size 260 and would like to map both loops to it. But if I bind both ax2 as well i.inner, I get domain already inferred, cannot prove their extents are the same 256 vs 260. Is there a way to do this?

Thanks