[Auto Tuning] Auto-tuning a vectoradd on GPU

Hello,

To help understand how TVM auto-tuning works on GPU, I created a simple example tune_vectoradd_gpu.py, which splits the addition of two vectors with a factor (e.g., 64) and let each CUDA block process a partition. The split factor is set to be autotuned with a set of candidates [32, 64, 256, 1024, 2048]. However, even though manual specifying the split factor works (e.g., as in naive_vectoradd()), combining it with cfg.define_know() and split() does not, and all the trials are failing (e.g., getting 0 GFLOPS). Without auto-tuning, the performance I get is bad on GPU for some new models with large dense layers and conv2d layers. However, why even a simple example like this is still failing? I would really appreciate if you should shed some light on what might cause the failure.

Here is the code for tune_vectoradd_gpu.py:

import tvm
import time
import numpy as np
import numpy
from tvm import autotvm
import logging
import sys

device = “cuda”

log_file = “cuda_vectoradd.log”
dtype = ‘float32’

ctx = tvm.context(device, 0)

tuning_option = {
‘log_filename’: log_file,

'tuner': 'random',
'n_trial': 100,
'early_stopping': 100,

'measure_option': autotvm.measure_option(
    builder=autotvm.LocalBuilder(timeout=10),
    runner=autotvm.RPCRunner(
        'titanv100',  # change the device key to your key
        '0.0.0.0', 9190,
        number=20, repeat=3, timeout=4, min_repeat_ms=150),
),

}

def vectoradd_naive():

N = 1048576
A = tvm.placeholder ((N,), name='A', dtype=dtype)
B = tvm.placeholder ((N,), name='B', dtype=dtype)
C = tvm.compute (A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule (C.op)

bx, tx = s[C].split (C.op.axis[0], factor=64)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))

module = tvm.build(s, [A, B, C], device, target_host="llvm")

print(tvm.lower(s, [A, B, C], simple_mode=True))

a = numpy.random.rand(N).astype(dtype)
a_np = tvm.nd.array(a, ctx)
b = numpy.random.rand(N).astype(dtype)
b_np = tvm.nd.array(b, ctx)
c_np = a + b

c_tvm = tvm.nd.array(numpy.random.rand(N).astype(dtype), ctx)
module(a_np, b_np, c_tvm)
 
tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)

evaluator = module.time_evaluator(module.entry_name, ctx, number=100)
print('Naive: %f' % evaluator(a_np, b_np, c_tvm).mean)

@autotvm.template
def vectoradd(N, dtype):

A = tvm.placeholder ((N,), name='A', dtype=dtype)
B = tvm.placeholder ((N,), name='B', dtype=dtype)
C = tvm.compute (A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule (C.op)

x = s[C].op.axis[0]
cfg = autotvm.get_config()
cfg.define_knob("tile_x", [32, 64, 256, 1024])
bx, tx = s[C].split(x, factor=cfg["tile_x"].val)

s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))

print(tvm.lower(s, [A, B, C], simple_mode=True))

return s, [A, B, C]

def tune_task(task,
measure_option,
tuner=‘random’,
n_trial=10,
early_stopping=None,
log_filename=‘tuning.log’,
use_transfer_learning=True):

measure_option = autotvm.measure_option(
    builder='local',
    runner=autotvm.LocalRunner(number=5))

tuner_obj = autotvm.tuner.RandomTuner(task)
print("do tuning")
tuner_obj.tune(n_trial=min(n_trial, len(task.config_space)),
           early_stopping=early_stopping,
           measure_option=measure_option,
           callbacks=[
               autotvm.callback.progress_bar(n_trial),
               autotvm.callback.log_to_file(log_filename)])

def tune_and_evaluate(tuning_opt):

N = 1048576
task = autotvm.task.create(vectoradd, args=(N, 'float32'), target=device)
print(task.config_space)

# run tuning tasks
print("Tuning...")
tune_task(task, **tuning_opt)

# inspect the best config
dispatch_context = autotvm.apply_history_best(log_file)
best_config = dispatch_context.query(task.target, task.workload)
print("\nBest config:")
print(best_config)

