# [SOLVED] Help! How to use fuse and split schedule in the same time?

#1

Hi, All:
Following is my simple test code:

``````def test_relu():
n = 128
factor = 64
A = tvm.placeholder((n, 2 ), name='A')
C = tvm.compute(A.shape, lambda *i: tvm.max(A(*i), tvm.const(0, A.dtype)), name='C')
s = tvm.create_schedule(C.op)

print("1st source code")
print(tvm.lower(s, [A, C], simple_mode = True))

fused_axis = s[C].fuse(*list(s[C].op.axis))

print("2nd source code")
print(tvm.lower(s, [A, C], simple_mode = True))

Co_iter, Ci_iter = s[C].split(fused_axis, factor = factor)

print("3rd source code")
print(tvm.lower(s, [A, C], simple_mode = True))

``````

console output:

``````1st source code
produce C {
for (i0, 0, 128) {
for (i1, 0, 2) {
C[((i0*2) + i1)] = max(A[((i0*2) + i1)], 0.000000f)
}
}
}

2nd source code
produce C {
for (i0.i1.fused, 0, 256) {
C[i0.i1.fused] = max(A[i0.i1.fused], 0.000000f)
}
}

3rd source code
produce C {
for (i0.i1.fused.outer, 0, 4) {
for (i0.i1.fused.inner, 0, 64) {
C[((((i0.i1.fused.outer*32) + (i0.i1.fused.inner/2))*2) + (i0.i1.fused.inner % 2))] = max(A[((((i0.i1.fused.outer*32) + (i0.i1.fused.inner/2))*2) + (i0.i1.fused.inner % 2))], 0.000000f)
}
}
}

``````

I expect to get following,

``````.
.
.

3rd source code
produce C {
for (i0.i1.fused.outer, 0, 4) {
for (i0.i1.fused.inner, 0, 64) {
C[((((i0.i1.fused.outer*64) + (i0.i1.fused.inner))))] = max(A[((((i0.i1.fused.outer*64) + (i0.i1.fused.inner))))], 0.000000f)
}
}
}

``````

It confused me why TVM got complicated representation ((i0.i1.fused.outer*32) + (i0.i1.fused.inner/2))*2) .

Thanks!

#2

I think the simplify of expression is done while passing these two phase: Simplify and CanonicalSimplify. You can read code of these part to firgure out your question.

#3