Got error on Jetson TX2 with resnet50_v2 CUDA OUT_OF_RESOURCES

I compiled resnet50_v2 using the latest TVM

from mxnet.gluon.model_zoo.vision import get_model
model_name = 'resnet50_v2'

block = get_model(model_name, pretrained=True)
target = tvm.target.cuda()
target_host = 'llvm -target=aarch64-linux-gnu'
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
set_cuda_target_arch('sm_62')

lib.export_library(path_lib, cc="aarch64-linux-gnu-g++")

But when I tried to run it on Jetson TX2 I got CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES

# relay compiled version:
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (3) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(TVMFuncCall+0x70) [0x7fa0d926a8]
  [bt] (2) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x859b8) [0x7fa0e039b8]
  [bt] (1) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x854cc) [0x7fa0e034cc]
  [bt] (0) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x10064) [0x7fa0d8e064]
  File "/home/nvidia/tvm/src/runtime/cuda/cuda_module.cc", line 215
  File "/home/nvidia/tvm/src/runtime/module_util.cc", line 73
TVMError: Check failed: ret == 0 (-1 vs. 0) : CUDALaunch Error: CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES
 grid=(2,14,4),  block=(28,1,16)
// func_name=fused_nn_conv2d_add_3_kernel0
// CUDA Source
// -----------
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-24817639
// Cuda compilation tools, release 10.0, V10.0.130
// Based on LLVM 3.4svn
//

.version 6.3
.target sm_62
.address_size 64

        // .globl       fused_nn_dense_add_kernel0
// _ZZ26fused_nn_dense_add_kernel0E8red_buf0 has been demoted
// _ZZ26fused_nn_dense_add_kernel0E7compute has been demoted
// _ZZ48fused_nn_conv2d_add_multiply_add_nn_relu_kernel0E15pad_temp_shared has been demoted
// _ZZ48fused_nn_conv2d_add_multiply_add_nn_relu_kernel0E18placeholder_shared has been demoted
// _ZZ37fused_nn_conv2d_add_nn_relu_5_kernel0E15pad_temp_shared has been demoted
nnvm compiled version:
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (3) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(TVMFuncCall+0x70) [0x7f8b45c6a8]
  [bt] (2) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x859b8) [0x7f8b4cd9b8]
  [bt] (1) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x854cc) [0x7f8b4cd4cc]
  [bt] (0) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x10064) [0x7f8b458064]
  File "/home/nvidia/tvm/src/runtime/cuda/cuda_module.cc", line 215
  File "/home/nvidia/tvm/src/runtime/module_util.cc", line 73
TVMError: Check failed: ret == 0 (-1 vs. 0) : CUDALaunch Error: CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES
 grid=(2,14,4),  block=(28,1,16)
// func_name=fuse_conv2d_kernel0
// CUDA Source
// -----------
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-24817639
// Cuda compilation tools, release 10.0, V10.0.130
// Based on LLVM 3.4svn
//

.version 6.3
.target sm_62
.address_size 64

        // .globl       fuse_broadcast_add_kernel0
// _ZZ38fuse_conv2d_broadcast_add_relu_kernel0E15pad_temp_shared has been demoted
// _ZZ38fuse_conv2d_broadcast_add_relu_kernel0E13input1_shared has been demoted

Iā€™m using cuda 10.0

Try
target = tvm.target.cuda("-model=tx2")

Because tx2 has a smaller amount of shared memory and number of threads.

1 Like

I tried target = tvm.target.cuda("tx2") got the same error but in different func

  [bt] (3) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(TVMFuncCall+0x70) [0x7f993ff6a8]
  [bt] (2) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x859b8) [0x7f994709b8]
  [bt] (1) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x854cc) [0x7f994704cc]
  [bt] (0) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x10064) [0x7f993fb064]
  File "/home/nvidia/tvm/src/runtime/cuda/cuda_module.cc", line 215
  File "/home/nvidia/tvm/src/runtime/module_util.cc", line 73