print("apply history best from log file")
with autotvm.apply_history_best(log_file):
    with tvm.target.create(device):
        s, arg_bufs = vectoradd(N, 'float32')
        func = tvm.build(s, arg_bufs)

a = numpy.random.rand(N).astype(dtype)
a_np = tvm.nd.array(a, ctx)
b = numpy.random.rand(N).astype(dtype)
b_np = tvm.nd.array(b, ctx)
c_np = a + b

c_tvm = tvm.nd.array(numpy.random.rand(N).astype(dtype), ctx)
func(a_np, b_np, c_tvm)
 
tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)

evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
print('Opt: %f' % evaluator(a_np, b_np, c_tvm).mean)

print("done")

vectoradd_naive()
tune_and_evaluate(tuning_option)


The schedule generated by the auto-tuner for each candidate factor seems to be correct:
produce C {
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1024
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1024
C[((blockIdx.x1024) + threadIdx.x)] = (A[((blockIdx.x1024) + threadIdx.x)] + B[((blockIdx.x*1024) + threadIdx.x)])
}

produce C {
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 4096
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 256
C[((blockIdx.x256) + threadIdx.x)] = (A[((blockIdx.x256) + threadIdx.x)] + B[((blockIdx.x*256) + threadIdx.x)])
}

produce C {
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 16384
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 64
C[((blockIdx.x64) + threadIdx.x)] = (A[((blockIdx.x64) + threadIdx.x)] + B[((blockIdx.x*64) + threadIdx.x)])
}

produce C {
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 32768
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
C[((blockIdx.x32) + threadIdx.x)] = (A[((blockIdx.x32) + threadIdx.x)] + B[((blockIdx.x*32) + threadIdx.x)])
}

