Got error on Jetson TX2 with resnet50_v2 CUDA OUT_OF_RESOURCES


#1

I compiled resnet50_v2 using the latest TVM

from mxnet.gluon.model_zoo.vision import get_model
model_name = 'resnet50_v2'

block = get_model(model_name, pretrained=True)
target = tvm.target.cuda()
target_host = 'llvm -target=aarch64-linux-gnu'
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
set_cuda_target_arch('sm_62')

lib.export_library(path_lib, cc="aarch64-linux-gnu-g++")

But when I tried to run it on Jetson TX2 I got CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES

# relay compiled version:
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (3) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(TVMFuncCall+0x70) [0x7fa0d926a8]
  [bt] (2) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x859b8) [0x7fa0e039b8]
  [bt] (1) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x854cc) [0x7fa0e034cc]
  [bt] (0) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x10064) [0x7fa0d8e064]
  File "/home/nvidia/tvm/src/runtime/cuda/cuda_module.cc", line 215
  File "/home/nvidia/tvm/src/runtime/module_util.cc", line 73
TVMError: Check failed: ret == 0 (-1 vs. 0) : CUDALaunch Error: CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES
 grid=(2,14,4),  block=(28,1,16)
// func_name=fused_nn_conv2d_add_3_kernel0
// CUDA Source
// -----------
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-24817639
// Cuda compilation tools, release 10.0, V10.0.130
// Based on LLVM 3.4svn
//

.version 6.3
.target sm_62
.address_size 64

        // .globl       fused_nn_dense_add_kernel0
// _ZZ26fused_nn_dense_add_kernel0E8red_buf0 has been demoted
// _ZZ26fused_nn_dense_add_kernel0E7compute has been demoted
// _ZZ48fused_nn_conv2d_add_multiply_add_nn_relu_kernel0E15pad_temp_shared has been demoted
// _ZZ48fused_nn_conv2d_add_multiply_add_nn_relu_kernel0E18placeholder_shared has been demoted
// _ZZ37fused_nn_conv2d_add_nn_relu_5_kernel0E15pad_temp_shared has been demoted
nnvm compiled version:
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (3) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(TVMFuncCall+0x70) [0x7f8b45c6a8]
  [bt] (2) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x859b8) [0x7f8b4cd9b8]
  [bt] (1) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x854cc) [0x7f8b4cd4cc]
  [bt] (0) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x10064) [0x7f8b458064]
  File "/home/nvidia/tvm/src/runtime/cuda/cuda_module.cc", line 215
  File "/home/nvidia/tvm/src/runtime/module_util.cc", line 73
TVMError: Check failed: ret == 0 (-1 vs. 0) : CUDALaunch Error: CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES
 grid=(2,14,4),  block=(28,1,16)
// func_name=fuse_conv2d_kernel0
// CUDA Source
// -----------
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-24817639
// Cuda compilation tools, release 10.0, V10.0.130
// Based on LLVM 3.4svn
//

.version 6.3
.target sm_62
.address_size 64

        // .globl       fuse_broadcast_add_kernel0
// _ZZ38fuse_conv2d_broadcast_add_relu_kernel0E15pad_temp_shared has been demoted
// _ZZ38fuse_conv2d_broadcast_add_relu_kernel0E13input1_shared has been demoted

Iā€™m using cuda 10.0


#2

Try
target = tvm.target.cuda("-model=tx2")

Because tx2 has a smaller amount of shared memory and number of threads.


#3

I tried target = tvm.target.cuda("tx2") got the same error but in different func

  [bt] (3) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(TVMFuncCall+0x70) [0x7f993ff6a8]
  [bt] (2) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x859b8) [0x7f994709b8]
  [bt] (1) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x854cc) [0x7f994704cc]
  [bt] (0) /usr/local/lib/python3.6/dist-packages/tvm-0.6.dev0-py3.6-linux-aarch64.egg/tvm/libtvm_runtime.so(+0x10064) [0x7f993fb064]
  File "/home/nvidia/tvm/src/runtime/cuda/cuda_module.cc", line 215
  File "/home/nvidia/tvm/src/runtime/module_util.cc", line 73
TVMError: Check failed: ret == 0 (-1 vs. 0) : CUDALaunch Error: CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES
 grid=(1,1,16),  block=(14,2,8)
// func_name=fuse_conv2d_elemwise_add_2_kernel0
// CUDA Source
// -----------
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-24817639
// Cuda compilation tools, release 10.0, V10.0.130
// Based on LLVM 3.4svn
//

.version 6.3
.target sm_62
.address_size 64

#4

What is interesting is that tvm can compile mxnet ssd_resnet50_512 and it works fine on Jetson TX2 on gpu.
But regular resnet50 224 fails with CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES


#5
$ ./deviceQuery 
./deviceQuery Starting...

 CUDA Device Query (Runtime API) version (CUDART static linking)

Detected 1 CUDA Capable device(s)

Device 0: "NVIDIA Tegra X2"
  CUDA Driver Version / Runtime Version          10.0 / 10.0
  CUDA Capability Major/Minor version number:    6.2
  Total amount of global memory:                 7852 MBytes (8233566208 bytes)
  ( 2) Multiprocessors, (128) CUDA Cores/MP:     256 CUDA Cores
  GPU Max Clock rate:                            1020 MHz (1.02 GHz)
  Memory Clock rate:                             1300 Mhz
  Memory Bus Width:                              128-bit
  L2 Cache Size:                                 524288 bytes
  Maximum Texture Dimension Size (x,y,z)         1D=(131072), 2D=(131072, 65536), 3D=(16384, 16384, 16384)
  Maximum Layered 1D Texture Size, (num) layers  1D=(32768), 2048 layers
  Maximum Layered 2D Texture Size, (num) layers  2D=(32768, 32768), 2048 layers
  Total amount of constant memory:               65536 bytes
  Total amount of shared memory per block:       49152 bytes
  Total number of registers available per block: 32768
  Warp size:                                     32
  Maximum number of threads per multiprocessor:  2048
  Maximum number of threads per block:           1024
  Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
  Max dimension size of a grid size    (x,y,z): (2147483647, 65535, 65535)
  Maximum memory pitch:                          2147483647 bytes
  Texture alignment:                             512 bytes
  Concurrent copy and kernel execution:          Yes with 1 copy engine(s)
  Run time limit on kernels:                     No
  Integrated GPU sharing Host Memory:            Yes
  Support host page-locked memory mapping:       Yes
  Alignment requirement for Surfaces:            Yes
  Device has ECC support:                        Disabled
  Device supports Unified Addressing (UVA):      Yes
  Device supports Compute Preemption:            Yes
  Supports Cooperative Kernel Launch:            Yes
  Supports MultiDevice Co-op Kernel Launch:      Yes
  Device PCI Domain ID / Bus ID / location ID:   0 / 0 / 0
  Compute Mode:
     < Default (multiple host threads can use ::cudaSetDevice() with device simultaneously) >

deviceQuery, CUDA Driver = CUDART, CUDA Driver Version = 10.0, CUDA Runtime Version = 10.0, NumDevs = 1
Result = PASS

#6

@merrymercy I tired to run compiled resnet50_v2 on AWS EC2 p2 (Tesla K80 sm_37) - it works!
BTW, P2 has the same amount of shared memory per block as Jetson TX2 - 49152 bytes

Device 0: "Tesla K80"
  CUDA Driver Version / Runtime Version          10.1 / 10.1
  CUDA Capability Major/Minor version number:    3.7
  Total amount of global memory:                 11441 MBytes (11996954624 bytes)
  (13) Multiprocessors, (192) CUDA Cores/MP:     2496 CUDA Cores
  GPU Max Clock rate:                            824 MHz (0.82 GHz)
  Memory Clock rate:                             2505 Mhz
  Memory Bus Width:                              384-bit
  L2 Cache Size:                                 1572864 bytes
  Maximum Texture Dimension Size (x,y,z)         1D=(65536), 2D=(65536, 65536), 3D=(4096, 4096, 4096)
  Maximum Layered 1D Texture Size, (num) layers  1D=(16384), 2048 layers
  Maximum Layered 2D Texture Size, (num) layers  2D=(16384, 16384), 2048 layers
  Total amount of constant memory:               65536 bytes
  Total amount of shared memory per block:       49152 bytes
  Total number of registers available per block: 65536
  Warp size:                                     32
  Maximum number of threads per multiprocessor:  2048
  Maximum number of threads per block:           1024
  Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
  Max dimension size of a grid size    (x,y,z): (2147483647, 65535, 65535)
  Maximum memory pitch:                          2147483647 bytes
  Texture alignment:                             512 bytes
  Concurrent copy and kernel execution:          Yes with 2 copy engine(s)
  Run time limit on kernels:                     No
  Integrated GPU sharing Host Memory:            No
  Support host page-locked memory mapping:       Yes
  Alignment requirement for Surfaces:            Yes
  Device has ECC support:                        Enabled
  Device supports Unified Addressing (UVA):      Yes
  Device supports Compute Preemption:            No
  Supports Cooperative Kernel Launch:            No
  Supports MultiDevice Co-op Kernel Launch:      No
  Device PCI Domain ID / Bus ID / location ID:   0 / 0 / 30
  Compute Mode:
     < Default (multiple host threads can use ::cudaSetDevice() with device simultaneously) >

deviceQuery, CUDA Driver = CUDART, CUDA Driver Version = 10.1, CUDA Runtime Version = 10.1, NumDevs = 1
Result = PASS

#7

Jetson TX2 and EC2 P2 comparison (params which are worse on TX2)

                               TX2     EC2 P2
Total amount of global memory 7852    11441 MBytes
Multiprocessors                  2       13
CUDA Cores/MP                  128      192
CUDA Cores                     256     2496
L2 Cache Size:              524288  1572864
Total number of registers
available per block:         32768    65536

#8

Do you get any warnings about fallback configs being used? Otherwise pretuned kernels should not have errors unless something has changed at the code generation level.


#9

I got 3 warnings

Cannot find config for target=cuda -model=unknown, workload=('conv2d', (1, 128, 56, 56, 'float32'), (128, 128, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=cuda -model=unknown, workload=('conv2d', (1, 256, 28, 28, 'float32'), (256, 256, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=cuda -model=unknown, workload=('conv2d', (1, 512, 14, 14, 'float32'), (512, 512, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.

#10

Can you try autotuning the network to resolve the issue of missing configs?
It is likely that one of the fallback configs is causing the out of resources error.
https://docs.tvm.ai/tutorials/autotvm/tune_relay_cuda.html


#11

autotuning solved the issue. Thank you eqy!

tx2 tx2.mxnet.resnet50_models.float32.log resnet50_v2 float32:
Mean inference time (std dev): 32.56 ms (0.32 ms)

tx2 tx2.mxnet.resnet50_models.float32.log resnet50_v1 float32:
Mean inference time (std dev): 29.62 ms (0.67 ms)

#12

What is interesting is that is was not just missing lines in original cuda_v0.04.log file.
I tried to add 4 missing lines to original cuda_v0.04.log file but still got the same error.
So, I had to replace corresponding lines in original cuda_v0.04.log file with lines generated during autotvm process for resnet50 models for Jetson TX2.

The following PR replaces 23 lines and adds 4 new lines to cuda_v0.04.log file.

resnet50 models compiled with new cuda_v0.04.log file work fine on Jetson TX2.

tx2 cuda_v0.04.log resnet50_v1 float32: Mean inference time (std dev): 29.61 ms (0.66 ms)

tx2 cuda_v0.04.log resnet50_v2 float32: Mean inference time (std dev): 32.89 ms (0.66 ms)