[Codegen][CUDA] Bug in global barrier

The global barrier on cuda is implemented using shared memory and atomic add.

In this line each thread adds num_blocks to vid_global_barrier_expect_, which is a variable on shared memory. As result, vid_global_barrier_expect_ == num_blocks * num_threads_per_block after all threads finish this line. Since this is a variable on shared memory, I think this should use atomicAdd or atomicAdd_block?

@tqchen

Sorry I was a bit busy recently , would be great if you can also tag a few others who are familiar with cuda

Can you give an example test case that generates code through this path? I am not familiar with this as most kernels I have seen do not use this synchronization mechanism.

This is a test case

import tvm
def ir(in_buf, out_buf):
   ib = tvm.ir_builder.create()
   p_in = ib.buffer_ptr(in_buf)
   p_out = ib.buffer_ptr(out_buf)
   nthreads = 256
   nblocks = in_buf.shape[0] // nthreads
   bx = tvm.thread_axis("blockIdx.x")
   tx = tvm.thread_axis("threadIdx.x")
   ib.scope_attr(tx, "thread_extent", nthreads)
   ib.scope_attr(bx, "thread_extent", nblocks)
   ib.emit(tvm.make.Call(None, 'tvm_global_barrier_kinit', None, tvm.expr.Call.Intrinsic, None, 0))
   i = bx * nthreads + tx
   p_out[i] = p_in[i]
   ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
                                  tvm.convert(['global', True, nblocks]),
                                  tvm.expr.Call.Intrinsic, None, 0))
   return ib.get()

a = tvm.placeholder((1024,))
b = tvm.extern([a.shape], [a], lambda ins, outs: ir(ins[0], outs[0]))
s = tvm.create_schedule(b.op)
print(tvm.build(s, [a, b], 'cuda').imported_modules[0].get_source())

which produces

extern "C" __device__ unsigned __tvm_global_barrier_state;
extern "C" __global__ void default_function_kernel0( float* __restrict__ extern1,  float* __restrict__ placeholder) {
  __shared__ unsigned __barrier_expect;
  if (threadIdx.x == 0) {
    __barrier_expect = 0;
  }
  extern1[(((int)threadIdx.x) + (((int)blockIdx.x) * 256))] = placeholder[(((int)threadIdx.x) + (((int)blockIdx.x) * 256))];
  __threadfence_system();
  if ((bool)1) {
    atomicAdd(&__tvm_global_barrier_state, 1);
    volatile unsigned* pf = &__tvm_global_barrier_state;
    __barrier_expect += 4;
    while (pf[0] < __barrier_expect);
  }
  __syncthreads();
}

Does this issue come up in your proposal op PR?

Yes, I found some data race in argsort

@vinx you are right, without atomic this doesn’t make a global barrier

As an aside, I am confused about the meaning of global in the global barrier here. Is it supposed to be a true global barrier, as in all threads in all blocks?

If the second argument of tvm_storage_sync is True, all threads will be synchronized

If I understand correctly, the implementation strategy is just to wait on a counter until all threads are accounted for. I thought that CUDA provides no guarantees that this style of implementation does not deadlock, as nothing prevents the execution of blocks from being serialized. But my knowledge of CUDA could just be very stale, so if someone has updated info that would be great.

You are right, I indeed observed deadlock when I repeat many times

1 Like

but then the original code is not doing what it is supposed to do (global barrier hack), isn’t it? Should we remove it?

The original one is wrong as well. Global barrier is only used in https://github.com/dmlc/tvm/blob/master/topi/recipe/rnn/lstm.py, I’m not sure whether it works