But I’m getting the following errors:
No: 1 GFLOPS: 0.00/0.00 result: MeasureResult(costs=(RuntimeError(‘Traceback (most recent call last):\n [bt] (8) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandleRecvPackedSeqArg()+0x9c4) [0x7f2a2f8c0672]\n [bt] (7) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::SwitchToState(tvm::runtime::RPCSession::EventHandler::State)+0x303) [0x7f2a2f8bf7dd]\n [bt] (6) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandlePackedCall()+0x55d) [0x7f2a2f8bb8f7]\n [bt] (5) /home/minjiaz/workspace/TVM/build/libtvm.so(void tvm::runtime::RPCSession::EventHandler::CallHandler<void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue)>(void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue))+0xaf) [0x7f2a2f8c2eed]\n [bt] (4) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCModuleLoad(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x150) [0x7f2a2f8babc3]\n [bt] (3) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::opera’,),), error_no=4, all_cost=0.41570568084716797, timestamp=1559887427.08522) [(‘tile_x’, 64)],None,1
No: 2 GFLOPS: 0.00/0.00 result: MeasureResult(costs=(RuntimeError(‘Traceback (most recent call last):\n [bt] (8) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandleRecvPackedSeqArg()+0x9c4) [0x7fce1a0fe672]\n [bt] (7) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::SwitchToState(tvm::runtime::RPCSession::EventHandler::State)+0x303) [0x7fce1a0fd7dd]\n [bt] (6) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandlePackedCall()+0x55d) [0x7fce1a0f98f7]\n [bt] (5) /home/minjiaz/workspace/TVM/build/libtvm.so(void tvm::runtime::RPCSession::EventHandler::CallHandler<void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue)>(void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue))+0xaf) [0x7fce1a100eed]\n [bt] (4) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCModuleLoad(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x150) [0x7fce1a0f8bc3]\n [bt] (3) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::opera’,),), error_no=4, all_cost=0.40395116806030273, timestamp=1559887427.3626578) [(‘tile_x’, 1024)],None,3
No: 3 GFLOPS: 0.00/0.00 result: MeasureResult(costs=(InstantiationError(‘Traceback (most recent call last):\n [bt] (4) /home/minjiaz/workspace/TVM/build/libtvm.so(TVMFuncCall+0x95) [0x7f2be39a10d5]\n [bt] (3) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::PackedFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x30) [0x7f2be2fc434c]\n [bt] (2) /home/minjiaz/workspace/TVM/build/libtvm.so(std::function<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x5a) [0x7f2be2e0c796]\n [bt] (1) /home/minjiaz/workspace/TVM/build/libtvm.so(+0x2801672) [0x7f2be39a2672]\n [bt] (0) /home/minjiaz/workspace/TVM/build/libtvm.so(+0x2800665) [0x7f2be39a1665]\n File “/home/minjiaz/workspace/TVM/python/tvm/_ffi/_ctypes/function.py”, line 71, in cfun\n rv = local_pyfunc(pyargs)\n File “/home/minjiaz/workspace/TVM/python/tvm/autotvm/measure/measure_methods.py”, line 597, in verify_pass\n raise InstantiationError(“Skipped because of invalid gpu kernel”)\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel’,),), error_no=1, all_cost=0.016252994537353516, timestamp=1559887426.9318764) [(‘tile_x’, 2048)],None,4
No: 4 GFLOPS: 0.00/0.00 result: MeasureResult(costs=(RuntimeError('Traceback (most recent call last):\n [bt] (8) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandleRecvPackedSeqArg()+0x9c4) [0x7f2a2f8c0672]\n [bt] (7) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::SwitchToState(tvm::runtime::RPCSession::EventHandler::State)+0x303) [0x7f2a2f8bf7dd]\n [bt] (6) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandlePackedCall()+0x55d) [0x7f2a2f8bb8f7]\n [bt] (5) /home/minjiaz/workspace/TVM/build/libtvm.so(void tvm::runtime::RPCSession::EventHandler::CallHandler<void (
)(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>(void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue))+0xaf) [0x7f2a2f8c2eed]\n [bt] (4) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCModuleLoad(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x150) [0x7f2a2f8babc3]\n [bt] (3) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::opera’,),), error_no=4, all_cost=0.40870046615600586, timestamp=1559887427.673778) [(‘tile_x’, 256)],None,2
No: 5 GFLOPS: 0.00/0.00 result: MeasureResult(costs=(RuntimeError(‘Traceback (most recent call last):\n [bt] (8) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandleRecvPackedSeqArg()+0x9c4) [0x7fce1a0fe672]\n [bt] (7) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::SwitchToState(tvm::runtime::RPCSession::EventHandler::State)+0x303) [0x7fce1a0fd7dd]\n [bt] (6) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandlePackedCall()+0x55d) [0x7fce1a0f98f7]\n [bt] (5) /home/minjiaz/workspace/TVM/build/libtvm.so(void tvm::runtime::RPCSession::EventHandler::CallHandler<void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue)>(void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue))+0xaf) [0x7fce1a100eed]\n [bt] (4) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCModuleLoad(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x150) [0x7fce1a0f8bc3]\n [bt] (3) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::opera’,),), error_no=4, all_cost=0.40895915031433105, timestamp=1559887427.954541) [(‘tile_x’, 32)],None,0
Cannot find config for target=cuda, workload=(‘vectoradd’, 1048576, ‘float32’). A fallback configuration is used, which may bring great performance regression.

Did you find a solution? I’m facing a similar problem…

@zhangninja your code works for me (I am using V100):

produce C {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 16384
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 64
  C[((blockIdx.x*64) + threadIdx.x)] = (A[((blockIdx.x*64) + threadIdx.x)] + B[((blockIdx.x*64) + threadIdx.x)])
}

Naive: 0.000030
produce C {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 32768
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
  C[((blockIdx.x*32) + threadIdx.x)] = (A[((blockIdx.x*32) + threadIdx.x)] + B[((blockIdx.x*32) + threadIdx.x)])
}

