I am facing a similar problem.
A = te.placeholder((64, 64), name="A")
B = te.placeholder((3, 3), name="B")
r0 = te.reduce_axis((0, 3), name="r0")
r1 = te.reduce_axis((0, 3), name="r1")
C1 = te.compute(
(64, 64),
lambda i, j: te.sum(A[i + r0, j + r1] * B[r0, r1], axis=[r0, r1]),
name="C1",
)
r0 = te.reduce_axis((0, 3), name="r0")
r1 = te.reduce_axis((0, 3), name="r1")
C2 = te.compute(
(64, 64),
lambda i, j: te.sum(A[i + r0, j + r1], axis=[r0, r1]),
name="C2",
)
C3 = te.compute(
(64, 64),
lambda i, j: C1[i, j] * 2 + C2[i, j],
name="C3"
)
s = te.create_schedule(C3.op)
s[C1].compute_at(s[C3], C3.op.axis[1])
s[C2].compute_at(s[C3], C3.op.axis[1])
print(tvm.lower(s, [A, B, C3], simple_mode=True))
# s[C1].compute_at(s[C2], C2.op.reduce_axis[1])
# print(tvm.lower(s, [A, B, C3], simple_mode=True)) # fail
Out:
for (i: int32, 0, 64) {
for (j: int32, 0, 64) {
C1_1: Buffer(C1, float32, [1], [], align=4)[0] = 0f32
for (r0: int32, 0, 3) {
for (r1: int32, 0, 3) {
C1_1[0] = (C1_1[0] + (A_3: Buffer(A_2, float32, [4096], [])[((((i*64) + (r0*64)) + j) + r1)]*B_3: Buffer(B_2, float32, [9], [])[((r0*3) + r1)]))
}
}
C2_1: Buffer(C2, float32, [1], [], align=4)[0] = 0f32
for (r0_1: int32, 0, 3) {
for (r1_1: int32, 0, 3) {
C2_1[0] = (C2_1[0] + A_3[((((i*64) + (r0_1*64)) + j) + r1_1)])
}
}
C3_3: Buffer(C3_2, float32, [4096], [])[((i*64) + j)] = ((C1_1[0]*2f32) + C2_1[0])
}
}
I want to further optimize the schedule by combining the two reduction loops, like:
for (i: int32, 0, 64) {
for (j: int32, 0, 64) {
C1_1: Buffer(C1, float32, [1], [], align=4)[0] = 0f32
C2_1: Buffer(C2, float32, [1], [], align=4)[0] = 0f32
for (r0: int32, 0, 3) {
for (r1: int32, 0, 3) {
C1_1[0] = (C1_1[0] + (A_3: Buffer(A_2, float32, [4096], [])[((((i*64) + (r0*64)) + j) + r1)]*B_3: Buffer(B_2, float32, [9], [])[((r0*3) + r1)]))
C2_1[0] = (C2_1[0] + A_3[((((i*64) + (r0_1*64)) + j) + r1_1)])
}
}
C3_3: Buffer(C3_2, float32, [4096], [])[((i*64) + j)] = ((C1_1[0]*2f32) + C2_1[0])
}
}
How may I approach this?