TVMError: Check failed: ret == 0 (-1 vs. 0) : CUDALaunch Error: CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES
 grid=(1,1,16),  block=(14,2,8)
// func_name=fuse_conv2d_elemwise_add_2_kernel0
// CUDA Source
// -----------
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-24817639
// Cuda compilation tools, release 10.0, V10.0.130
// Based on LLVM 3.4svn
//

.version 6.3
.target sm_62
.address_size 64

What is interesting is that tvm can compile mxnet ssd_resnet50_512 and it works fine on Jetson TX2 on gpu. But regular resnet50 224 fails with CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES

$ ./deviceQuery 
./deviceQuery Starting...

 CUDA Device Query (Runtime API) version (CUDART static linking)

Detected 1 CUDA Capable device(s)

Device 0: "NVIDIA Tegra X2"
  CUDA Driver Version / Runtime Version          10.0 / 10.0
  CUDA Capability Major/Minor version number:    6.2
  Total amount of global memory:                 7852 MBytes (8233566208 bytes)
  ( 2) Multiprocessors, (128) CUDA Cores/MP:     256 CUDA Cores
  GPU Max Clock rate:                            1020 MHz (1.02 GHz)
  Memory Clock rate:                             1300 Mhz
  Memory Bus Width:                              128-bit
  L2 Cache Size:                                 524288 bytes
  Maximum Texture Dimension Size (x,y,z)         1D=(131072), 2D=(131072, 65536), 3D=(16384, 16384, 16384)
  Maximum Layered 1D Texture Size, (num) layers  1D=(32768), 2048 layers
  Maximum Layered 2D Texture Size, (num) layers  2D=(32768, 32768), 2048 layers
  Total amount of constant memory:               65536 bytes
  Total amount of shared memory per block:       49152 bytes
  Total number of registers available per block: 32768
  Warp size:                                     32
  Maximum number of threads per multiprocessor:  2048
  Maximum number of threads per block:           1024
  Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
  Max dimension size of a grid size    (x,y,z): (2147483647, 65535, 65535)
  Maximum memory pitch:                          2147483647 bytes
  Texture alignment:                             512 bytes
  Concurrent copy and kernel execution:          Yes with 1 copy engine(s)
  Run time limit on kernels:                     No
  Integrated GPU sharing Host Memory:            Yes
  Support host page-locked memory mapping:       Yes
  Alignment requirement for Surfaces:            Yes
  Device has ECC support:                        Disabled
  Device supports Unified Addressing (UVA):      Yes
  Device supports Compute Preemption:            Yes
  Supports Cooperative Kernel Launch:            Yes
  Supports MultiDevice Co-op Kernel Launch:      Yes
  Device PCI Domain ID / Bus ID / location ID:   0 / 0 / 0
  Compute Mode:
     < Default (multiple host threads can use ::cudaSetDevice() with device simultaneously) >

deviceQuery, CUDA Driver = CUDART, CUDA Driver Version = 10.0, CUDA Runtime Version = 10.0, NumDevs = 1
Result = PASS

@merrymercy I tired to run compiled resnet50_v2 on AWS EC2 p2 (Tesla K80 sm_37) - it works!
BTW, P2 has the same amount of shared memory per block as Jetson TX2 - 49152 bytes

