Generating transformer codes with TVM

I am interested in applying TVM to optimize BERT (especially transformer part).

Thanks to the blog post in TVM’s homepage (https://tvm.ai/2018/03/23/nmt-transformer-optimize) i could easily apply TVM to generate cuda code for the batch matmul operation.

I applied the reference code in the blog (https://github.com/Orion34C/tvm-batch-matmul-example/blob/master/tvm_batch_matmul_transpose_m1_kX.py ) and obtained the cuda code below.

extern "C" __global__ void batch_matmul_transpose_1_16_384_384_64_0213_kernel0( float* __restrict__ A,  float* __restrict__ B,  float* __restrict__ C) {
   float C_local[1];
  C_local[0] = 0.000000e+00f;
  for (int k = 0; k < 64; ++k) {
    C_local[0] = (C_local[0] + (A[((((((((int)blockIdx.y) * 8) + ((int)threadIdx.y)) % 16) * 24576) + ((((((int)blockIdx.y) * 8) + ((int)threadIdx.y)) / 16) * 64)) + k)] * B[(((((((((int)blockIdx.y) * 8) + ((int)threadIdx.y)) % 16) * 24576) + (k * 384)) + (((int)blockIdx.x) * 32)) + ((int)threadIdx.x))]));
  }
  C[((((((int)blockIdx.y) * 3072) + (((int)threadIdx.y) * 384)) + (((int)blockIdx.x) * 32)) + ((int)threadIdx.x))] = C_local[0];
}

But it seems like the code is not well optimized – for example, simple optimizations such as loop unrolling is not applied. The generated code is actually 10X slower than the batch matmul used in Tensorflow (1.14). I’m sure we are doing something wrong here.

Could you give us pointers as to how we could apply optimizations (such as machine learning guided code generation) to this code? (I mean how to make TVM to generate optimized code.)

Thanks!

This sample code is too outdated. If you are running on Volta or Turing GPU, it’s highly recommended to try out TensorCore codegen. Here is the sample code for batch matmul: https://github.com/minminsun/tvm/blob/17a08bb50f6874d50790fefa369d60f2a4b22d81/tutorials/autotvm/tune_tensor_core_batch_matmul.py

For your case, the cmd line is

python tune_tensor_core_batch_matmul.py 16 384 384 64

The generated kernel runs about 32us on V100. If you are running on CUDA10, it can be further reduced to 29us by changing the warp_tile_m from 16 to 32.

@@ -69,7 +69,7 @@ def test_gemm_nn(batch, N, L, M, dtype, layout):
     TY = 1
     tile_x = bx * TX
     tile_y = by * TY
-    WX = min(16, tile_x)
+    WX = min(32, tile_x)
     tile_k = 16