I found some broken cases for higher order AD when dealing with if or tuple expressions.
import tvm
import tvm.relay as relay
x = relay.var("x", shape=(1, 16, 64, 64))
y = relay.var("y", shape=(1, 16, 64, 64))
cond = relay.var("cond", shape=(), dtype='uint1')
net = relay.If(cond, x, y)
net = relay.log(net)
net = relay.ir_pass.infer_type(relay.Function(relay.ir_pass.free_vars(net), net))
back_func = relay.ir_pass.infer_type(relay.ir_pass.gradient(net, mode='higher_order'))
Another case:
import tvm
import tvm.relay as relay
x = relay.var("x", shape=(1, 16, 64, 64))
y = relay.var("y", shape=(1, 16, 64, 64))
net = relay.Tuple([x, y])
net = relay.ir_pass.infer_type(relay.Function(relay.ir_pass.free_vars(net), net))
back_func = relay.ir_pass.infer_type(relay.ir_pass.gradient(net, mode='higher_order'))
The problem is, ReverseAD
transform every expression to (forward, ref_reverse) form. In call expressions, there is a explicit TupleGetItem to get the first field in the tuple as the argument. However, this is not handled in other cases such If
or Tuple
. As a result, the tuple is used as the condition for If
which causes error in type inference.