[DISCUSS] pass for merging shape tensors

Hello,

I recently stumbled over the fact that reshape is typically hard for TVM’s common subexpression elimination pass to work with. This is because the target shape (which also comes in the attrs) can be a distinct (even if equal) tensor. In particular, converting reshape from, say, PyTorch, we have that all shape tensors are separate. Fusing these (second inputs to reshape, provided they’re constant, same device(?), same shape, same values) helps eliminate those.

My main use case is self-attention in transformers.

The pass I came up with in Python looks like this:

class ShapeConstDedupMutator(tvm.relay.ExprMutator):
    def __init__(self):
        super().__init__()
        self.shape_consts = {}

    def visit_call(self, call):
        if (isinstance(call.op, tvm.ir.Op) and call.op.name == "reshape"
            and isinstance(call.args[1], tvm.relay.Constant)):
            assert list(call.attrs.newshape) == list(call.args[1].data.asnumpy())
            new_fn = self.visit(call.op)
            new_args = [self.visit(arg) for arg in call.args]
            const = new_args[1]
            assert const.data.dtype.startswith('int') and len(const.data.shape)==1
            key = tuple(const.data.asnumpy())
            if key in self.shape_consts:
                new_args[1] = self.shape_consts[key]
            else:
                self.shape_consts[key] = new_args[1]
            return tvm.relay.Call(new_fn, new_args, call.attrs)
        return super().visit_call(call)

@tvm.relay.transform.function_pass(opt_level=1)
def ShapeConstDedup(fn, mod, ctx):
    return ShapeConstDedupMutator().visit(fn)

new_mod = ShapeConstDedup(new_mod)
new_mod = tvm.relay.transform.EliminateCommonSubexpr()(new_mod)

Before I convert this to C++ submit a PR would this be of enough general interest to add to the TVM standard passes?

An alternative to doing this separately can be to adjust the eliminate common subexpression logic to allow same const-input shapes to be merged. (Maybe this is even preferable, I would love to hear your input on it.)

One of the reasons I’m not proposing to merge all same-value consts is that it my experience is that can be touchy in other parts if suddenly all consts “1” are the same thing.

Best regards

Thomas

Given that we are having a systematic discussion about A0/A1 approaches to dynamic, perhaps we can revisit the case once we transition the reshape to the new convention – hopefully soon. The current workloads looks good

Indeed, I’d wait for that.