Possible vthread+tenzorize bug

Hello everyone! I will keep the story short. I am working on the conv2d_int8 template with tensor cores support. Vthread’s usage leads to the broken kernels (wrong results) after tuning. I was able to narrow done the problem to the simplest use-case:

import tvm
# Simple algorithm [A[i] @ B for i in range(len(A))]
VIRTUAL_THREAD = 2
A_shape = (VIRTUAL_THREAD, 16, 16)
B_shape = (16, 16)

A = tvm.placeholder(A_shape, name='A', dtype='float16')
B = tvm.placeholder(B_shape, name='B', dtype='float16')

r = tvm.reduce_axis((0, 16), name='r')
C = tvm.compute((VIRTUAL_THREAD, 16, 16),
                   lambda vth, row, col: tvm.sum(
                       A[vth, row, r].astype("float32") * B[r, col].astype("float32"),
                       axis=[r]),
                   name="C")

s = tvm.create_schedule(C.op)
# print(tvm.lower(s, [A, B, C], simple_mode=True))

# Memory hierarchy
AF = s.cache_read(A, 'wmma.matrix_a', [C])
BF = s.cache_read(B, 'wmma.matrix_b', [C])
CF = s.cache_write(C, 'wmma.accumulator')


def intrin_wmma_load_matrix(scope):
    n = 16
    A = tvm.placeholder((n, n), name='A', dtype='float16')
    BA = tvm.decl_buffer(A.shape, A.dtype, scope='global', data_alignment=32, offset_factor=256)
    C = tvm.compute((n, n), lambda i, j: A[i, j], name='C')
    BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256)

    def intrin_func(ins, outs):
        ib = tvm.ir_builder.create()

        BA = ins[0]
        BC = outs[0]
        ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync',
                                BC.data, n, n, n, BC.elem_offset // 256,
                                BA.access_ptr('r'), n, 'row_major'))
        return ib.get()

    return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})


def intrin_wmma_gemm():
    n = 16
    A = tvm.placeholder((n, n), name='A', dtype='float16')
    B = tvm.placeholder((n, n), name='B', dtype='float16')
    k = tvm.reduce_axis((0, n), name="k")
    C = tvm.compute((n, n),
                    lambda ii, jj:
                    tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k),
                    name='C')
    BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256)
    BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256)
    BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256)

    def intrin_func(ins, outs):
        BA, BB = ins
        BC, = outs

        def init():
            ib = tvm.ir_builder.create()
            ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0))
            return ib.get()

        def update():
            ib = tvm.ir_builder.create()
            ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync',
                                    BC.data, BC.elem_offset // 256,
                                    BA.data, BA.elem_offset // 256,
                                    BB.data, BB.elem_offset // 256,
                                    BC.data, BC.elem_offset // 256))
            return ib.get()

        return update(), init(), update()

    return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})


def intrin_wmma_store_matrix():
    n = 16
    A = tvm.placeholder((n, n), name='A', dtype='float32')
    BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256)
    C = tvm.compute((n, n), lambda i, j: A[i, j], name='C')
    BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256)

    def intrin_func(ins, outs):
        ib = tvm.ir_builder.create()
        BA = ins[0]
        BC = outs[0]
        ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync',
                                BA.data, n, n, n, BA.elem_offset // 256,
                                BC.access_ptr('w'), n, 'row_major'))
        return ib.get()

    return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})


vth, row, col = C.op.axis
s[CF].compute_at(s[C], vth)
s[AF].compute_at(s[CF], CF.op.axis[-3])
s[BF].compute_at(s[CF], CF.op.axis[-3])
s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a'))
s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b'))
s[C].tensorize(row, intrin_wmma_store_matrix())
s[CF].tensorize(CF.op.axis[-2], intrin_wmma_gemm())
print(tvm.lower(s, [A, B, C], simple_mode=True))

s[C].bind(C.op.axis[-3], tvm.thread_axis('vthread'))
print(tvm.lower(s, [A, B, C], simple_mode=True))

So, the problem is - after vthread binding we got multiple load/store expressions for the same memory fragment and only one gemm call:

// attr [A.wmma.matrix_a] storage_scope = "wmma.matrix_a"
allocate A.wmma.matrix_a[float16 * 256]
// attr [B.wmma.matrix_b] storage_scope = "wmma.matrix_b"
allocate B.wmma.matrix_b[float16 * 256]
// attr [C.wmma.accumulator] storage_scope = "wmma.accumulator"
allocate C.wmma.accumulator[float32 * 256]
produce C {
  for (vth, 0, 2) {
    produce C.wmma.accumulator {
      produce A.wmma.matrix_a {
        tvm_load_matrix_sync(A.wmma.matrix_a, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), A, (vth*256), 256, 1), 16, "row_major")
      }
      produce B.wmma.matrix_b {
        tvm_load_matrix_sync(B.wmma.matrix_b, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), B, 0, 256, 1), 16, "row_major")
      }
      tvm_mma_sync(C.wmma.accumulator, 0, A.wmma.matrix_a, 0, B.wmma.matrix_b, 0, C.wmma.accumulator, 0)
    }
    tvm_store_matrix_sync(C.wmma.accumulator, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), C, (vth*256), 256, 2), 16, "row_major")
  }
}

turns into:

// attr [A.wmma.matrix_a] storage_scope = "wmma.matrix_a"
allocate A.wmma.matrix_a[float16 * 256]
// attr [B.wmma.matrix_b] storage_scope = "wmma.matrix_b"
allocate B.wmma.matrix_b[float16 * 256]
// attr [C.wmma.accumulator] storage_scope = "wmma.accumulator"
allocate C.wmma.accumulator[float32 * 256]
produce C {
  produce C.wmma.accumulator {
    produce A.wmma.matrix_a {
      tvm_load_matrix_sync(A.wmma.matrix_a, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), A, 0, 256, 1), 16, "row_major")
      tvm_load_matrix_sync(A.wmma.matrix_a, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), A, 256, 256, 1), 16, "row_major")
    }
    produce B.wmma.matrix_b {
      tvm_load_matrix_sync(B.wmma.matrix_b, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), B, 0, 256, 1), 16, "row_major")
    }
    tvm_mma_sync(C.wmma.accumulator, 0, A.wmma.matrix_a, 0, B.wmma.matrix_b, 0, C.wmma.accumulator, 0)
  }
  tvm_store_matrix_sync(C.wmma.accumulator, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), C, 0, 256, 2), 16, "row_major")
  tvm_store_matrix_sync(C.wmma.accumulator, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), C, 256, 256, 2), 16, "row_major")
}

I am curious, is it a bug, or am I doing something wrong? The pipeline seems to be somewhat close to the one used in the TOPI.

Please, note that for simplicity, I omitted threadIdx bindings, this behavior can be observed even in the simple mode. Intrinsics are taken from the TVM tutorial with a single memory scope change (shared->global).

I’m afraid it is a bug. I will find out whether it can be fixed.

BTW, can you please tell me why you need vthread here?

@Hzfengsy, thanks for the replay!

Really good question. I think they might be not needed for conv2d template, as every wmma call is synchronous across the warp. Hence, we can’t hide the latency with the vthread trick. On the other side, I may be overlooking something. So, I decided to relay on the benchmarks and check template with and without vthreads.

Please, let me know if I can help with this bug. I am somewhat new to the TVM and would love to get some guidelines and dive deeper into the code base.