ConfigSpace (len=4, space_map=
   0 tile_x: OtherOption([32, 64, 256, 1024]) len=4
)
Tuning...
do tuning
 Current/Best:    0.00/   0.00 GFLOPS | Progress: (0/100) | 0.00 sproduce C {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 16384
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 64
  C[((blockIdx.x*64) + threadIdx.x)] = (A[((blockIdx.x*64) + threadIdx.x)] + B[((blockIdx.x*64) + threadIdx.x)])
}

produce C {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1024
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1024
  C[((blockIdx.x*1024) + threadIdx.x)] = (A[((blockIdx.x*1024) + threadIdx.x)] + B[((blockIdx.x*1024) + threadIdx.x)])
}

produce C {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 32768
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
  C[((blockIdx.x*32) + threadIdx.x)] = (A[((blockIdx.x*32) + threadIdx.x)] + B[((blockIdx.x*32) + threadIdx.x)])
}

produce C {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 4096
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 256
  C[((blockIdx.x*256) + threadIdx.x)] = (A[((blockIdx.x*256) + threadIdx.x)] + B[((blockIdx.x*256) + threadIdx.x)])
}

 Current/Best:   52.89/  52.89 GFLOPS | Progress: (4/100) | 2.69 s Done.

Best config:
[('tile_x', 256)],,None,2
apply history best from log file
produce C {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 4096
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 256
  C[((blockIdx.x*256) + threadIdx.x)] = (A[((blockIdx.x*256) + threadIdx.x)] + B[((blockIdx.x*256) + threadIdx.x)])
}

Opt: 0.000018
done

@gasgallo, could you describe your problem? I guess it’s different issue.

I’m trying to autotune a resnet101 model on Nvidia Jetson TX2, RPC is working fine but autotuning process complete with 0GFlops. It doesn’t code becuase of error_no=4. Here’s a sample debug info:

DEBUG:autotvm:No: 171	GFLOPS: 0.00/0.00	result: MeasureResult(costs=
(RuntimeError('Traceback (most recent call last):\n  [bt] (8) 
/home/nvidia/tvm/build/libtvm.so(tvm::runtime::RPCSession::ServerLoop()+0x104) [0x7fa1eee4ac]\n  [bt] (7) 
/home/nvidia/tvm/build/libtvm.so(tvm::runtime::RPCSession::HandleUntilReturnEvent(tvm::runtime::
TVMRetValue*, bool, tvm::runtime::PackedFunc const*)+0x1b8) [0x7fa1eee180]\n  [bt] (6) 
/home/nvidia/tvm/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandleNextEvent(tvm::
runtime::TVMRetValue*, bool, tvm::runtime::PackedFunc const*)+0x554) [0x7fa1ef3cd4]\n  [bt] (5) 
/home/nvidia/tvm/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::
HandleRecvPackedSeqArg()+0x148) [0x7fa1ef3220]\n  [bt] (4) 
/home/nvidia/tvm/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::SwitchToState(tvm::
runtime::RPCSession::EventHandler::State)+0x320) [0x7fa1ef2140]\n  [bt] (3) 
/home/nvidia/tvm/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::
HandlePackedCall()+0x688) [0x7fa1eec620]\n 
 [bt] (2) /home/nvidia/tvm/build/libtvm.so(void tvm::runtime::RPCSession::EventHandler::C',),), 
error_no=4, all_cost=1.7421786785125732, timestamp=1570501536.5706928)	[('tile_b', [16, 1, 1, 1]), 
('tile_y', [4, 4, 1, 32]), ('tile_x', [2, 2, 1, 4]), ('tile_rc', [512, 1]), ('auto_unroll_max_step', 128), 
('unroll_explicit', 1)],winograd,None,313907

error_no 4 means the kernel encounters errors on the GPU. Since the traceback too long to be recorded by AutoTVM, it’s hard to figure out the root cause. Here are some suggestions about the next step:

  • Enlarge the trial number. When error number is larger than 150, AutoTVM will enable the debug mode automatically. In the debug mode, the generated GPU kernel code and the runtime error message will be dumped.

  • Manually apply the schedule and see what’s going on. This approach is more complicate but more effective. You could refer to here to get the schedule implementation of winograd conv2d for CUDA. Then you can make a simple case based on it for debugging.