Compute_at for multiple outputs

  • A = placeholder((10, 20, 30), name = “A”)
  • B = min(A,axis=(0, 1, 2), name = “B”)
  • C = broadcast(B, (10, 20, 30), name = “C”)
  • D = A + C
  • sch = tvm.create_schedule(D.op)
  • ‘’‘cache_read,write’’’
  • sch[A].compute_at(sch[D], sch[D].op.axis[2])
  • sch[B].compute_at(sch[C], sch[C].op.reduce_axis[1])
  • sch[C].compute_at(sch[D], sch[D].op.axis[2])
  • (just a hint,sorry)
  • In my case, it has 4 stages A,B,C and D.A is just a input op, B is reduce op, C is a broadcast op and D is a add op.I make A compute_at in D(axis=2), B at C(axis=2,reduce_axis), and C at D(axis=2).But A’s buffer will allocate 102030 for B, not 10 expectantly.How to make A’s buffer allocate 10?