[RFC] Wildcard Pattern Matching in MergeComposite Pass

Problem

As discussed here, there are many implementations for the exact same core computation. To avoid writing a ton of patterns to satisfy all of these implementations, it would be nice to have a way to “always match” certain operators in a pattern, even if they aren’t present in the user-defined pattern.

In other words, I propose a way to match the core computation, without having to match all of the extraneous data-mutation operators like reshape, transpose, cast, etc…

Example Use-Case

I am trying to match a transformer, and am using the HuggingFace Transformer exported to ONNX. The start of the transformer, from the ONNX perspective, looks like MatMul -> Add. However, after importing to TVM, the ONNX frontend does a bunch of data mutation. This is to account for broadcasting and the fact that TVM does matrix multiplication as (m,k) x (n,k), where ONNX does matrix multiplication as (m,k) x (k,n). This means that the Relay expression becomes Reshape -> Reshape -> Transpose -> MatMul -> Reshape -> Add.

Different frontends may handle the reshape / transposes differently, but they will all do the core computation of MatMul -> Add. To avoid writing a ton of patterns, I would like an option to always consider these reshape and transpose operators as a match and just skip over them. In this case, the core MatMul -> Add pattern will always match, and any operators in between will be merged into the composite function.

Solution

I am only just starting to look at the implementation of MergeComposite, so I would look to @mbaret and @comaniac for suggestions. However, one solution could be the following:

The MergeComposite pass will take an extra parameter: a list of Relay ops to “always match”. The ExtractPattern function will continue moving the pattern “in lockstep” with the root. When the root and pattern differ, and the root’s op is in the “always match” list, the root node will move, while the pattern node will stay in place. The ExtractPattern function will only return an empty expression if the root and pattern nodes differ and the root node’s op is not in the “always match” list.

Please comment with other ideas and implementation suggestions :slight_smile:

2 Likes

The first thing that occurs to me is that all the nodes before the MatMul shouldn’t matter, so I think we’re only interested in what happens between the MatMul and the Add (in this case a reshape). The thing I find a bit surprising is that a reshape is inserted between these two. Could you print out a Relay snippet to illustrate this case?

The reason I mention this is because I’m not sure it’s valid in the general case to ignore all reshapes/transposes. The accelerated operation you map to will presumably have a behaviour that expects a particular intermediate format between the MatMul and the Add. I don’t think it will be accurate to map to it no matter what layout transforms happen.

Thanks for the quick response, Matt. Here’s a Relay example:

// start of transformer layer 1
  %15 = reshape(%14, newshape=[-1, 128, 768]) /* ty=Tensor[(1, 128, 768), float32] */;
  %16 = reshape(%v1617, newshape=[-1, 768, 768]) /* ty=Tensor[(1, 768, 768), float32] */;
  %17 = transpose(%16, axes=[0, 2, 1]) /* ty=Tensor[(1, 768, 768), float32] */;
  %18 = nn.batch_matmul(%15, %17) /* ty=Tensor[(1, 128, 768), float32] */;
  %19 = reshape(%18, newshape=[1, 128, 768]) /* ty=Tensor[(1, 128, 768), float32] */;
  %20 = add(%19, %bert.encoder.layer.0.attention.self.query.bias) /* ty=Tensor[(1, 128, 768), float32] */;

This comes from this code in the MatMul op converter in the ONNX frontend:

if len(a_shape) > 2:
            b_shape = infer_shape(inputs[1])
            # Convert a and b into 3 dimensional tensors.
            a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]])
            b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]])
            # Broadcast b to match batch size of a
            new_b_shape = list(infer_shape(b))
            new_a_shape = infer_shape(a)
            if new_a_shape[0] > new_b_shape[0]:
                new_b_shape[0] = new_a_shape[0]
                b = _op.broadcast_to(b, new_b_shape)
            # Transpose matrix dimensions of b.
            b = _op.transpose(b, [0, 2, 1])
            # Perform a batch matmul.
            output = _op.nn.batch_matmul(a, b)
            # Reshape output to original dimensions.
            return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])

This transformations are required for a Relay model to be valid, but don’t really matter for the purpose of understanding the computation.

Let me know if you want me to attach the .onnx file and the pattern I am working with!

I’ve just taken a quick look in the tensorflow frontend, and it looks to lower it to Relay in a similar way (with an output reshape). The mxnet frontend does create a matmul without a following reshape, but it looks to be implementing a different operator. I’m led to think that the reshape after the matmul is a fundamental part of the computation (ie. if you removed it, the output is going to be different).

Can I ask what your mapping this pattern to? And additionally, would that mapping still be valid for an arbitrary series of reshapes/transposes between the matmul and the add, or does it imply a particular intermediate tensor layout?

If you take a look at the original BERT model in ONNX, there is no reshape. The reshape is added by Relay to ensure consistency.

Specifically, this pattern is the first part of a single layer of a transformer. I would like to use a pattern to plug in something like Nvidia’s FasterTransformer.

I feel that we should support the “always match” functionality and leave it up to the user to decide when these data-transformation operators are necessary. For example, in hand-written, optimized code, you don’t actually perform a reshape. Rather, you just use some pointer arithmetic to index the array where you want. The point is that the computation is the same.

Plus, the pattern for the transformer is so long and specific that I doubt we will accidentally match something that is not really a transformer.

@mbaret @comaniac I also found a bug during my transformer pattern testing. I found that if the same call node showed up multiple times in a pattern, it was getting newly processed every time, resulting in the composite function having duplicate nodes. A fix is here: https://github.com/apache/incubator-tvm/pull/4879

1 Like

I might need a bit of help understanding this. For instance, a pattern I match is Conv2d -> Bias_Add -> Requantize because that maps to a single call to an ACL (Arm Compute Library) convolution operator. Now in theory I could throw a reshape in there so it’s Conv2d -> Reshape -> Bias_Add -> Requantize, but this is a fundamentally different computation and can’t just be cut and replaced by the ACL convolution operator. It’s true that all the information is still there and I could just do a different traversal, but ACL doesn’t know that.

The pattern you want to match is MatMul -> Add but with both of those being the ONNX versions of those operators (not Relay). The semantics of the MatMul ONNX operator are represented in Relay with nn.batch_matmul plus some transforms, if those transforms weren’t there it wouldn’t perform the correct operation.

If a different frontend implements MatMul without these transforms or in a different way, then fundamentally it’s not executing an equivalent operation to ONNX. This might be perfectly valid, it’s probably just the case that MatMul can have different behaviour in different frameworks. But I think that also makes it valid to pattern match differently for different frameworks. If you want to match patterns with different functional behaviour, those patterns should be themselves different.

It might be helpful if you can link the ONNX model so I can get a slightly higher level understanding of this problem, because it sounds like the pattern you’re matching is actually quite a bit larger than simply this MatMul -> Add block. And apologies if I’m completely missing the point here :slight_smile: