How to schedule concatenate by split

I’m working with classical wide-deep nn. There’s a concatenation in the network, like this:

import tvm
import topi

n = tvm.var("n")
A = tvm.placeholder((n,5), name='A')
B = tvm.placeholder((n,2), name='B')
result1 = topi.concatenate([A, B])

s = tvm.create_schedule(result1.op)
print(tvm.lower(s, [A, B, result1], simple_mode=True))

The default schedule result is:

produce compute {
  for (i0, 0, (n*2)) {
    for (i1, 0, 5) {
      compute[((i0*5) + i1)] = tvm_if_then_else((i0 < n), A[((i0*5) + i1)], B[(((i0*2) + i1) - (n*2))])
    }
  }
}

I want to schedule it like:

produce compute {
  for (i0, 0, 5) {
    // copy A
  }
  for (i0, 0, 2) {
    // copy B
  }
}

I want to split axis-1 by tensor (A and B), but the split api only supports factor, not a number list. Is there any way to do that? Thanks !

1 Like