Optimizing matrix multiplication for GPU

And s = topi.cuda.batch_matmul.schedule_batch_matmul(R) is 280x slower than PyTorch.

    bsz = te.var('bsz')
    d1 = te.var('d1')
    d2 = te.var('d2')
    d3 = te.var('d3')
    A = te.placeholder((bsz, d1, d3), name='A', dtype='float32')  # first tensor
    B = te.placeholder((bsz, d2, d3), name='B', dtype='float32')  # second tensor
    R = topi.nn.batch_matmul(A, B)
    with tvm.target.cuda():
        s = topi.cuda.batch_matmul.schedule_batch_matmul(R)

Most of the ops so far are specialized for constant dimensions. So in this case we will need to put in constant dimensions(instead of te.var). Because batch_matmul itself is compute intensive, we might also benefit from AutoTVM, instead of using the default one.

Note that specializations necessary for this type of kernels, this is also what cudnn will do(which pytorch calls under the hood). The easiest way to make use such kernel is to generate a few candidate and and fall back to a normal kernel when the shape does not match. There are also more efforts in the community to work on automatic bucketing etc.

Thanks @tqchen, this was helpful. Using constants as follows:

    bsz = 24
    d1 = 16384
    d2 = 512
    d3 = 64
    A = te.placeholder((bsz, d1, d3), name='A', dtype='float32')  # first tensor
    B = te.placeholder((bsz, d2, d3), name='B', dtype='float32')  # second tensor
    R = topi.nn.batch_matmul(A, B)
    s = topi.cuda.batch_matmul.schedule_batch_matmul(R)

makes it 5x faster, which is still 3x slower than PyToch, but much better. So,

  1. How do I improve this further? use AutoTVM? I have tried it before for custom schedules that I implemented, but not sure how to do that for schedule_batch_matmul.

  2. More importantly, in my application, d2, d3 are constants, but bsz, d1 are not; bsz has 16 possible values, and d1 has 32 possible values, making it a total of 512 configuration. Does that mean I need to compile 512 different version of the code? there must be a better solution. Maybe the schedule can be updated to support range of values instead of constants?

others might be able to jump in to comment about running autotvm, which will certainly help.

The easiest way to make use such kernel is to generate a few candidate and and fall back to a normal kernel when the shape does not match. There are also more efforts in the community to work on automatic bucketing etc.

The above comment(about generating a few key ones and fallback to normal kernel) is a quick work around for the constant constraints. I agree that we will need better solutions for dynamic workloads, there are some ongoing efforts, see dynamic workload presos in https://tvmconf.org/#about-tvmconf

generating a few key ones and fallback to normal kernel

The 16 values for bsz, and 32 values for d1 are the key ones, e.g, d1 ranges from 1 to 16384, but I can pad it to multiples of 512.

@comaniac, do you know about this? how would you optimizer topi.cuda.batch_matmul.schedule_batch_matmul(R) using autotvm?

