[AutoTVM] Difficult to tune conv2d_transpose for CUDA and bad performance

I’m trying to autotune a decoder-encoder model that has some deconvolutional layers, but auto-tuning seems impossible for my nvidia 1050ti with CUDA.

I tried many different tuners for thousands (10k) of trials, but no trial is successful, that means GFLOPS sticks to 0 and in the log I only get error codes (mainly 1, 2, 4, 6, 7). Any idea why finding a combination of parameters that can run on my target is so difficult?

The code for autotuning is the standard that everyone can find in the tutorials.

P.S. I can tune fine conv2d layers.

Example of errors:

error_no=1
DEBUG:autotvm:No: 4371	GFLOPS: 0.00/0.00	result: MeasureResult(costs=(TVMError('Traceback (most recent call last):
  [bt] (5) /tvm/incubator-tvm-gpu/build/libtvm.so(TVMFuncCall+0x61) [0x7f18621b3021]
  [bt] (4) /tvm/incubator-tvm-gpu/build/libtvm.so(+0x62de01) [0x7f1861a3be01]
  [bt] (3) /tvm/incubator-tvm-gpu/build/libtvm.so(tvm::te::Stage::fuse(tvm::Array<tvm::tir::IterVar, void> const&, tvm::tir::IterVar*)+0x12a) [0x7f1861a398ba]
  [bt] (2) /tvm/incubator-tvm-gpu/build/libtvm.so(tvm::te::Stage::fuse(tvm::tir::IterVar, tvm::tir::IterVar, tvm::tir::IterVar*)+0x405) [0x7f1861a34585]
  [bt] (1) /tvm/incubator-tvm-gpu/build/libtvm.so(tvm::te::FindLeafVar(tvm::ArrayNode*, tvm::ArrayNode*, tvm::tir::IterVar const&)+0xe6) [0x7f1861a306d6]
  [bt] (0) /tvm/incubator-tvm-gpu/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7f1861825ae2]
  File "/tvm/incubator-tvm-gpu/src/te/schedule/schedule_lang.cc", line 51
TVMError: Operate on iter var iter_var(w.inner.inner.outer, )that has already been split',),), error_no=2, all_cost=0.019483089447021484, timestamp=1589436232.3741207)	[('tile_n', [-1, 1, 1, 1]), ('tile_f', [-1, 4, 4, 8]), ('tile_y', [-1, 20, 1, 1]), ('tile_x', [-1, 2, 2, 3]), ('tile_rc', [-1, 4, 2]), ('auto_unroll_max_step', 512), ('unroll_explicit', 0), ('fuse_yx', 1)],None,23160694877
error_no=2
DEBUG:autotvm:No: 4370	GFLOPS: 0.00/0.00	result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):
  [bt] (4) /tvm/incubator-tvm-gpu/build/libtvm.so(TVMFuncCall+0x61) [0x7f18621b3021]
  [bt] (3) /tvm/incubator-tvm-gpu/build/libtvm.so(+0x48bbf1) [0x7f1861899bf1]
  [bt] (2) /tvm/incubator-tvm-gpu/build/libtvm.so(tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x3ca) [0x7f186189868a]
  [bt] (1) /tvm/incubator-tvm-gpu/build/libtvm.so(tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x50c) [0x7f1861b0856c]
  [bt] (0) /tvm/incubator-tvm-gpu/build/libtvm.so(+0xda16ab) [0x7f18621af6ab]
  File "/tvm/incubator-tvm-gpu/python/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun
    rv = local_pyfunc(*pyargs)
  File "/tvm/incubator-tvm-gpu/python/tvm/autotvm/measure/measure_methods.py", line 622, in verify_pass
    raise InstantiationError("Skipped because of invalid gpu kernel")
tvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.02739715576171875, timestamp=1589436232.3713)	[('tile_n', [-1, 1, 1, 1]), ('tile_f', [-1, 2, 4, 16]), ('tile_y', [-1, 4, 10, 4]), ('tile_x', [-1, 2, 8, 3]), ('tile_rc', [-1, 64, 4]), ('auto_unroll_max_step', 64), ('unroll_explicit', 0), ('fuse_yx', 0)],None,1660783880
error_no=4
DEBUG:autotvm:No: 4373	GFLOPS: 0.00/0.00	result: MeasureResult(costs=(RuntimeError('Traceback (most recent call last):
  [bt] (5) /tvm/incubator-tvm-gpu/build/libtvm.so(TVMFuncCall+0x61) [0x7f18621b3021]
  [bt] (4) /tvm/incubator-tvm-gpu/build/libtvm.so(+0xdf4b02) [0x7f1862202b02]
  [bt] (3) /tvm/incubator-tvm-gpu/build/libtvm.so(tvm::runtime::RPCWrappedFunc::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x21f) [0x7f186220655f]
  [bt] (2) /tvm/incubator-tvm-gpu/build/libtvm.so(tvm::runtime::RPCClientSession::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)> const&)+0x57) [0x7f18621fe637]
  [bt] (1) /tvm/incubator-tvm-gpu/build/libtvm.so(tvm::runtime::RPCEndpoint::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)>)+0x36e) [0x7f18621f679e]
  [bt] (0) /tvm/incubator-tvm-gpu/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7f1861825ae2]
  File "/tvm/incubator-tvm-gpu/src/runtime/rpc/rpc_endpoint.cc", line 799
