[SOLVED]X86 autotvm: NCHW -> NHCWc conversion seems to cause config mismatch

I’m experiencing an issue this time that I believe may be happening to others as well.

I am using the following pytorch network that is later exported to onnx:

net = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, dilation=(1,1)),
                    nn.ReLU(),
                    nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, dilation=(1,1)),
                    nn.ReLU(),
                    nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, dilation=(1,1)))

When running it through the example at https://docs.tvm.ai/tutorials/autotvm/tune_nnvm_x86.html, it runs fine until the moment of compiling the final network, where I get

WARNING:autotvm:Cannot find config for target=llvm -mcpu=core-avx2, 
    workload=('conv2d_NCHWc', (1, 16, 608, 608, 8, 'float32'), (8, 16, 3, 3, 8, 8, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW8c', 'NCHW8c', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm -mcpu=core-avx2,
    workload=('conv2d_NCHWc', (1, 16, 610, 610, 8, 'float32'), (16, 16, 3, 3, 8, 8, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW8c', 'NCHW8c', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm -mcpu=core-avx2,
    workload=('conv2d_NCHWc', (1, 1, 612, 612, 3, 'float32'), (16, 1, 3, 3, 3, 8, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW3c', 'NCHW8c', 'float32'). A fallback configuration is used, which may bring great performance regression.

This seems weird at first, but looking at the .log file I see:

{"r": [[0.715419564], 0, 3.2787282466888428, 1551285782.4269547], "i": ["llvm -mcpu=core-avx2", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 128, 608, 608], "float32"], 
    ["TENSOR", [64, 128, 3, 3], "float32"], [1, 1], [0, 0], [1, 1], "NCHW", "float32"], {}, ["conv2d", [1, 128, 608, 608, "float32"], [64, 128, 3, 3, "float32"], [1, 1], [0, 0], [1, 1], "NCHW", "float32"], {"t": "direct", "e": [["tile_ic", "sp", [32, 4]], ["tile_oc", "sp", [4, 16]], ["tile_ow", "sp", [101, 6]], ["unroll_kw", "ot", true]], "c": null, "i": 202}], "v": 0.1}
{"r": [[0.991738546], 0, 4.449911117553711, 1551285848.6736164], "i": ["llvm -mcpu=core-avx2", 
    "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 128, 610, 610], "float32"], ["TENSOR", [128, 128, 3, 3], "float32"], [1, 1], [0, 0], [1, 1], "NCHW", "float32"], {}, ["conv2d", [1, 128, 610, 610, "float32"], [128, 128, 3, 3, "float32"], [1, 1], [0, 0], [1, 1], "NCHW", "float32"], {"t": "direct", "e": [["tile_ic", "sp", [8, 16]], ["tile_oc", "sp", [8, 16]], ["tile_ow", "sp", [152, 4]], ["unroll_kw", "ot", false]], "c": null, "i": 676}], "v": 0.1}
{"r": [[0.03143428715625], 0, 1.458106517791748, 1551285879.0393698], "i": ["llvm -mcpu=core-avx2",
    "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 3, 612, 612], "float32"], ["TENSOR", [128, 3, 3, 3], "float32"], [1, 1], [0, 0], [1, 1], "NCHW", "float32"], {}, ["conv2d", [1, 3, 612, 612, "float32"], [128, 3, 3, 3, "float32"], [1, 1], [0, 0], [1, 1], "NCHW", "float32"], {"t": "direct", "e": [["tile_ic", "sp", [1, 3]], ["tile_oc", "sp", [2, 64]], ["tile_ow", "sp", [610, 1]], ["unroll_kw", "ot", false]], "c": null, "i": 93}], "v": 0.1}

My understanding is that the task args passed to autotvm.task.create are in the original NCHW ordering, while during the final optimization phase the TVM optimizer tries to find a task with args in NCHWc format, thus failing. (check the indented lines above)

If I change the optimization level to opt_level=2, then it passes without warnings.

Is there a way to extract the tasks already in NCHWc order? Am I missing something important that is causing this warning?

I am attaching a ready-to-run example.
Using the onnx file here and the code below:

import onnx

import os
import numpy as np

import tvm
from tvm import autotvm
from tvm import relay
from tvm.relay import testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
import tvm.contrib.graph_runtime as runtime

def get_network(name):
    """Get the symbol definition and random weight of a network"""
    onnx_model = onnx.load('example_net.onnx')

    input_name = 'input'
    input_shape = (1, 3, 612, 612)

    shape_dict = {input_name: input_shape}
    net, params = relay.frontend.from_onnx(onnx_model, shape_dict)

    return net, params, input_shape

target = "llvm -mcpu=core-avx2"

dtype = "float32"
model_name = "optlog"
log_file = "%s.log" % model_name

# Set number of threads used for tuning based on the number of
# physical cpu cores on your machine.
num_threads = 1
os.environ["TVM_NUM_THREADS"] = str(num_threads)

tuning_option = {
    'log_filename': log_file,
    'tuner': 'random',
    'early_stopping': 1,  # make it stop early so we don't wait too long

    'measure_option': autotvm.measure_option(
        builder=autotvm.LocalBuilder(),
        runner=autotvm.LocalRunner(number=1, repeat=1,  # these values should probably be higher to get statistically significant measurements
                                   min_repeat_ms=1000),
    ),
}

# You can skip the implementation of this function for this tutorial.
def tune_kernels(tasks,
                 measure_option,
                 tuner='gridsearch',
                 early_stopping=None,
                 log_filename='tuning.log'):

    # create tmp log file
    tmp_log_file = log_filename + ".tmp"
    if os.path.exists(tmp_log_file):
        os.remove(tmp_log_file)

    for i, tsk in enumerate(tasks):
        prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
        print("TASK ARGS:", tsk.args)  # debug

        # converting conv2d tasks to conv2d_NCHWc tasks
        op_name = tsk.workload[0]
        if op_name == 'conv2d':
            func_create = 'topi_x86_conv2d_NCHWc'
        elif op_name == 'depthwise_conv2d_nchw':
            func_create = 'topi_x86_depthwise_conv2d_NCHWc_from_nchw'
        else:
            raise ValueError("Tuning {} is not supported on x86".format(op_name))

        task = autotvm.task.create(func_create, args=tsk.args,
                                   target=target, template_key='direct')
        task.workload = tsk.workload

        # create tuner
        if tuner == 'xgb' or tuner == 'xgb-rank':
            tuner_obj = XGBTuner(task, loss_type='rank')
        elif tuner == 'ga':
            tuner_obj = GATuner(task, pop_size=50)
        elif tuner == 'random':
            tuner_obj = RandomTuner(task)
        elif tuner == 'gridsearch':
            tuner_obj = GridSearchTuner(task)
        else:
            raise ValueError("Invalid tuner: " + tuner)

        # do tuning
        n_trial=len(task.config_space)
        tuner_obj.tune(n_trial=n_trial,
                       early_stopping=early_stopping,
                       measure_option=measure_option,
                       callbacks=[
                           autotvm.callback.progress_bar(n_trial, prefix=prefix),
                           autotvm.callback.log_to_file(tmp_log_file)])

    # pick best records to a cache file
    autotvm.record.pick_best(tmp_log_file, log_filename)
    os.remove(tmp_log_file)

def tune_and_evaluate(tuning_opt):
    # extract workloads from relay program
    print("Extract tasks...")
    net, params, data_shape = get_network(model_name)
    tasks = autotvm.task.extract_from_program(net, target=target,
                                              params=params, ops=(relay.op.nn.conv2d,))

    # run tuning tasks
    print("Tuning...", len(tasks))
    tune_kernels(tasks, **tuning_opt)

    # compile kernels with history best records
    with autotvm.apply_history_best(log_file):
        print("Compile...")
        with relay.build_config(opt_level=3):
            graph, lib, params = relay.build_module.build(
                net, target=target,  params=params)

        # upload parameters to device
        ctx = tvm.cpu()
        data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype))
        module = runtime.create(graph, lib, ctx)
        module.set_input('input', data_tvm)
        module.set_input(**params)

        # evaluate
        print("Evaluate inference time cost...")
        ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=1)
        prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
              (np.mean(prof_res), np.std(prof_res)))                                    # Gives 1.8 seconds

tune_and_evaluate(tuning_option)

Output:

TASK ARGS: (('TENSOR', (1, 128, 608, 608), 'float32'), ('TENSOR', (64, 128, 3, 3), 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32')
[Task  1/ 3]  Current/Best:   96.45/  96.45 GFLOPS | Progress: (8/448) | 47.03 s Done.
TASK ARGS: (('TENSOR', (1, 128, 610, 610), 'float32'), ('TENSOR', (128, 128, 3, 3), 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32')
[Task  2/ 3]  Current/Best:   52.34/  69.36 GFLOPS | Progress: (8/1024) | 57.68 s Done.
TASK ARGS: (('TENSOR', (1, 3, 612, 612), 'float32'), ('TENSOR', (128, 3, 3, 3), 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32')
[Task  3/ 3]  Current/Best:   27.87/  96.55 GFLOPS | Progress: (8/160) | 13.83 s Done.
Compile...
WARNING:autotvm:Cannot find config for target=llvm -mcpu=core-avx2, workload=('conv2d', (1, 3, 612, 612, 'float32'), (128, 3, 3, 3, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', ''). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm -mcpu=core-avx2, workload=('conv2d', (1, 128, 610, 610, 'float32'), (128, 128, 3, 3, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', ''). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm -mcpu=core-avx2, workload=('conv2d', (1, 128, 608, 608, 'float32'), (64, 128, 3, 3, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', ''). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm -mcpu=core-avx2, workload=('conv2d_NCHWc', (1, 16, 608, 608, 8, 'float32'), (8, 16, 3, 3, 8, 8, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW8c', 'NCHW8c', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm -mcpu=core-avx2, workload=('conv2d_NCHWc', (1, 16, 610, 610, 8, 'float32'), (16, 16, 3, 3, 8, 8, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW8c', 'NCHW8c', 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=llvm -mcpu=core-avx2, workload=('conv2d_NCHWc', (1, 1, 612, 612, 3, 'float32'), (16, 1, 3, 3, 3, 8, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW3c', 'NCHW8c', 'float32'). A fallback configuration is used, which may bring great performance regression.
Evaluate inference time cost...

Thanks for the post, I will take a look at this.

I think this is a subtle issue caused by a dtype not being propagated during the alter_layout_pass. Can you see if PR #2707 fixes this issue for you?

Hi Eqy. Thanks, that works!

1 Like