Device 0: "Tesla K80"
  CUDA Driver Version / Runtime Version          10.1 / 10.1
  CUDA Capability Major/Minor version number:    3.7
  Total amount of global memory:                 11441 MBytes (11996954624 bytes)
  (13) Multiprocessors, (192) CUDA Cores/MP:     2496 CUDA Cores
  GPU Max Clock rate:                            824 MHz (0.82 GHz)
  Memory Clock rate:                             2505 Mhz
  Memory Bus Width:                              384-bit
  L2 Cache Size:                                 1572864 bytes
  Maximum Texture Dimension Size (x,y,z)         1D=(65536), 2D=(65536, 65536), 3D=(4096, 4096, 4096)
  Maximum Layered 1D Texture Size, (num) layers  1D=(16384), 2048 layers
  Maximum Layered 2D Texture Size, (num) layers  2D=(16384, 16384), 2048 layers
  Total amount of constant memory:               65536 bytes
  Total amount of shared memory per block:       49152 bytes
  Total number of registers available per block: 65536
  Warp size:                                     32
  Maximum number of threads per multiprocessor:  2048
  Maximum number of threads per block:           1024
  Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
  Max dimension size of a grid size    (x,y,z): (2147483647, 65535, 65535)
  Maximum memory pitch:                          2147483647 bytes
  Texture alignment:                             512 bytes
  Concurrent copy and kernel execution:          Yes with 2 copy engine(s)
  Run time limit on kernels:                     No
  Integrated GPU sharing Host Memory:            No
  Support host page-locked memory mapping:       Yes
  Alignment requirement for Surfaces:            Yes
  Device has ECC support:                        Enabled
  Device supports Unified Addressing (UVA):      Yes
  Device supports Compute Preemption:            No
  Supports Cooperative Kernel Launch:            No
  Supports MultiDevice Co-op Kernel Launch:      No
  Device PCI Domain ID / Bus ID / location ID:   0 / 0 / 30
  Compute Mode:
     < Default (multiple host threads can use ::cudaSetDevice() with device simultaneously) >

deviceQuery, CUDA Driver = CUDART, CUDA Driver Version = 10.1, CUDA Runtime Version = 10.1, NumDevs = 1
Result = PASS

Jetson TX2 and EC2 P2 comparison (params which are worse on TX2)

                               TX2     EC2 P2
Total amount of global memory 7852    11441 MBytes
Multiprocessors                  2       13
CUDA Cores/MP                  128      192
CUDA Cores                     256     2496
L2 Cache Size:              524288  1572864
Total number of registers
available per block:         32768    65536

Do you get any warnings about fallback configs being used? Otherwise pretuned kernels should not have errors unless something has changed at the code generation level.

I got 3 warnings

Cannot find config for target=cuda -model=unknown, workload=('conv2d', (1, 128, 56, 56, 'float32'), (128, 128, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=cuda -model=unknown, workload=('conv2d', (1, 256, 28, 28, 'float32'), (256, 256, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=cuda -model=unknown, workload=('conv2d', (1, 512, 14, 14, 'float32'), (512, 512, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.

Can you try autotuning the network to resolve the issue of missing configs?
It is likely that one of the fallback configs is causing the out of resources error.
https://docs.tvm.ai/tutorials/autotvm/tune_relay_cuda.html

autotuning solved the issue. Thank you eqy!

tx2 tx2.mxnet.resnet50_models.float32.log resnet50_v2 float32:
Mean inference time (std dev): 32.56 ms (0.32 ms)

tx2 tx2.mxnet.resnet50_models.float32.log resnet50_v1 float32:
Mean inference time (std dev): 29.62 ms (0.67 ms)
1 Like

What is interesting is that is was not just missing lines in original cuda_v0.04.log file.
I tried to add 4 missing lines to original cuda_v0.04.log file but still got the same error.
So, I had to replace corresponding lines in original cuda_v0.04.log file with lines generated during autotvm process for resnet50 models for Jetson TX2.

The following PR replaces 23 lines and adds 4 new lines to cuda_v0.04.log file.

resnet50 models compiled with new cuda_v0.04.log file work fine on Jetson TX2.

tx2 cuda_v0.04.log resnet50_v1 float32: Mean inference time (std dev): 29.61 ms (0.66 ms)

tx2 cuda_v0.04.log resnet50_v2 float32: Mean inference time (std dev): 32.89 ms (0.66 ms)
1 Like

I had same problem. After I changed set_cuda_target_arch('sm_62') to set_cuda_target_arch('sm_37') like AWS EC2 p2 you used, I could run the code with no problem.