Understanding virtual threads


#1

I am trying to understand the concept and use of virtual threads. The GPU convolution tutorial shows how to use vthreads to avoid shared memory conflict. But can I achieve the same goal w/o using vthreads?

The following simple test case uses tiling to get three levels of loop nesting. I have three versions of scheduling. In all three versions, the outermost level of loops are bound to blockIdx and the innermost ones are bound to threadIdx. The first version interchanges the middle level and innermost level, so that all the outer loops are bounded. The second version just leaves the middle level unbounded. The last version binds the middle level to virtual threads.

All three versions generate practically the same CUDA code.

‘virtual threads’ seems an important concept and tool in TVM. I can create a tutorial on this topic if you can help me understand it first :).

Thanks.

Yuan


import tvm

def show_cuda(s, A, B):
    ctx = tvm.context("cuda", 0)
    with tvm.build_config(dump_pass_ir=True) as cfg:
        func = tvm.build(s, [A, B], target="cuda", name='test')
    print(func.imported_modules[0].get_source())

M = 7*5*2
N = 9*3*2

A = tvm.placeholder((M,N), name='A')
B = tvm.compute((M,N), lambda i,j: 3.14 * A[i,j], name='B')

block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
vthread_x = tvm.thread_axis("vthread", name="vx")
vthread_y = tvm.thread_axis("vthread", name="vy")
thread_x = tvm.thread_axis("threadIdx.x")
thread_y = tvm.thread_axis("threadIdx.y")

#
# Schedule 1: Manual loop interchange
#
#    [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 7
#    [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 9
#    [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 5
#    [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 3
#    for (i.inner.outer, 0, 2) {
#      for (j.inner.outer, 0, 2) {
#
s = tvm.create_schedule(B.op)

Mblock = 5*2
Nblock = 3*2

i_outer, j_outer, i_inner, j_inner = s[B].tile(B.op.axis[0], B.op.axis[1], Mblock, Nblock)
i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner = s[B].tile(i_inner, j_inner, 5, 3)
s[B].reorder(i_outer, j_outer, i_inner_inner, j_inner_inner, i_inner_outer, j_inner_outer)

s[B].bind(i_outer, block_y)
s[B].bind(j_outer, block_x)
s[B].bind(i_inner_inner, thread_y)
s[B].bind(j_inner_inner, thread_x)

show_cuda(s, A, B)


#
# Schedule 2: No loop interchange
#
#    [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 7
#    [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 9
#    for (i.inner.outer, 0, 2) {
#      for (j.inner.outer, 0, 2) {
#        [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 5
#        [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 3
#
s = tvm.create_schedule(B.op)

Mblock = 5*2
Nblock = 3*2

i_outer, j_outer, i_inner, j_inner = s[B].tile(B.op.axis[0], B.op.axis[1], Mblock, Nblock)
i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner = s[B].tile(i_inner, j_inner, 5, 3)

s[B].bind(i_outer, block_y)
s[B].bind(j_outer, block_x)
s[B].bind(i_inner_inner, thread_y)
s[B].bind(j_inner_inner, thread_x)

show_cuda(s, A, B)


#
# Schedule 3: use virtual threads
#
#    [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 7
#    [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 9
#    [iter_var(vy, , vthread)] virtual_thread = 2
#    [iter_var(vx, , vthread)] virtual_thread = 2
#    [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 5
#    [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 3
#
s = tvm.create_schedule(B.op)

Mblock = 5*2
Nblock = 3*2

i_outer, j_outer, i_inner, j_inner = s[B].tile(B.op.axis[0], B.op.axis[1], Mblock, Nblock)
i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner = s[B].tile(i_inner, j_inner, 5, 3)

s[B].bind(i_outer, block_y)
s[B].bind(j_outer, block_x)
s[B].bind(i_inner_outer, vthread_y)
s[B].bind(j_inner_outer, vthread_x)
s[B].bind(i_inner_inner, thread_y)
s[B].bind(j_inner_inner, thread_x)

show_cuda(s, A, B)

#2

vthread’s definition is quite straightforward: we create inner-most serial loops to simulate concurrent execution of the threads. Because vthread executes in the same thread, the vthread lowering will perform optimization to detect sharable computation among different vthread and only compute once.

Such compound effect is useful to create shared stridded access patterns such as those in gemm


#3

Tianqi, thanks for your explanation. Would loop splitting with interchange not achieve the effect, as shown by ‘schedule 1’ in my sample code above?


#4

I was lurking this thread because I also had a question about vthreads (in the special case of the VTA).
In the VTA tec report there is a whole subsection about latency hiding using virtual threads.

Also checking the source code of inject_virtual_thread


Requires that the axis be labelled as “vthread”.

In the VTA tutorial

So the label is “cthread”

  1. Are all “cthreads” also virtual threads?
  2. Why is it in this case better to define this axis as “cthread” and not “vthread”?

Also, I have a question concerning this part of the VTA environment

  1. Can I interpret this as: “the coprocessor sync function for the VTA architecture is an external call to the VTASynchronize routine”?
  2. Why isn’t the VTASynchronize (also VTADepPush and VTADepPop) routine inlined in the output of the tvm.lower() routine?
// attr [res_conv] storage_scope = "local.acc_buffer"
// attr [data_buf] storage_scope = "local.inp_buffer"
// attr [kernel_buf] storage_scope = "local.wgt_buffer"
produce res {
/* There are a lot of lines here */
}
vta.coproc_sync()