Code generation when splitting by a non-factor value

I’ve noticed when splitting/tiling by a non-factor, if statements appear to get inserted at the innermost loops to handle the bound checks. Take this example:

def tile_nd(s, tensor, tile):
    outer_indices = []
    inner_indices = []
    for i, size in enumerate(tile):
        outer, inner = s[tensor].split(tensor.op.axis[i], size)
        outer_indices.append(outer)
        inner_indices.append(inner)

    s[tensor].reorder(*outer_indices, *inner_indices)
    return outer_indices, inner_indices


def small_graph():
    conv1_ifm = te.placeholder((1, 34, 34, 3), name='input')
    conv1_weights = te.placeholder((3, 3, 3, 32), name='weights_1')

    conv1 = topi.nn.conv2d_nhwc(conv1_ifm, conv1_weights, 1, 0, 1)
    return ([conv1],
            [conv1_ifm, conv1_weights])


def test_tile():
    graph, inputs = small_graph()
    out = graph[-1]
    s = te.create_schedule([out.op])
    tile_nd(s, graph[0], (1, 10, 10, 32))
    print(tvm.lower(s, inputs))

test_tile()

This returns

for (yy.outer: int32, 0, 4) {
      for (xx.outer: int32, 0, 4) {
        for (yy.inner: int32, 0, 10) {
          for (xx.inner: int32, 0, 10) {
            for (ff.inner: int32, 0, 32) {
              if (((yy.outer*10) + yy.inner) < 32) {
                if (((xx.outer*10) + xx.inner) < 32) {
                  Conv2dOutput[(((((yy.outer*10240) + (yy.inner*1024)) + (xx.outer*320)) + (xx.inner*32)) + ff.inner)] = 0f32
                }
              }
              for (ry: int32, 0, 3) {
                for (rx: int32, 0, 3) {
                  for (rc: int32, 0, 3) {
                    if (((yy.outer*10) + yy.inner) < 32) {
                      if (((xx.outer*10) + xx.inner) < 32) {
                        Conv2dOutput[(((((yy.outer*10240) + (yy.inner*1024)) + (xx.outer*320)) + (xx.inner*32)) + ff.inner)] = ((float32*)Conv2dOutput[(((((yy.outer*10240) + (yy.inner*1024)) + (xx.outer*320)) + (xx.inner*32)) + ff.inner)] + ((float32*)PaddedInput[(((((((yy.outer*1020) + (yy.inner*102)) + (ry*102)) + (xx.outer*30)) + (xx.inner*3)) + (rx*3)) + rc)]*(float32*)weights_1_2[((((ry*288) + (rx*96)) + (rc*32)) + ff.inner)]))
                      }
                    }
                  }
                }
              }
            }
          }
        }
      }
    }

Note that the if statements appear in the innermost reduction loop even though they only vary with the x and y tiles. This means they are checked far more often than they need to be. Is there a way the ifs could be hoisted to the outer loops to avoid all these unnecessary checks?

Hello,

I just found that. Apparently the PR is close to being merged.

I think in your case it wont help you because the comments in the PR seem to imply it only works for single if statements. Not sequential.

Is there a TIR pass to fuse those two if statements?

Thanks for pointing me at this. It looks like the right sort of thing but not yet sufficiently general. I’ll take a closer look at the code.

I’ve found an interesting alternative solution for this problem. There’s a TIR pass called LoopPartition which can split up loop nests into a set of new nests where the if constraints are trivially satisfied in each therefore allowing them to be removed. Initially I thought it wasn’t working properly, but it turns out to get it to work as expected there’s a build_config parameter that needs setting

with tvm.transform.PassContext(config={
    "tir.LoopPartition": {"partition_const_loop": True}
}):

By running with this I successfully removed all the ifs which was sufficient for my use case.

2 Likes

You are right. I had previous experience with this pass but had forgotten about it. Thanks for reminding me