[SOLVED] Help! How to use fuse and split schedule in the same time?

Hi, All:
Following is my simple test code:

def test_relu():
    n = 128
    factor = 64
    A = tvm.placeholder((n, 2 ), name='A')
    C = tvm.compute(A.shape, lambda *i: tvm.max(A(*i), tvm.const(0, A.dtype)), name='C')
    s = tvm.create_schedule(C.op)

    print("1st source code")
    print(tvm.lower(s, [A, C], simple_mode = True))

    fused_axis = s[C].fuse(*list(s[C].op.axis))

    print("2nd source code")
    print(tvm.lower(s, [A, C], simple_mode = True))

    Co_iter, Ci_iter = s[C].split(fused_axis, factor = factor)

    print("3rd source code")
    print(tvm.lower(s, [A, C], simple_mode = True))

console output:

1st source code
produce C {
  for (i0, 0, 128) {
    for (i1, 0, 2) {
      C[((i0*2) + i1)] = max(A[((i0*2) + i1)], 0.000000f)
    }
  }
}

2nd source code
produce C {
  for (i0.i1.fused, 0, 256) {
    C[i0.i1.fused] = max(A[i0.i1.fused], 0.000000f)
  }
}

3rd source code
produce C {
  for (i0.i1.fused.outer, 0, 4) {
    for (i0.i1.fused.inner, 0, 64) {
      C[((((i0.i1.fused.outer*32) + (i0.i1.fused.inner/2))*2) + (i0.i1.fused.inner % 2))] = max(A[((((i0.i1.fused.outer*32) + (i0.i1.fused.inner/2))*2) + (i0.i1.fused.inner % 2))], 0.000000f)
    }
  }
}

I expect to get following,

.
.
.

3rd source code
produce C {
  for (i0.i1.fused.outer, 0, 4) {
    for (i0.i1.fused.inner, 0, 64) {
      C[((((i0.i1.fused.outer*64) + (i0.i1.fused.inner))))] = max(A[((((i0.i1.fused.outer*64) + (i0.i1.fused.inner))))], 0.000000f)
    }
  }
}

It confused me why TVM got complicated representation ((i0.i1.fused.outer*32) + (i0.i1.fused.inner/2))*2) .

Thanks!

I think the simplify of expression is done while passing these two phase: Simplify and CanonicalSimplify. You can read code of these part to firgure out your question.

It is already solved. Please update your code to the latest version.

1 Like

:smiling_face_with_three_hearts: