About tensor core tutorials

I tried to use tvm to generate tensor core code and I followed https://tvm.apache.org/docs/tutorials/optimize/opt_matmul_auto_tensorcore.html

I use the latest version of TVM on NVIDIA V100 with CUDA 10.1

the output kernel code doesn’t use tensor core wmma.

I have changed N M L and layout. The output is followed.

Best config:
[('bx', 2), ('by', 64), ('step_k', 4), ('v', 8)],None,105
Finish loading 1376 records
primfn(A_1: handle, B_1: handle, compute_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {compute: Buffer(compute_2: handle, float32, [4096, 4096], []),
             B: Buffer(B_2: handle, float16, [4096, 4096], []),
             A: Buffer(A_2: handle, float16, [4096, 4096], [])}
  buffer_map = {A_1: A, B_1: B, compute_1: compute} {
  attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 64;
  attr [compute.local: handle] "storage_scope" = "local";
  allocate(compute.local, float32, [8]);
  attr [A.shared: handle] "storage_scope" = "shared";
  allocate(A.shared, float16, [4608]);
  attr [B.shared: handle] "storage_scope" = "shared";
  allocate(B.shared, float16, [1024]);
  attr [A.shared.local: handle] "storage_scope" = "local";
  allocate(A.shared.local, float16, [16]);
  attr [B.shared.local: handle] "storage_scope" = "local";
  allocate(B.shared.local, float16, [128]);
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 256;
  attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
  attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 64;
  attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 2 {
    for (j.c.init: int32, 0, 8) {
      compute.local[j.c.init] = 0f32
    }
    attr [IterVar(k.outer: int32, (nullptr), "CommReduce", "")] "pragma_tensor_core" = 1;
    for (k.outer, 0, 64) {
      for (ax0.ax1.outer.fused.outer: int32, 0, 4) {
        attr [IterVar(threadIdx.y_1: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 64;
        attr [IterVar(threadIdx.z_1: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
        attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 2;
        A.shared[ramp(((((ax0.ax1.outer.fused.outer*1152) + (floordiv(threadIdx.y_1, 4)*72)) + (floormod(threadIdx.y_1, 4)*16)) + (threadIdx.x_1*8)), 1, 8)] = (float16x8*)A_2[ramp(((((((blockIdx.y*262144) + (ax0.ax1.outer.fused.outer*65536)) + (floordiv(threadIdx.y_1, 4)*4096)) + (k.outer*64)) + (floormod(threadIdx.y_1, 4)*16)) + (threadIdx.x_1*8)), 1, 8)])
      }
      attr [IterVar(threadIdx.y_2: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 64;
      attr [IterVar(threadIdx.z_2: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
      attr [IterVar(threadIdx.x_2: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 2;
      B.shared[ramp(((threadIdx.y_2*16) + (threadIdx.x_2*8)), 1, 8)] = (float16x8*)B_2[ramp((((((blockIdx.x*65536) + (floordiv(threadIdx.y_2, 4)*4096)) + (k.outer*64)) + (floormod(threadIdx.y_2, 4)*16)) + (threadIdx.x_2*8)), 1, 8)])
      for (k.inner.outer: int32, 0, 4) {
        for (ax1: int32, 0, 16) {
          A.shared.local[ax1] = (float16*)A.shared[(((threadIdx.y*72) + (k.inner.outer*16)) + ax1)])
        }
        for (ax0: int32, 0, 8) {
          for (ax1_1: int32, 0, 16) {
            B.shared.local[((ax0*16) + ax1_1)] = (float16*)B.shared[((((threadIdx.x*512) + (ax0*64)) + (k.inner.outer*16)) + ax1_1)])
          }
        }
        for (k.inner.inner: int32, 0, 16) {
          for (j.c: int32, 0, 8) {
            compute.local[j.c] = ((float32*)compute.local[j.c]) + (cast(float32, (float16*)A.shared.local[k.inner.inner]))*cast(float32, (float16*)B.shared.local[((j.c*16) + k.inner.inner)]))))
          }
        }
      }
    }
    for (j.inner.inner.inner: int32, 0, 8) {
      compute_2[(((((blockIdx.y*262144) + (threadIdx.y*4096)) + (blockIdx.x*16)) + (threadIdx.x*8)) + j.inner.inner.inner)] = (float32*)compute.local[j.inner.inner.inner])
    }
  }
}

// meta data omitted. you can use show_meta_data=True to include meta data
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
#include <cuda_fp16.h>
__device__ half max(half a, half b)
{
  return __hgt(__half(a), __half(b)) ? a : b;
}
__device__ half min(half a, half b)
{
  return __hlt(__half(a), __half(b)) ? a : b;
}
#else

typedef unsigned short uint16_t;
typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef int int32_t;
typedef unsigned long long uint64_t;
typedef unsigned int uint32_t;

#define TVM_FORCE_INLINE inline __attribute__((always_inline))
#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__
#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))
#define TVM_HALF_OPERATOR(RTYPE, OP)                              \
  TVM_XINLINE RTYPE operator OP (half a, half b) {                \
    return RTYPE(float(a) OP float(b));                           \
  }                                                               \
  template<typename T>                                            \
  TVM_XINLINE RTYPE operator OP (half a, T b) {                   \
    return RTYPE(float(a) OP float(b));                           \
  }                                                               \
  template<typename T>                                            \
  TVM_XINLINE RTYPE operator OP (T a, half b) {                   \
    return RTYPE(float(a) OP float(b));                           \
  }

#define TVM_HALF_ASSIGNOP(AOP, OP)                                \
  template<typename T>                                            \
  TVM_XINLINE half operator AOP (const T& a) {                    \
    return *this = half(float(*this) OP float(a));                \
  }                                                               \
  template<typename T>                                            \
  TVM_XINLINE half operator AOP (const volatile T& a) volatile {  \
    return *this = half(float(*this) OP float(a));                \
  }

class TVM_ALIGNED(2) half {
 public:
  uint16_t half_;

  static TVM_XINLINE half Binary(uint16_t value) {
    half res;
    res.half_ = value;
    return res;
  }

  TVM_XINLINE half() {}

  TVM_XINLINE half(const float& value) { constructor(value); }
  TVM_XINLINE explicit half(const double& value) { constructor(value); }
  TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }
  TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
  TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
  TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
  TVM_XINLINE explicit half(const long long& value) { constructor(value); }
  TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }

  TVM_XINLINE operator float() const {                          \
    return float(half2float(half_));                            \
  }                                                             \
  TVM_XINLINE operator float() const volatile {                 \
    return float(half2float(half_));                            \
  }


  TVM_HALF_ASSIGNOP(+=, +)
  TVM_HALF_ASSIGNOP(-=, -)
  TVM_HALF_ASSIGNOP(*=, *)
  TVM_HALF_ASSIGNOP(/=, /)

  TVM_XINLINE half operator+() {
    return *this;
  }

  TVM_XINLINE half operator-() {
    return half(-float(*this));
  }

  TVM_XINLINE half operator=(const half& a) {
    half_ = a.half_;
    return a;
  }

  template<typename T>
  TVM_XINLINE half operator=(const T& a) {
    return *this = half(a);
  }

  TVM_XINLINE half operator=(const half& a) volatile {
    half_ = a.half_;
    return a;
  }

  template<typename T>
  TVM_XINLINE half operator=(const T& a) volatile {
    return *this = half(a);
  }

 private:
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static int const fp16FractionBits = 10;
  static int const fp32FractionBits = 23;
  static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits);   // == 0x7fffff
  static int32_t const fp32HiddenBit = 1 << fp32FractionBits;   // == 0x800000
  static int const shift = fp32FractionBits - fp16FractionBits;   // == 13
  static int const shiftSign = 16;
  static int32_t const expAdjust = 127 - 15;   // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)

  static int32_t const infN = 0x7F800000;   // flt32 infinity
  static int32_t const maxN = 0x477FFFFF;   // max flt32 that's a flt16 normal after >> by shift
  static int32_t const minN = 0x38800000;   // min flt16 normal as a flt32
  static int32_t const maxZ = 0x33000000;   // max fp32 number that's still rounded to zero in fp16
  static int32_t const signN = 0x80000000;  // flt32 sign bit

  static int32_t const infC = infN >> shift;
  static int32_t const nanN = (infC + 1) << shift;   // minimum flt16 nan as a flt32
  static int32_t const maxC = maxN >> shift;
  static int32_t const minC = minN >> shift;
  static int32_t const signC = signN >> shiftSign;  // flt16 sign bit

  static int32_t const mulN = 0x52000000;  // (1 << 23) / minN
  static int32_t const mulC = 0x33800000;  // minN / (1 << (23 - shift))

  static int32_t const subC = 0x003FF;  // max flt32 subnormal down shifted
  static int32_t const norC = 0x00400;  // min flt32 normal down shifted

  static int32_t const maxD = infC - maxC - 1;
  static int32_t const minD = minC - subC - 1;

  TVM_XINLINE uint16_t float2half(const float& value) const {
    Bits v;
    v.f = value;
    uint32_t sign = v.si & signN;    // grab sign bit
    v.si ^= sign;                    // clear sign bit from v
    sign >>= shiftSign;              // logical shift sign to fp16 position

    if (v.si <= maxZ) {
      // Handle eventual zeros here to ensure
      // vshift will not exceed 32 below.
      v.ui = 0;
    } else if (v.si < minN) {
      // Handle denorms
      uint32_t exp32 = v.ui >> fp32FractionBits;
      int32_t exp16 = exp32 - expAdjust;
      // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
      // Smaller (so negative) exp16 values should result in greater right shifts.
      uint32_t vshift = 1 - exp16;
      uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
      v.ui = significand >> vshift;
      v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
    } else if (v.si <= maxN) {
      // Handle norms
      v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
      v.ui -= expAdjust << fp32FractionBits;
    } else if (v.si <= infN) {
      v.si = infN;
    } else if (v.si < nanN) {
      v.si = nanN;
    }

    v.ui >>= shift;
    return sign | (v.ui & 0x7fff);
  }

  // Same as above routine, except for addition of volatile keyword
  TVM_XINLINE uint16_t float2half(
    const volatile float& value) const volatile {
    Bits v;
    v.f = value;
    uint32_t sign = v.si & signN;    // grab sign bit
    v.si ^= sign;                    // clear sign bit from v
    sign >>= shiftSign;              // logical shift sign to fp16 position

    if (v.si <= maxZ) {
      // Handle eventual zeros here to ensure
      // vshift will not exceed 32 below.
      v.ui = 0;
    } else if (v.si < minN) {
      // Handle denorms
      uint32_t exp32 = v.ui >> fp32FractionBits;
      int32_t exp16 = exp32 - expAdjust;
      // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
      // Smaller (so negative) exp16 values should result in greater right shifts.
      uint32_t vshift = 1 - exp16;
      uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
      v.ui = significand >> vshift;
      v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
    } else if (v.si <= maxN) {
      // Handle norms
      v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
      v.ui -= expAdjust << fp32FractionBits;
    } else if (v.si <= infN) {
      v.si = infN;
    } else if (v.si < nanN) {
      v.si = nanN;
    }

    v.ui >>= shift;
    return sign | (v.ui & 0x7fff);
  }

  TVM_XINLINE float half2float(const uint16_t& value) const {
    Bits v;
    v.ui = value;
    int32_t sign = v.si & signC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;
  }

  TVM_XINLINE float half2float(
    const volatile uint16_t& value) const volatile {
    Bits v;
    v.ui = value;
    int32_t sign = v.si & signC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;
  }

  template<typename T>
  TVM_XINLINE void constructor(const T& value) {
    half_ = float2half(float(value));
  }
};

TVM_HALF_OPERATOR(half, +)
TVM_HALF_OPERATOR(half, -)
TVM_HALF_OPERATOR(half, *)
TVM_HALF_OPERATOR(half, /)
TVM_HALF_OPERATOR(bool, >)
TVM_HALF_OPERATOR(bool, <)
TVM_HALF_OPERATOR(bool, >=)
TVM_HALF_OPERATOR(bool, <=)

TVM_XINLINE half __float2half_rn(const float a) {
  return half(a);
}
#endif


// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
  return (v1 << 16) | v0;
}
extern "C" __global__ void default_function_kernel0(void* __restrict__ A, void* __restrict__ B, void* __restrict__ compute) {
  float compute_local[8];
  __shared__ half A_shared[4608];
  __shared__ half B_shared[1024];
  half A_shared_local[16];
  half B_shared_local[128];
  for (int j_c_init = 0; j_c_init < 8; ++j_c_init) {
    compute_local[(j_c_init)] = 0.000000e+00f;
  }
  for (int k_outer = 0; k_outer < 64; ++k_outer) {
    __syncthreads();
    for (int ax0_ax1_outer_fused_outer = 0; ax0_ax1_outer_fused_outer < 4; ++ax0_ax1_outer_fused_outer) {
      ((uint4*)(A_shared + (((((ax0_ax1_outer_fused_outer * 1152) + ((((int)threadIdx.y) >> 2) * 72)) + ((((int)threadIdx.y) & 3) * 16)) + (((int)threadIdx.x) * 8)))))[0] = ((uint4*)((half*)A + (((((((((int)blockIdx.y) * 262144) + (ax0_ax1_outer_fused_outer * 65536)) + ((((int)threadIdx.y) >> 2) * 4096)) + (k_outer * 64)) + ((((int)threadIdx.y) & 3) * 16)) + (((int)threadIdx.x) * 8)))))[0];
    }
    ((uint4*)(B_shared + (((((int)threadIdx.y) * 16) + (((int)threadIdx.x) * 8)))))[0] = ((uint4*)((half*)B + ((((((((int)blockIdx.x) * 65536) + ((((int)threadIdx.y) >> 2) * 4096)) + (k_outer * 64)) + ((((int)threadIdx.y) & 3) * 16)) + (((int)threadIdx.x) * 8)))))[0];
    __syncthreads();
    for (int k_inner_outer = 0; k_inner_outer < 4; ++k_inner_outer) {
      for (int ax1 = 0; ax1 < 16; ++ax1) {
        A_shared_local[(ax1)] = A_shared[((((((int)threadIdx.y) * 72) + (k_inner_outer * 16)) + ax1))];
      }
      for (int ax0 = 0; ax0 < 8; ++ax0) {
        for (int ax11 = 0; ax11 < 16; ++ax11) {
          B_shared_local[(((ax0 * 16) + ax11))] = B_shared[(((((((int)threadIdx.x) * 512) + (ax0 * 64)) + (k_inner_outer * 16)) + ax11))];
        }
      }
      for (int k_inner_inner = 0; k_inner_inner < 16; ++k_inner_inner) {
        for (int j_c = 0; j_c < 8; ++j_c) {
          compute_local[(j_c)] = (compute_local[(j_c)] + (((float)A_shared_local[(k_inner_inner)]) * ((float)B_shared_local[(((j_c * 16) + k_inner_inner))])));
        }
      }
    }
  }
  for (int j_inner_inner_inner = 0; j_inner_inner_inner < 8; ++j_inner_inner_inner) {
    ((float*)compute)[((((((((int)blockIdx.y) * 262144) + (((int)threadIdx.y) * 4096)) + (((int)blockIdx.x) * 16)) + (((int)threadIdx.x) * 8)) + j_inner_inner_inner))] = compute_local[(j_inner_inner_inner)];
  }
}

Time cost of this operator: 0.045881