As discussed here, there are many implementations for the exact same core computation. To avoid writing a ton of patterns to satisfy all of these implementations, it would be nice to have a way to “always match” certain operators in a pattern, even if they aren’t present in the user-defined pattern.
In other words, I propose a way to match the core computation, without having to match all of the extraneous data-mutation operators like
I am trying to match a transformer, and am using the HuggingFace Transformer exported to ONNX. The start of the transformer, from the ONNX perspective, looks like
MatMul -> Add. However, after importing to TVM, the ONNX frontend does a bunch of data mutation. This is to account for broadcasting and the fact that TVM does matrix multiplication as
(m,k) x (n,k), where ONNX does matrix multiplication as
(m,k) x (k,n). This means that the Relay expression becomes
Reshape -> Reshape -> Transpose -> MatMul -> Reshape -> Add.
Different frontends may handle the reshape / transposes differently, but they will all do the core computation of
MatMul -> Add. To avoid writing a ton of patterns, I would like an option to always consider these reshape and transpose operators as a match and just skip over them. In this case, the core
MatMul -> Add pattern will always match, and any operators in between will be merged into the composite function.
The MergeComposite pass will take an extra parameter: a list of Relay ops to “always match”. The
ExtractPattern function will continue moving the pattern “in lockstep” with the root. When the root and pattern differ, and the root’s op is in the “always match” list, the root node will move, while the pattern node will stay in place. The
ExtractPattern function will only return an empty expression if the root and pattern nodes differ and the root node’s op is not in the “always match” list.
Please comment with other ideas and implementation suggestions