Cache_write does not replace vars in reducer identity


#1

IR mutator does not visit combiner of reducer node https://github.com/dmlc/tvm/blob/3516cbe0049c7e11ee58afbc668acddb1f110ece/src/pass/ir_mutator.cc#L397.
If the identity value of comm reducer depends on outer axis, it throws some error like “Not all Vars are passed in api_args: ‘i’ does not appeared in api_args”

For example,

import tvm

N = 16

data = tvm.placeholder((N,))
def _f(i):
    min_value = lambda _: i.astype('float32')
    max = tvm.comm_reducer(lambda x, y: tvm.make._OpMax(x, y), min_value, name='max')
    r = tvm.reduce_axis((0, N))

    return max(data[r] + i, axis=[r])

c = tvm.compute((N,), _f)
s = tvm.create_schedule(c.op)
s.cache_write(c, 'local')

print(tvm.lower(s, [data, c], simple_mode=True))
tvm.build(s, [data, c], 'llvm')


s = tvm.create_schedule(c.op)
s.cache_write(c, 'local')

print(tvm.lower(s, [data, c], simple_mode=True))
tvm.build(s, [data, c], 'llvm')

i in compute.local is not replaced with i.c

// attr [compute.local] storage_scope = "local"
allocate compute.local[float32 * 16]
produce compute.local {
  for (i.c, 0, 16) {
    compute.local[i.c] = float32(i)
    for (rv, 0, 16) {
      compute.local[i.c] = max(compute.local[i.c], (placeholder[rv] + float32(i.c)))
    }
  }
}
produce compute {
  for (i, 0, 16) {
    compute[i] = compute.local[i]
  }
}

cc @tqchen


#2

@vinx13 can you try to fix this problem?


#3

followup discussions in https://github.com/dmlc/tvm/pull/2354