[cuda][Relay] Illegal memory access caused by OpFusion pass

I’m encountering a very annoying error when building a convolutional model. When using small input sizes, I can build using opt-level=3 and everything works great. However, for larger inputs, opt-level>=1 causes the error:

CUDA: Check failed: e == cudaSuccess || e == cudaErrorCudartUnloading: an illegal memory access was encountered

Although this may seem like my GPU is running out of memory, everything works fine when opt_level=0. Since the only extra pass for opt_level=1 is operator fusion, I suspect operations are being fused in a way that causes cuda to run out of memory. Any thoughts on how to work around this?

Can you try to look into the specific kernel? We could try to print out the name in graph runtime, and see which one causes seg, and then inspect the ir after fusion

Using cuda-gdb and memcheck, I can see that the error occurs in fused_reshape_transpose_reshape_nn_leaky_relu_nn_pad_kernel0. Is there a better way to dig in to what specifically causes the issue?

Hmm, so it is not even a conv2d kernel. We could use lib.imported_modules[0].get_source() to print out the cuda source code, and see what exactly is the code in there.

Then we can isolate it out. Some possiblities

  • could due to too big input size making the index calculation go OOB(of int32), we are in the process of upgrading infra to make use of 64 bit index if that is the case, but might take a bit time.
  • a simplification error
  • the schedule template not inlining certain things properly

Looks pretty hard to parse but here is the source for the function just in case you can spot something obvious.

extern "C" __global__ void fused_reshape_transpose_reshape_nn_leaky_relu_nn_pad_kernel0( float* __restrict__ T_pad,  float* __restrict__ placeholder) {
  for (int ax0_ax1_fused_ax2_fused_ax3_fused_outer = 0; ax0_ax1_fused_ax2_fused_ax3_fused_outer < 102; ++ax0_ax1_fused_ax2_fused_ax3_fused_outer) {
    if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) < (13307928 - ((int)threadIdx.x))) {
      T_pad[(((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x))] = (((((1922 <= ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 2217988)) && (((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 2217988) < 2216066)) && (1 <= ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922))) && (((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) < 1921)) ? ((0.000000e+00f < placeholder[((((((((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 2217988) / 1922) + ((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1921) / 1920)) % 4) * 3317760) + (((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1) % 4) * 829440)) + ((((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 13307928) / 2217988) + (((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 2217988) / 1922) + ((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1921) / 1920)) / 1152)) % 6) * 138240)) + (((((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 2217988) / 1922) + ((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1921) / 1920)) % 1152) / 4) * 480)) + (((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1921) % 1920) / 4))]) ? placeholder[((((((((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 2217988) / 1922) + ((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1921) / 1920)) % 4) * 3317760) + (((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1) % 4) * 829440)) + ((((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 13307928) / 2217988) + (((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 2217988) / 1922) + ((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1921) / 1920)) / 1152)) % 6) * 138240)) + (((((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 2217988) / 1922) + ((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1921) / 1920)) % 1152) / 4) * 480)) + (((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1921) % 1920) / 4))] : (placeholder[((((((((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 2217988) / 1922) + ((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1921) / 1920)) % 4) * 3317760) + (((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1) % 4) * 829440)) + ((((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 13307928) / 2217988) + (((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 2217988) / 1922) + ((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1921) / 1920)) / 1152)) % 6) * 138240)) + (((((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 2217988) / 1922) + ((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1921) / 1920)) % 1152) / 4) * 480)) + (((((((ax0_ax1_fused_ax2_fused_ax3_fused_outer * 131072) + (((int)blockIdx.x) * 512)) + ((int)threadIdx.x)) % 1922) - 1921) % 1920) / 4))] * 2.000000e-01f)) : 0.000000e+00f);
    } 
  }
}

Also potentially relevant: this fused op is the result of fusing DepthToSpace (which is implemented as reshape->transpose->reshape), LeakyRelu, and the padding for a proceeding convolution. It seems like fusing into depthtospace can cause indices to get out of control due to the large transpose. Is there a good way to annotate the entire depthtospace op as unfusable?

I think you’re definitely right about this being an OOB of int32 indices issue. Although I’d like to get this model working with large images eventually, it’s not urgent. Do you approximately know how long the int64 indexing infra-upgrade will take?

mark the op as opaque might help

Since DepthToSpace isn’t a real relay operator and instead is only a composition of transpose and reshape marking it opaque is a little tricky. It should probably be implemented as a full op down the line (maybe under relay.image). However, marking reshape as Opaque has resolved the issue on my end for now. Interestingly, marking transpose (and not reshape) as opaque stopped it from being fused, but the fused reshape_leaky_pad kernel still triggered the same error. Thanks for the help @tqchen!

Perhaps we could manually insert stop_fusion annotation hint, which i believe was implemented. We will likely tackle the int32 around end of the summer