TVMError: Check failed: code == RPCCode: :kReturn: code=1',),), error_no=4, all_cost=13.760782241821289, timestamp=1589436246.2543547)	[('tile_n', [-1, 1, 1, 1]), ('tile_f', [-1, 1, 2, 8]), ('tile_y', [-1, 80, 4, 1]), ('tile_x', [-1, 4, 4, 2]), ('tile_rc', [-1, 256, 1]), ('auto_unroll_max_step', 512), ('unroll_explicit', 0), ('fuse_yx', 0)],None,3779905915
error_no=6
DEBUG:autotvm:No: 9951	GFLOPS: 0.00/0.00	result: MeasureResult(costs=(TimeoutError(),), error_no=6, all_cost=10, timestamp=1589442262.925702)	[('tile_n', [-1, 1, 1, 1]), ('tile_f', [-1, 2, 32, 4]), ('tile_y', [-1, 64, 1, 1]), ('tile_x', [-1, 10, 6, 1]), ('tile_rc', [-1, 8, 1]), ('auto_unroll_max_step', 1500), ('unroll_explicit', 1), ('fuse_yx', 0)],None,16186660157
error_no=7
DEBUG:autotvm:No: 4596	GFLOPS: 0.00/0.00	result: MeasureResult(costs=('Traceback (most recent call last):
  [bt] (3) /tvm/incubator-tvm-gpu/build/libtvm.so(TVMFuncCall+0x61) [0x7fbf55233021]
  [bt] (2) /tvm/incubator-tvm-gpu/build/libtvm.so(std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::PackedFunc tvm::runtime::detail::PackFuncVoidAddr_<4, tvm::runtime::CUDAWrappedFunc>(tvm::runtime::CUDAWrappedFunc, std::vector<tvm::runtime::detail::ArgConvertCode, std::allocator<tvm::runtime::detail::ArgConvertCode> > const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)+0xbc) [0x7fbf552a811c]
  [bt] (1) /tvm/incubator-tvm-gpu/build/libtvm.so(tvm::runtime::CUDAWrappedFunc::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*, void**) const+0x665) [0x7fbf552a7bf5]
  [bt] (0) /tvm/incubator-tvm-gpu/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7fbf548a5ae2]
  File "/tvm/incubator-tvm-gpu/src/runtime/cuda/cuda_module.cc", line 190
  File "/tvm/incubator-tvm-gpu/src/runtime/rpc/rpc_endpoint.cc", line 370
RPCError: Error caught from RPC call:
[06:07:58] /tvm/incubator-tvm-gpu/src/runtime/library_module.cc:78: Check failed: ret == 0 (-1 vs. 0) : TVMError: CUDALaunch Error: CUDA_ERROR_INVALID_VALUE
 grid=(48,1,1),  block=(1,1,1)
// func_name=default_function_kernel0
// CUDA Source

<error is to long to post>

',), error_no=7, all_cost=10, timestamp=1589436478.2384377)	[('tile_n', [-1, 1, 1, 1]), ('tile_f', [-1, 128, 1, 2]), ('tile_y', [-1, 40, 1, 8]), ('tile_x', [-1, 1, 1, 20]), ('tile_rc', [-1, 1, 1]), ('auto_unroll_max_step', 1500), ('unroll_explicit', 0), ('fuse_yx', 0)],None,6447087787

Anyone ever had a similar issue? Is it because the schedule is not optimized?

@Huyuwei Any idea about this issue?

In the end, I found few configurations that actually ran on my GPU, but it seems like conv2d_transpose performance is not quite good, especially when compared with other implementations (e.g. CUDNN) that can run >2x faster than the implementation found by autoTVM.

Is this related to conv2d_transpose schedule that is not highly optimized? @kevinthesun @Huyuwei