I just checked it out and unfortunately batch_matmul schedulefor CUDA does not have a tuning space (https://github.com/apache/incubator-tvm/blob/master/topi/python/topi/cuda/batch_matmul.py). A schedule with a tuning space should have a set of knobs defined in the schedule, such as this. You are very welcome to add a tuning space to topi.cuda.batch_matmul.schedule_batch_matmul.

After a tuning space has been created, you could follow this tutorial to register the schedule function as an AutoTVM task and tune it to see the performance improvement.

1 Like

Thanks @comaniac.

Do you know what params in schedule_batch_matmul need tuning? probably the 16s and the 64s?

Also, do you have a high level description of that schedule how it is splitting the workload?

Sorry I didn’t dive into batch_matmul schedule function so I have no idea. You could do some experiments and file a PR once you identified a tuning space.

I didn’t quite get your second question. Are you asking how define_split works?

sorry the question wasn’t clear. I was wondering if there’s documentation or high level description for the schedule (no comments in the code), but I guess the answer is no.

Sorry as you expected we don’t have that documents. All schedules were developed and committed by contributors. Since end-users are not supposed to read the schedule functions, how the schedule function was developed is not documented. The similar tutorials you can find is like this one.

1 Like

@comaniac, hopefully last question, do you know how to save multiple modules into one .so file? I want to save multiple versions of the function as @tqchen suggested, but tvm.runtime.export_library saves only one, and tvm.cc.create_shared doesn’t link the host code with the target code as tvm.runtime.export_library does. I am sure this can be done with gcc or something, but was wondering if tvm already has a solution for this given that it is a common use case.

AFAIK you can only keep one module in one .so file. If you want to keep multiple versions to deal with dynamic workloads, it’s better to wait for the corresponding support by @haichen. Otherwise, you can only integrate all versions to a single schedule function for now.

Given that it is not easy to get multiple versions of the function compiled together, I tried to see if schedule_batch_matmul can be changed to work for dynamic input sizes (at the expense of being slightly less efficient). It seem that this line and the one above can be changed to constants y_bn = x_bn = 64. However, when I do so, I get the following error:

TVMError: Check failed: condition_counter() == 0 (1 vs. 0) : Cannot insert syncs inside condition

The full codegen and the error are here. Any suggestions what could be wrong? my guess is that some of the scheduling primitives are still relying on constant input sizes, but couldn’t pin point which one. It is probably compute_at but I am sure I used compute_at with variable input sizes before, and it worked nicely. Any ideas?

It is actually possible to get multiple versions of function compiled together. Just need a bit of extra effort to do so.

One way to achieve that now is to call tvm.lower(instead of buld) to get List[LoweredFuncs] for each of the function you care about(give them different names), then concat these lists together to get one list, then feed it to tvm.build to get you a single runtime module which which will contain all the functions you need. Then we can call mod.export_library to export a single so file that contains all the function variants. This is the approach we use to export multiple functions in a neural networks into a single shared library

Also we might have another way(by importing other modules into a single one) which @FrozenGene recently enabled.

1 Like

This works nicely, and it is trivial to implement. Here’s a working example for reference:

import torch
import topi
import tvm
from tvm import te
from tvm.contrib import dlpack

def _codegen_function(d1, d2, name):
    bsz = te.var('bsz') # bsz and d3 can be variables without impact on performance 
    d3 = te.var('d3')   # but d1 and d2 should be constants for `schedule_batch_matmul` to work
    A = te.placeholder((bsz, d1, d3), name='A', dtype='float32')
    B = te.placeholder((bsz, d2, d3), name='B', dtype='float32')
    R = topi.nn.batch_matmul(A, B)
    s = topi.cuda.batch_matmul.schedule_batch_matmul(R)
    return tvm.lower(s, [A, B, R], name=name)

if __name__ == "__main__":
  bsz = 12
  d11 = 2048
  d12 = 1024
  d2 = 512
  d3 = 64

  #  2 different versions of the same function
  bmm1 = _codegen_function(d11, d2, 'bmm1') 
  bmm2 = _codegen_function(d12, d2, 'bmm2')

  # build both functions into one module
  module = tvm.build([bmm1, bmm2], target='cuda', target_host='llvm')

  module.export_library('libbmm.so')  # save the module into a .so file
  module = tvm.runtime.load_module('libbmm.so')  # load it back
  # get each function then package it as a pytorch function
  bmm1_pytorch = dlpack.to_pytorch_func(module['bmm1'])
  bmm2_pytorch = dlpack.to_pytorch_func(module['bmm2'])

  A1 = torch.randn(bsz, d11, d3, device='cuda')
  A2 = torch.randn(bsz, d12, d3, device='cuda')
  B = torch.randn(bsz, d2, d3, device='cuda')
  R1 = B.new_empty(bsz, d11, d2)  # allocate memory for the result tensor
  R2 = B.new_empty(bsz, d12, d2)  # allocate memory for the result tensor

  bmm1_pytorch(A1, B, R1)
  print(R1.sum())

  bmm2_pytorch(A2, B, R2)
  print(R2.sum())
1 Like

What was the final speed comparison with pytorch here?

It turned out that specializing with d1 and d2 as constants was the key part? (no need for autotvm?)

topi.cuda.batch_matmul.schedule_batch_matmul with constant d1 and d2 gave the best performance. It is still 3x slower than PyTroch thought.

topi.cuda.batch_matmul.schedule_batch_matmul is not instrumented with autotvm knobs, but I tried to change the few constants it has but that didn’t help.

Oh :frowning: yeah I was also able to get to around 3x with a lot of work, but that’s still really far from being practical.

I got very close to matching PyTorch’s bmm on Vega 20 (Radeon VII) and to about to 1.5x on 1080Ti for the 1024 example (with fixed dims). I’ll look into sending a PR for TOPI.

One of the limiting things on the path ahead is the “-1” issue in the output configurations of course.

Best regards

Thomas