How to integrate handcrafted CUDA device function to TVM

Hello TVM developers:

I am recently working on CUDA-related optimization and code generator on TVM stack.

As we know, for GPU application, memory accesses are the most common and important performance constraints.

TVM already has CUDA code generator and scheduler for us, which is great.

I am thinking about for some applications, CUDA codes generated by TVM is not optimal especially for the piece of data transfer part. I would like to replace memory-access-related part with handcrafted CUDA device function. handcrafted CUDA codes provide more freedom and we can change the order of threads and make sure sequential accesses are adjacent. (I.e.

shared_mem[thread.id+offset_local] = global_mem[thread.id + offset_global]

). This handcrafted coalescing access pattern should be more efficient.

For instance.

We have CUDA kernel generated by TVM

extern "C" __global__ void kernel(void* __restrict__ Input, float* __restrict__ Output){
__shared__ float shared_mem[A];

....
// code generated by TVM 
// load data to shared memory
shared_mem[thread_x] = Input[y]

//do computation based on data in shared_mem, and store back global memory

Output = shared_mem[thread_x] * 123
}

And what I want to do is to replace the automatic generated part

shared_mem[thread_x] = Input[y]

with

handcrafted_read(shared_mem,Input)

where handcrafted_read is defined as

__device__ void handcrafted_read(float* shared_buffer,void* data_to_be_read)
{
    // handcrafted codes here 
}

Could you give me some tips/directions how to implement this?

Thanks a lot!

@tqchen @jcf94 @FrozenGene

1 Like

TVM provides Intrinsics support that can help users to generate low-level function call. In my understanding, this should work for you.

However it seems the intrinsic math tutorial may not give enough examples on how to create your own C functions.

See if @tqchen @merrymercy have more comments.

Hi: Thanks for your tips! I just quickly checked intrinsic math tutorial and Use Tensorize to Leverage Hardware Intrinsics. It seems that the second one matches my purpose better.

In most cases, the generated CUDA codes are not just simple line of reading. For loops and indexes computation are also involved. For instance:

for i in range(M):
  for j in range(N):
    shared_mem[f(thread_id,i,j)] = Input[g(thread_id,i,j)]

Where f() and g() are formula(automatic generated by TVM bound infer I guess) to compute the real addresses based on thread_id(also block_id) , i and j.

In the tutorial of tensorize, I found that it is possible to replace for-loop with handcrafted code.

However, it seems that If I want to replace this part with a handcrafted function, then the information of index (f () and g()) are not valid anymore because I want to change the order of accesses and the memory access pattern to fit memory coalescing.

Is that possible to pass the loop stage values ( i, j in this case) and do index computation inside handcrafted function? Will this approach and modification result in compatible schedule?

Thanks for your answer!