[AutoTVM] Interpreting the output of the tuner

Hi,
I’m running AutoTVM for tuning a network written using the relay interface. My tuner is based on the tutorial, tune_relay_x86.py.

Here’s my ‘tuning option’

tuning_option = {
     'log_filename': log_file,
     'tuner': 'random',
     'early_stopping': 500,
     'measure_option': autotvm.measure_option(
         builder=autotvm.LocalBuilder(),
         runner=autotvm.LocalRunner(number=10, repeat=1,
                                   min_repeat_ms=1000),
    ),
}

And here’s my output.

Extract tasks...
Tuning...
[Task  1/ 5]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (0/1152) | 0.00 s
[Task  1/ 5]  Current/Best: 1414.35/6037.21 GFLOPS | Progress: (560/1152) | 1654.08 s Done.
[Task  2/ 5]  Current/Best:  551.32/6469.98 GFLOPS | Progress: (336/2048) | 1691.42 s
[Task  2/ 5]  Current/Best:  458.95/6469.98 GFLOPS | Progress: (672/2048) | 3531.38 s
[Task  2/ 5]  Current/Best: 1044.44/6469.98 GFLOPS | Progress: (784/2048) | 4109.08 s Done.
[Task  3/ 5]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (0/1792) | 0.00 s
[Task  3/ 5]  Current/Best:  762.93/6319.01 GFLOPS | Progress: (1680/1792) | 5208.71 s Done.
[Task  4/ 5]  Current/Best: 1997.02/6840.22 GFLOPS | Progress: (784/784) | 3528.16 s Done.
[Task  5/ 5]  Current/Best: 1310.34/6094.27 GFLOPS | Progress: (112/112) | 101.15 s Done.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (128, 3, 227, 227, 'float32'), (64, 3, 11, 11, 'float32'), (4, 4), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (128, 64, 55, 55, 'float32'), (192, 64, 5, 5, 'float32'), (1, 1), (2, 2), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (128, 192, 27, 27, 'float32'), (384, 192, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (128, 384, 27, 27, 'float32'), (384, 384, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (128, 384, 27, 27, 'float32'), (256, 384, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
2019-08-19 16:24:52,311 INFO Start to benchmark layout transformation...
2019-08-19 16:27:42,939 INFO Benchmarking layout transformation successful.
2019-08-19 16:27:42,941 INFO Start to run PBQP algorithm...
2019-08-19 16:27:42,942 INFO Finished PBQPExecutor run. Got optimal solution.
2019-08-19 16:27:42,944 INFO Writing optimal schedules to alexnet_graph_opt.log successfully.
Compile...
Config for target=llvm -mcpu=skylake-avx512, workload=('dense', (128, 4096, 'float32'), (1008, 4096, 'float32'), 0, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
Config for target=llvm -mcpu=skylake-avx512, workload=('dense', (128, 4096, 'float32'), (4096, 4096, 'float32'), 0, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
Config for target=llvm -mcpu=skylake-avx512, workload=('dense', (128, 43264, 'float32'), (4096, 43264, 'float32'), 0, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
Evaluate inference time cost...
Mean inference time (std dev): 484.82 ms (6.10 ms)
  1. As you can see, all the operators (5 convs, and 3 dense layers) fall back to default configurations. Can someone please explain why this happens?

  2. Also, what does `-device=tracing’ mean?

Thanks,

Could you specify the network, full target (e.g., “llvm -mcpu=” or something like that), and the machine you used? I’d like to reproduce the problem to dive into the root cause if possible.
Also this seems similar to this issue

Hi, here is my network description. The implementation resides in relay/testing/

def get_net(batch_size, image_shape, num_classes, dtype, batch_norm=False):

    data_shape = (batch_size,) + image_shape
    data = relay.var("data", shape=data_shape, dtype=dtype)

    feature = wrapper.conv2d(data=data, channels=64, kernel_size=(11,11), strides=(4,4), padding=(0,0), name="conv1")
    feature = relay.nn.bias_add(feature, relay.var("conv1_bias"))
    feature = relay.nn.relu(data=feature)
    feature = relay.nn.max_pool2d(data=feature, pool_size=(5, 5), strides=(1, 1), padding=(2,2))

    feature = wrapper.conv2d(data=feature, channels=192, kernel_size=(5,5), strides=(1,1), padding=(2,2), name="conv2")
    feature = relay.nn.bias_add(feature, relay.var("conv2_bias"))
    feature = relay.nn.relu(data=feature)
    feature = relay.nn.max_pool2d(data=feature, pool_size=(3, 3), strides=(2, 2), padding=(0,0))

    feature = wrapper.conv2d(data=feature, channels=384, kernel_size=(3,3), strides=(1,1), padding=(1,1), name="conv3")
    feature = relay.nn.bias_add(feature, relay.var("conv3_bias"))
    feature = relay.nn.relu(data=feature)
    feature = wrapper.conv2d(data=feature, channels=384, kernel_size=(3,3), strides=(1,1), padding=(1,1), name="conv4")
    feature = relay.nn.bias_add(feature, relay.var("conv4_bias"))
    feature = relay.nn.relu(data=feature)

    feature = wrapper.conv2d(data=feature, channels=256, kernel_size=(3,3), strides=(1,1), padding=(1,1), name="conv5")
    feature = relay.nn.bias_add(feature, relay.var("conv5_bias"))
    feature = relay.nn.relu(data=feature)
    feature = relay.nn.max_pool2d(data=feature, pool_size=(3, 3), strides=(2, 2), padding=(0,0))

    flatten = relay.nn.batch_flatten(data=feature)
    fc6 = wrapper.dense_add_bias(data=flatten, units=4096, name="fc6")
    fc7 = wrapper.dense_add_bias(data=fc6, units=4096, name="fc7")
    fc8 = wrapper.dense_add_bias(data=fc7, units=num_classes, name="fc8")

    args = relay.analysis.free_vars(fc8)
    return relay.Function(args, fc8)

This is how I call the tuner and I’m running on an Intel Xeon Platinum 8280 (CascadeLake).

target = "llvm -mcpu=skylake-avx512"
dtype = "float32"
model_name = "custom_net"
input_name = "data"
log_file = "%s.log" % model_name
graph_opt_sch_file = "%s_graph_opt.log" % model_name
input_name = "data"

num_threads = 56
os.environ["TVM_NUM_THREADS"] = str(num_threads)

batch_size =128

tuning_option = {
    'log_filename': log_file,
    'tuner': 'random',
    'early_stopping': 100,

    'measure_option': autotvm.measure_option(
        builder=autotvm.LocalBuilder(),
        runner=autotvm.LocalRunner(number=10, repeat=1, timeout=1000),
    ),
}

def tune_kernels(tasks,
                 measure_option,
                 tuner='gridsearch',
                 early_stopping=None,
                 log_filename='tuning.log'):

    for i, tsk in enumerate(tasks):
        prefix = "[Task %2d/%2d] " % (i+1, len(tasks))

        op_name = tsk.workload[0]
        if op_name == 'conv2d':
            func_create = 'topi_x86_conv2d_NCHWc'
        elif op_name == 'depthwise_conv2d_nchw':
            func_create = 'topi_x86_depthwise_conv2d_NCHWc_from_nchw'
        else:
            raise ValueError("Tuning {} is not supported on x86".format(op_name))

        task = autotvm.task.create(func_create, args=tsk.args,
                                   target=target, template_key='direct')
        task.workload = tsk.workload

        if tuner == 'xgb' or tuner == 'xgb-rank':
            tuner_obj = XGBTuner(task, loss_type='rank')
        elif tuner == 'ga':
            tuner_obj = GATuner(task, pop_size=50)
        elif tuner == 'random':
            tuner_obj = RandomTuner(task)
        elif tuner == 'gridsearch':
            tuner_obj = GridSearchTuner(task)
        else:
            raise ValueError("Invalid tuner: " + tuner)

        n_trial=len(task.config_space)
        tuner_obj.tune(n_trial=n_trial,
                       early_stopping=early_stopping,
                       measure_option=measure_option,
                       callbacks=[
                           autotvm.callback.progress_bar(n_trial, prefix=prefix),
                           autotvm.callback.log_to_file(log_filename)])

def tune_graph(graph, dshape, records, opt_sch_file, use_DP=False):
    target_op = [relay.nn.conv2d]
    Tuner = DPTuner if use_DP else PBQPTuner
    executor = Tuner(graph, {input_name: dshape}, records, target_op, target)
    executor.benchmark_layout_transform(min_exec_num=10)
    executor.run()
    executor.write_opt_sch2record_file(opt_sch_file)

def tune_and_evaluate(tuning_opt):
    print("Extract tasks...")
    mod, params, data_shape, out_shape = get_network(model_name, batch_size)
    tasks = autotvm.task.extract_from_program(mod["main"], target=target,  params=params, ops=(relay.op.nn.conv2d,))
    print("Tuning...")
    tune_kernels(tasks, **tuning_opt)
    tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file)

    with autotvm.apply_graph_best(graph_opt_sch_file):
        print("Compile...")
        with relay.build_config(opt_level=3):
            graph, lib, params = relay.build_module.build(
                mod, target=target, params=params)
     
        ctx = tvm.cpu()
        data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype))
        module = runtime.create(graph, lib, ctx)
        module.set_input(input_name, data_tvm)
        module.set_input(**params)

        print("Evaluate inference time cost...")
        ftimer = module.module.time_evaluator("run", ctx, number=5, repeat=3)
        prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
              (np.mean(prof_res), npzO.std(prof_res)))


tune_and_evaluate(tuning_option)

Thanks,

After some code tracing here is my understanding to your questions. Anyone please corrects me if I missed anything.

First of all, this warning should not matter, and you should still achieve the high performance tuned by AutoTVM (if not, then it is probably another issue). Let me first briefly explain the workflow that causes the warning and why it doesn’t matter. Note that it is also related to the second question.

When constructing the graph tuner (e.g., executor = Tuner(graph, {input_name: dshape}, records, target_op, target)), the constrcutor invokes expr2graph to traverse the graph for the later tuning process. The implementation of graph traverser is the function _expr2graph_impl. It builds the module with “tracing” target to trace all calls to TOPI (code). That’s the answer to your second question.

In addition, when building the module, TVM creates a new context ApplyHistoryBest and loads your tuning log for the best config. Apparently, your search history definitely does not have a config for “tracing” device. In this case, TVM falls back to the default context, which is FallbackContext and pops out the message you’ve seen (code). I do agree that this warning is confusing and we should remove it.

Thanks for the clarification. Do you know if AutoTVM is benchmarked on llvm -mcpu=skylake-avx512 backend so I can compare the results I’m getting?

I think it should be, but since AutoTVM just find a config with decent performance in the search space, you can actually dump the lowered function with the best config applied and see if that makes sense. In case it has a performance gap to the optimal, you could check either the optimal config doesn’t be covered in the search space, or the search time is too short for AutoTVM to figure out. It would help to further improve AutoTVM in any case.