I’ve noticed when splitting/tiling by a non-factor, if statements appear to get inserted at the innermost loops to handle the bound checks. Take this example:
def tile_nd(s, tensor, tile):
outer_indices = []
inner_indices = []
for i, size in enumerate(tile):
outer, inner = s[tensor].split(tensor.op.axis[i], size)
outer_indices.append(outer)
inner_indices.append(inner)
s[tensor].reorder(*outer_indices, *inner_indices)
return outer_indices, inner_indices
def small_graph():
conv1_ifm = te.placeholder((1, 34, 34, 3), name='input')
conv1_weights = te.placeholder((3, 3, 3, 32), name='weights_1')
conv1 = topi.nn.conv2d_nhwc(conv1_ifm, conv1_weights, 1, 0, 1)
return ([conv1],
[conv1_ifm, conv1_weights])
def test_tile():
graph, inputs = small_graph()
out = graph[-1]
s = te.create_schedule([out.op])
tile_nd(s, graph[0], (1, 10, 10, 32))
print(tvm.lower(s, inputs))
test_tile()
This returns
for (yy.outer: int32, 0, 4) {
for (xx.outer: int32, 0, 4) {
for (yy.inner: int32, 0, 10) {
for (xx.inner: int32, 0, 10) {
for (ff.inner: int32, 0, 32) {
if (((yy.outer*10) + yy.inner) < 32) {
if (((xx.outer*10) + xx.inner) < 32) {
Conv2dOutput[(((((yy.outer*10240) + (yy.inner*1024)) + (xx.outer*320)) + (xx.inner*32)) + ff.inner)] = 0f32
}
}
for (ry: int32, 0, 3) {
for (rx: int32, 0, 3) {
for (rc: int32, 0, 3) {
if (((yy.outer*10) + yy.inner) < 32) {
if (((xx.outer*10) + xx.inner) < 32) {
Conv2dOutput[(((((yy.outer*10240) + (yy.inner*1024)) + (xx.outer*320)) + (xx.inner*32)) + ff.inner)] = ((float32*)Conv2dOutput[(((((yy.outer*10240) + (yy.inner*1024)) + (xx.outer*320)) + (xx.inner*32)) + ff.inner)] + ((float32*)PaddedInput[(((((((yy.outer*1020) + (yy.inner*102)) + (ry*102)) + (xx.outer*30)) + (xx.inner*3)) + (rx*3)) + rc)]*(float32*)weights_1_2[((((ry*288) + (rx*96)) + (rc*32)) + ff.inner)]))
}
}
}
}
}
}
}
}
}
}
Note that the if statements appear in the innermost reduction loop even though they only vary with the x and y tiles. This means they are checked far more often than they need to be. Is there a way the ifs could be hoisted to the outer loops to avoid all these unnecessary checks?