[Relay] Higher order AD broken in some cases

I found some broken cases for higher order AD when dealing with if or tuple expressions.

import tvm
import tvm.relay as relay
x = relay.var("x", shape=(1, 16, 64, 64))
y = relay.var("y", shape=(1, 16, 64, 64))
cond = relay.var("cond", shape=(), dtype='uint1')
net = relay.If(cond, x, y)
net = relay.log(net)
net = relay.ir_pass.infer_type(relay.Function(relay.ir_pass.free_vars(net), net))
back_func = relay.ir_pass.infer_type(relay.ir_pass.gradient(net, mode='higher_order'))

Another case:

import tvm
import tvm.relay as relay
x = relay.var("x", shape=(1, 16, 64, 64))
y = relay.var("y", shape=(1, 16, 64, 64))
net = relay.Tuple([x, y])
net = relay.ir_pass.infer_type(relay.Function(relay.ir_pass.free_vars(net), net))
back_func = relay.ir_pass.infer_type(relay.ir_pass.gradient(net, mode='higher_order'))

The problem is, ReverseAD transform every expression to (forward, ref_reverse) form. In call expressions, there is a explicit TupleGetItem to get the first field in the tuple as the argument. However, this is not handled in other cases such If or Tuple. As a result, the tuple is used as the condition for If which causes error in type inference.

@MarisaKirisame

@tqchen could you help with this case?

I know how to fix it, and I will in a few days.

Sorry for the long reply, I forgot to check discuss.

Would be great if we can also document the cause and proposed fix, this way everyone can learn and improve the codebase together

I agree. However, I am rethinking the AD design to expose more optimization oppotunity. I will do it once the new design is settled.

1 Like

The zeroth case is an error, and I am fixing it rn.
The first case require a bit more thought: I am thinking of making an extensible AD interface so we can support other kind of data type, and we can use that to support tuple. In the original design the AD interface take a whole program from tuple of tensor to tensor.

Here are another case broken

import tvm
import tvm.relay as relay

acc = relay.Var("acc", relay.TensorType(shape=(1, 16)))
x = relay.Var("x", relay.TensorType(shape=(1, 16)))
w = relay.Var("w", relay.TensorType(shape=(1, 16)))
fn = relay.Function([acc, x], relay.add(acc, relay.multiply(x, w)))
m = relay.module.Module()
prelude = relay.prelude.Prelude(m)

xs = relay.Var("xs", prelude.l(relay.TensorType(shape=(1, 16))))
init = relay.zeros((1, 16), dtype='float32')
F = prelude.foldl(fn, init, xs)
F = relay.Function([w, xs], F)
main_ = relay.GlobalVar('main')
m[main_] = F
print(m[main_])

F = relay.ir_pass.gradient(F, m)
m[main_] = F
print(F)

Iā€™m trying to fold over a list of data. It will transform xs: list[Tensor] into list[Tensor, ref] and cause error in type infer.
In this case, only gradients of weights are needed, instead of input data.

@jroesch Error annotation seems also buggy in this case. It prints TVMError: Error(s) have occurred. We have annotated the program with them: but nothing is annotated

In this case, only gradients of weights are needed, instead of input data.

If the gradient of input data is not used, dead code elimination should remove them. the dce rn dont respect effect and I am trying to get time to fix that.

I will look at the program when I have time.

1 Like

@vinx13 it is not broken. the ad algorithm is designed to only take in tensor arguments.

we might add support for other datatype (ADT/Tuple) in the future. I have a rough idea on how to do so.

1 Like