This works nicely, and it is trivial to implement. Here’s a working example for reference:
import torch
import topi
import tvm
from tvm import te
from tvm.contrib import dlpack
def _codegen_function(d1, d2, name):
bsz = te.var('bsz') # bsz and d3 can be variables without impact on performance
d3 = te.var('d3') # but d1 and d2 should be constants for `schedule_batch_matmul` to work
A = te.placeholder((bsz, d1, d3), name='A', dtype='float32')
B = te.placeholder((bsz, d2, d3), name='B', dtype='float32')
R = topi.nn.batch_matmul(A, B)
s = topi.cuda.batch_matmul.schedule_batch_matmul(R)
return tvm.lower(s, [A, B, R], name=name)
if __name__ == "__main__":
bsz = 12
d11 = 2048
d12 = 1024
d2 = 512
d3 = 64
# 2 different versions of the same function
bmm1 = _codegen_function(d11, d2, 'bmm1')
bmm2 = _codegen_function(d12, d2, 'bmm2')
# build both functions into one module
module = tvm.build([bmm1, bmm2], target='cuda', target_host='llvm')
module.export_library('libbmm.so') # save the module into a .so file
module = tvm.runtime.load_module('libbmm.so') # load it back
# get each function then package it as a pytorch function
bmm1_pytorch = dlpack.to_pytorch_func(module['bmm1'])
bmm2_pytorch = dlpack.to_pytorch_func(module['bmm2'])
A1 = torch.randn(bsz, d11, d3, device='cuda')
A2 = torch.randn(bsz, d12, d3, device='cuda')
B = torch.randn(bsz, d2, d3, device='cuda')
R1 = B.new_empty(bsz, d11, d2) # allocate memory for the result tensor
R2 = B.new_empty(bsz, d12, d2) # allocate memory for the result tensor
bmm1_pytorch(A1, B, R1)
print(R1.sum())
bmm2_pytorch(A2, B, R2)
print(R2.sum())