Relay Gradients and Pattern Language

Hello,

I have been toying around with the gradient relay transformation and wondered if I am doing something wrong to get a rather elaborate gradient:

gets transformed into:

I must admit that that is a bit more than I had hoped for…

Now I realize that symbolic differentiation is bound to create very complex graphs, but I can’t help but wonder whether I did something wrong. (And there might be optimization passes disconnecting some of it.)

I am grateful for any hint you might have.

So staring down this some more, the collapse_sum_like is, of course, the dual to broadcasting (which I should have recognized) and the zeros_like are likely from “summing up all uses” of a variable (even if it starts at 1) but is there a ready optimization pass eliminating those I don’t need (including the intermediate lets)? Edit: The pass for the much is GraphNormalForm.

Best regards

Thomas