[Auto Tuning] Auto-tuning a vectoradd on GPU


#1

Hello,

To help understand how TVM auto-tuning works on GPU, I created a simple example tune_vectoradd_gpu.py, which splits the addition of two vectors with a factor (e.g., 64) and let each CUDA block process a partition. The split factor is set to be autotuned with a set of candidates [32, 64, 256, 1024, 2048]. However, even though manual specifying the split factor works (e.g., as in naive_vectoradd()), combining it with cfg.define_know() and split() does not, and all the trials are failing (e.g., getting 0 GFLOPS). Without auto-tuning, the performance I get is bad on GPU for some new models with large dense layers and conv2d layers. However, why even a simple example like this is still failing? I would really appreciate if you should shed some light on what might cause the failure.

Here is the code for tune_vectoradd_gpu.py:

import tvm
import time
import numpy as np
import numpy
from tvm import autotvm
import logging
import sys

device = “cuda”

log_file = “cuda_vectoradd.log”
dtype = ‘float32’

ctx = tvm.context(device, 0)

tuning_option = {
‘log_filename’: log_file,

'tuner': 'random',
'n_trial': 100,
'early_stopping': 100,

'measure_option': autotvm.measure_option(
    builder=autotvm.LocalBuilder(timeout=10),
    runner=autotvm.RPCRunner(
        'titanv100',  # change the device key to your key
        '0.0.0.0', 9190,
        number=20, repeat=3, timeout=4, min_repeat_ms=150),
),

}

def vectoradd_naive():

N = 1048576
A = tvm.placeholder ((N,), name='A', dtype=dtype)
B = tvm.placeholder ((N,), name='B', dtype=dtype)
C = tvm.compute (A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule (C.op)

bx, tx = s[C].split (C.op.axis[0], factor=64)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))

module = tvm.build(s, [A, B, C], device, target_host="llvm")

print(tvm.lower(s, [A, B, C], simple_mode=True))

a = numpy.random.rand(N).astype(dtype)
a_np = tvm.nd.array(a, ctx)
b = numpy.random.rand(N).astype(dtype)
b_np = tvm.nd.array(b, ctx)
c_np = a + b

c_tvm = tvm.nd.array(numpy.random.rand(N).astype(dtype), ctx)
module(a_np, b_np, c_tvm)
 
tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)

evaluator = module.time_evaluator(module.entry_name, ctx, number=100)
print('Naive: %f' % evaluator(a_np, b_np, c_tvm).mean)

@autotvm.template
def vectoradd(N, dtype):

A = tvm.placeholder ((N,), name='A', dtype=dtype)
B = tvm.placeholder ((N,), name='B', dtype=dtype)
C = tvm.compute (A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule (C.op)

x = s[C].op.axis[0]
cfg = autotvm.get_config()
cfg.define_knob("tile_x", [32, 64, 256, 1024])
bx, tx = s[C].split(x, factor=cfg["tile_x"].val)

s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))

print(tvm.lower(s, [A, B, C], simple_mode=True))

return s, [A, B, C]

def tune_task(task,
measure_option,
tuner=‘random’,
n_trial=10,
early_stopping=None,
log_filename=‘tuning.log’,
use_transfer_learning=True):

measure_option = autotvm.measure_option(
    builder='local',
    runner=autotvm.LocalRunner(number=5))

tuner_obj = autotvm.tuner.RandomTuner(task)
print("do tuning")
tuner_obj.tune(n_trial=min(n_trial, len(task.config_space)),
           early_stopping=early_stopping,
           measure_option=measure_option,
           callbacks=[
               autotvm.callback.progress_bar(n_trial),
               autotvm.callback.log_to_file(log_filename)])

def tune_and_evaluate(tuning_opt):

N = 1048576
task = autotvm.task.create(vectoradd, args=(N, 'float32'), target=device)
print(task.config_space)

# run tuning tasks
print("Tuning...")
tune_task(task, **tuning_opt)

# inspect the best config
dispatch_context = autotvm.apply_history_best(log_file)
best_config = dispatch_context.query(task.target, task.workload)
print("\nBest config:")
print(best_config)

print("apply history best from log file")
with autotvm.apply_history_best(log_file):
    with tvm.target.create(device):
        s, arg_bufs = vectoradd(N, 'float32')
        func = tvm.build(s, arg_bufs)

a = numpy.random.rand(N).astype(dtype)
a_np = tvm.nd.array(a, ctx)
b = numpy.random.rand(N).astype(dtype)
b_np = tvm.nd.array(b, ctx)
c_np = a + b

c_tvm = tvm.nd.array(numpy.random.rand(N).astype(dtype), ctx)
func(a_np, b_np, c_tvm)
 
tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)

evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
print('Opt: %f' % evaluator(a_np, b_np, c_tvm).mean)

print("done")

vectoradd_naive()
tune_and_evaluate(tuning_option)


The schedule generated by the auto-tuner for each candidate factor seems to be correct:
produce C {
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1024
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1024
C[((blockIdx.x1024) + threadIdx.x)] = (A[((blockIdx.x1024) + threadIdx.x)] + B[((blockIdx.x*1024) + threadIdx.x)])
}

produce C {
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 4096
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 256
C[((blockIdx.x256) + threadIdx.x)] = (A[((blockIdx.x256) + threadIdx.x)] + B[((blockIdx.x*256) + threadIdx.x)])
}

produce C {
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 16384
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 64
C[((blockIdx.x64) + threadIdx.x)] = (A[((blockIdx.x64) + threadIdx.x)] + B[((blockIdx.x*64) + threadIdx.x)])
}

