[Discussion] New IR pass proposal: CombineParallelDense

Currently, there is an IR pass to combine parallel Conv2D ops. This is a kind of fusion strategy that reduces the number of kernels launched, improves data locality for those kernels, and reduces latency.

Models like BERT start each layer with three parallel branches that all utilize the same input and have the same sequence of operations. These branches start with a matmul op. By fusing theese matmuls, we can go from three sequential matrix multiplications of size (128,768) x (768,768) to one batch matrix multiplication of size (3,128,768) x (3,768,768). On GPU and multi-core CPU, this should provide significant speedup.

I propose to make an IR pass called CombineParallelDense, which will combine parallel Dense ops into one BatchMatMul op. This will be optionally followed by a single batch Add op if the “units” parameter of Dense is not null. This combine can be done in a very similar way to CombineParallelConv2D, and can even fuse the element-wise operations following it.

What do you think? How do you think the implementation should look so we don’t have to copy code between the two “combine” passes?

Here are two of the parallel branches that use the same input (the third branch is very far to the right in Netron, so it’s not shown here :slight_smile: )

4 Likes

@tqchen @jroesch do you have any thoughts?

perhaps @vinx13 can comment on this :slight_smile:

Also @haichen if you are interested

I like this idea, this can be helpful. Probably we can reuse part of CombineParallelConv2d code to combine the followed ops.

CombineParallelConv2d doesn’t always bring speed up, so maybe we should make this pass optional

2 Likes

Do you mean optional as in being part of optimization level 4?

Reusing the code should be doable. The tricky parts are:

  1. CombineParallelConv2D concatenates the input then splits output. CombineParallelDense will stack the input then slice the output.
  2. CombineParallelConv2D has some extra logic based on the size of the channel dimension. This doesn’t apply to CombineParallelDense, because the size of the matrix multiplications need to match exactly.

Maybe we can make an abstract CombineParallelOp class, and have the methods take in an optional argument map, so implementations like CombineParallelConv2D can take in the channel size as an argument.

Yes we can put it in level 4 and invoke it based on profile result (like autotvm).
It would be great if we can refactor CombineParallelConv2D pass and extract common parts between two passes

Cool, I can start working on that :grinning:

You mentioned using AutoTVM to decide whether or not to run this pass. Is there an example I can go off of? I thought that CombineParallelConv2D was always invoked if you choose opt level 4.

Also, how do you think I should handle combining the element-wise ops at the end? For example, normally, the output of shape (128,768) would be added with a bias tensor of shape (768). This 1D tensor would be broadcasted. However, when the output tensor is of shape (3,128,768), I don’t think we can properly broadcast-add a tensor of shape (3,768).

Unfortunately there are no examples right now. For the broadcast ops, we can pad ‘1’ to broadcasted dimension to make lhs and rhs have same number of dimensions, meaning that (768,) will be casted to (1, 768).

Sorry, would you be able to expand on that last part? I don’t quite understand. We basically have three different bias adds for each inner matrix in (3,128,768).

Do you mean the bias adds will be part of a new stacked tensor of shape (3,1,768)?

When you are combining two (768,) bias, we need to first canonicalize broadcast shapes into (1, 768) and then stack them into (3, 1, 768)

1 Like

I can understand the GPU part but could you explain a bit why it benefits a lot in multicore CPU case? Thanks.

I believe that using a library like MKL’s batch matrix multiplication can be more efficient than running multiple matrix multiplications in sequence. This depends on the problem size too.

Yeah, exactly, if the matrix size is large then it won’t benefit significant.

That’s one of the benefits of the pass system though :slight_smile: we can just turn this pass off if it doesn’t provide gains.

Yeah, I totally agree that this is useful.