produce C {
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 32768
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
C[((blockIdx.x32) + threadIdx.x)] = (A[((blockIdx.x32) + threadIdx.x)] + B[((blockIdx.x*32) + threadIdx.x)])
}

But I’m getting the following errors:
No: 1 GFLOPS: 0.00/0.00 result: MeasureResult(costs=(RuntimeError(‘Traceback (most recent call last):\n [bt] (8) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandleRecvPackedSeqArg()+0x9c4) [0x7f2a2f8c0672]\n [bt] (7) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::SwitchToState(tvm::runtime::RPCSession::EventHandler::State)+0x303) [0x7f2a2f8bf7dd]\n [bt] (6) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandlePackedCall()+0x55d) [0x7f2a2f8bb8f7]\n [bt] (5) /home/minjiaz/workspace/TVM/build/libtvm.so(void tvm::runtime::RPCSession::EventHandler::CallHandler<void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue)>(void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue))+0xaf) [0x7f2a2f8c2eed]\n [bt] (4) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCModuleLoad(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x150) [0x7f2a2f8babc3]\n [bt] (3) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::opera’,),), error_no=4, all_cost=0.41570568084716797, timestamp=1559887427.08522) [(‘tile_x’, 64)],None,1
No: 2 GFLOPS: 0.00/0.00 result: MeasureResult(costs=(RuntimeError(‘Traceback (most recent call last):\n [bt] (8) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandleRecvPackedSeqArg()+0x9c4) [0x7fce1a0fe672]\n [bt] (7) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::SwitchToState(tvm::runtime::RPCSession::EventHandler::State)+0x303) [0x7fce1a0fd7dd]\n [bt] (6) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandlePackedCall()+0x55d) [0x7fce1a0f98f7]\n [bt] (5) /home/minjiaz/workspace/TVM/build/libtvm.so(void tvm::runtime::RPCSession::EventHandler::CallHandler<void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue)>(void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue))+0xaf) [0x7fce1a100eed]\n [bt] (4) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCModuleLoad(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x150) [0x7fce1a0f8bc3]\n [bt] (3) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::opera’,),), error_no=4, all_cost=0.40395116806030273, timestamp=1559887427.3626578) [(‘tile_x’, 1024)],None,3
No: 3 GFLOPS: 0.00/0.00 result: MeasureResult(costs=(InstantiationError(‘Traceback (most recent call last):\n [bt] (4) /home/minjiaz/workspace/TVM/build/libtvm.so(TVMFuncCall+0x95) [0x7f2be39a10d5]\n [bt] (3) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::PackedFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x30) [0x7f2be2fc434c]\n [bt] (2) /home/minjiaz/workspace/TVM/build/libtvm.so(std::function<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x5a) [0x7f2be2e0c796]\n [bt] (1) /home/minjiaz/workspace/TVM/build/libtvm.so(+0x2801672) [0x7f2be39a2672]\n [bt] (0) /home/minjiaz/workspace/TVM/build/libtvm.so(+0x2800665) [0x7f2be39a1665]\n File “/home/minjiaz/workspace/TVM/python/tvm/_ffi/_ctypes/function.py”, line 71, in cfun\n rv = local_pyfunc(pyargs)\n File “/home/minjiaz/workspace/TVM/python/tvm/autotvm/measure/measure_methods.py”, line 597, in verify_pass\n raise InstantiationError(“Skipped because of invalid gpu kernel”)\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel’,),), error_no=1, all_cost=0.016252994537353516, timestamp=1559887426.9318764) [(‘tile_x’, 2048)],None,4
No: 4 GFLOPS: 0.00/0.00 result: MeasureResult(costs=(RuntimeError('Traceback (most recent call last):\n [bt] (8) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandleRecvPackedSeqArg()+0x9c4) [0x7f2a2f8c0672]\n [bt] (7) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::SwitchToState(tvm::runtime::RPCSession::EventHandler::State)+0x303) [0x7f2a2f8bf7dd]\n [bt] (6) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandlePackedCall()+0x55d) [0x7f2a2f8bb8f7]\n [bt] (5) /home/minjiaz/workspace/TVM/build/libtvm.so(void tvm::runtime::RPCSession::EventHandler::CallHandler<void (
)(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>(void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue))+0xaf) [0x7f2a2f8c2eed]\n [bt] (4) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCModuleLoad(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x150) [0x7f2a2f8babc3]\n [bt] (3) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::opera’,),), error_no=4, all_cost=0.40870046615600586, timestamp=1559887427.673778) [(‘tile_x’, 256)],None,2
No: 5 GFLOPS: 0.00/0.00 result: MeasureResult(costs=(RuntimeError(‘Traceback (most recent call last):\n [bt] (8) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandleRecvPackedSeqArg()+0x9c4) [0x7fce1a0fe672]\n [bt] (7) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::SwitchToState(tvm::runtime::RPCSession::EventHandler::State)+0x303) [0x7fce1a0fd7dd]\n [bt] (6) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCSession::EventHandler::HandlePackedCall()+0x55d) [0x7fce1a0f98f7]\n [bt] (5) /home/minjiaz/workspace/TVM/build/libtvm.so(void tvm::runtime::RPCSession::EventHandler::CallHandler<void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue)>(void ()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue))+0xaf) [0x7fce1a100eed]\n [bt] (4) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::RPCModuleLoad(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x150) [0x7fce1a0f8bc3]\n [bt] (3) /home/minjiaz/workspace/TVM/build/libtvm.so(tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::opera’,),), error_no=4, all_cost=0.40895915031433105, timestamp=1559887427.954541) [(‘tile_x’, 32)],None,0
Cannot find config for target=cuda, workload=(‘vectoradd’, 1048576, ‘float32’). A fallback configuration is used, which may bring great performance regression.