Traceback (most recent call last):
File "compile_onnx.py", line 147, in <module>
graph, lib, params= tune_and_evaluate(tuning_option,mod,params,input_shape,target=target)
File "compile_onnx.py", line 96, in tune_and_evaluate
# tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file, False) # turn graph
File "compile_onnx.py", line 80, in tune_graph
executor.benchmark_layout_transform(min_exec_num=2000)
File "/home/qqai-cv/yexing/my_python/lib/python3.6/site-packages/tvm/autotvm/graph_tuner/base_graph_tuner.py", line 431, in benchmark_layout_transform
self._iterate_layout_transform(_fetch_args_callback)
File "/home/qqai-cv/yexing/my_python/lib/python3.6/site-packages/tvm/autotvm/graph_tuner/base_graph_tuner.py", line 269, in _iterate_layout_transform
i_topi_op = in_node_entry["topi_op"][0]
KeyError: 'topi_op'
Hi, i have meet the KeyError when I try to tune a new model SINet (https://arxiv.org/abs/1911.09099) in the step of tune_graph after tune_kernels, following the instructure.
I have successfully transfer the model from pytorch to ONNX (i tried opset_version=10 with nearest upsample op but get error on compile step by relay.frontend.from_onnx. But compile success with opset_version=11). I also evaluated that the result is same on pytorch, onnxruntime and tvm, shown the transformation from pytorch to onnx to IR is no error. The environment info is shown behind: System: Ubuntu 18.04 tvm version: 0.7 dev1 python: 3.6.9 pytorch: 1.4.0+cu100 onnx: 1.4.0 model (.onnx file):
my code (hardly change the code from tutorial)
import numpy as np
import tvm, os
from tvm import relay
import cv2
from tvm.contrib import graph_runtime as runtime
import onnx
import my_utils
from tvm import autotvm
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
from tvm.contrib import util
num_threads = 6
os.environ["TVM_NUM_THREADS"] = str(num_threads)
model_name = "Dnc_SINet"
img_path = "../test_img/0.png"
model_file = "../Dnc_SINet.onnx"
w = 320
h = 256
batch_size = 1
dtype = 'float32'
target = "llvm"
target_host = 'llvm'
input_name = "input.1"
output_name = "962"
input_shape = (batch_size,3,h,w)
ctx = tvm.cpu(0)
log_file = "%s.log" % model_name
graph_opt_sch_file = "%s_graph_opt.log" % model_name
tuning_option = {
'log_filename': log_file,
'tuner': 'random',
'early_stopping': None,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(number=10, repeat=1,
min_repeat_ms=1000),
),
}
# function for turn kernels
def tune_kernels(tasks,measure_option,
tuner='gridsearch',
early_stopping=None,
log_filename='tuning.log'):
for i, task in enumerate(tasks):
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
tuner_obj = XGBTuner(task, loss_type='rank')
elif tuner == 'ga':
tuner_obj = GATuner(task, pop_size=50)
elif tuner == 'random':
tuner_obj = RandomTuner(task)
elif tuner == 'gridsearch':
tuner_obj = GridSearchTuner(task)
else:
raise ValueError("Invalid tuner: " + tuner)
# do tuning
n_trial=len(task.config_space)
tuner_obj.tune(n_trial=n_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(log_filename)])
# Use graph tuner to achieve graph level optimal schedules
# Set use_DP=False if it takes too long to finish.
def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
target_op = [relay.op.get("nn.conv2d"),]
Tuner = DPTuner if use_DP else PBQPTuner
print(dshape)
executor = Tuner(graph, {input_name: dshape}, records, target_op, target)
executor.benchmark_layout_transform(min_exec_num=2000)
executor.run()
executor.write_opt_sch2record_file(opt_sch_file)
def tune_and_evaluate(tuning_opt,mod, params, data_shape,target = "llvm"):
# extract workloads from relay program
print("Extract tasks...")
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.get("nn.conv2d"),))
# run tuning tasks
tune_kernels(tasks, **tuning_opt) # turn kernel
"""some bug on tune_graph for SINet:
File "/home/qqai-cv/yexing/my_python/lib/python3.6/site-packages/tvm/autotvm/graph_tuner/base_graph_tuner.py", line 269, in _iterate_layout_transform
i_topi_op = in_node_entry["topi_op"][0]
KeyError: 'topi_op'
"""
tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file, False) # turn graph
# compile kernels with graph-level best records
with autotvm.apply_graph_best(graph_opt_sch_file):
print("Compile...")
with relay.build_config(opt_level=4):
graph, lib, params = relay.build_module.build(
mod, target=target, params=params)
# upload parameters to device
ctx = tvm.cpu(0)
data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype))
module = runtime.create(graph, lib, ctx)
module.set_input(input_name, data_tvm)
module.set_input(**params)
# evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=100, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
return graph, lib, params
#######################################################################
# load test image
img = cv2.imread(img_path)
input = my_utils.preprocess(img,w,h)
print("input shape:{}".format(input.shape))
input_array = tvm.nd.array(input.astype(dtype))
# load onnx model
onnx_model = onnx.load(model_file)
# compile model by relay
shape_dict = {input_name: input_shape}
mod,params = relay.frontend.from_onnx(onnx_model,shape_dict,dtype)
meta_file = "./Dnc_SINet.meta"
mf = open(meta_file,'w')
print(mod.astext(show_meta_data=False),file=mf)
mf.close()
print("#######################################################################")
# Auto-tune
graph, lib, params= tune_and_evaluate(tuning_option,mod,params,input_shape,target=target)
# save the graph, lib and params into separate files
temp = util.tempdir("./model_x86")
path_lib = temp.relpath("DncSINet_lib.tar")
lib.export_library(path_lib)
with open(temp.relpath("DncSINet_graph.json"), "w") as fo:
fo.write(graph)
with open(temp.relpath("DncSINet_param.params"), "wb") as fo:
fo.write(relay.save_param_dict(params))
print(temp.listdir())
# load the module back.
loaded_graph = open(temp.relpath("DncSINet_graph.json")).read()
loaded_lib = tvm.runtime.load_module(path_lib)
loaded_params = bytearray(open(temp.relpath("DncSINet_param.params"), "rb").read())
module = runtime.create(loaded_graph, loaded_lib, ctx)
module.load_params(loaded_params)
module.set_input(input_name, input_array)
module.run()
tvm_output = module.get_output(0)
print(tvm_output)
# onnxruntime for evaluation
import onnxruntime as ort
from onnxruntime.capi import _pybind_state as C
so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
so.execution_mode = ort.ExecutionMode.ORT_PARALLEL
so.intra_op_num_threads = 2
so.inter_op_num_threads = 2
ort_sess = ort.InferenceSession("../Dnc_SINet.onnx",sess_options=so)
print(C.get_available_providers())
ort_sess.set_providers(["CPUExecutionProvider"])
onnx_out = ort_sess.run(None, {'input.1': input})
print("ONNX_output")
print(onnx_out)
print()
Further, the tune_kernel step seems run without erro and output the log Dnc_SINet.log
{"input": ["llvm", "conv2d_NCHWc.x86", [["TENSOR", [1, 2, 256, 320], "float32"], ["TENSOR", [2, 2, 3, 3], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}], "config": {"index": 54, "code_hash": null, "entity": [["tile_ic", "sp", [-1, 1]], ["tile_oc", "sp", [-1, 2]], ["tile_ow", "sp", [-1, 2]], ["unroll_kw", "ot", false]]}, "result": [[0.00013595731521333498], 0, 1.8156261444091797, 1590487913.6498568], "version": 0.2, "tvm_version": "0.7.dev1"}
{"input": ["llvm", "conv2d_NCHWc.x86", [["TENSOR", [1, 2, 256, 320], "float32"], ["TENSOR", [2, 2, 3, 3], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}], "config": {"index": 10, "code_hash": null, "entity": [["tile_ic", "sp", [-1, 1]], ["tile_oc", "sp", [-1, 2]], ["tile_ow", "sp", [-1, 2]], ["unroll_kw", "ot", true]]}, "result": [[0.00013540219174497443], 0, 1.3245575428009033, 1590487914.9348269], "version": 0.2, "tvm_version": "0.7.dev1"}
{"input": ["llvm", "conv2d_NCHWc.x86", [["TENSOR", [1, 2, 256, 320], "float32"], ["TENSOR", [2, 2, 3, 3], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}], "config": {"index": 70, "code_hash": null, "entity": [["tile_ic", "sp", [-1, 1]], ["tile_oc", "sp", [-1, 2]], ["tile_ow", "sp", [-1, 10]], ["unroll_kw", "ot", false]]}, "result": [[0.00015527120916877555], 0, 3.4014461040496826, 1590487918.2519333], "version": 0.2, "tvm_version": "0.7.dev1"}
{"input": ["llvm", "conv2d_NCHWc.x86", [["TENSOR", [1, 2, 256, 320], "float32"], ["TENSOR", [2, 2, 3, 3], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}], "config": {"index": 6, "code_hash": null, "entity": [["tile_ic", "sp", [-1, 1]], ["tile_oc", "sp", [-1, 2]], ["tile_ow", "sp", [-1, 1]], ["unroll_kw", "ot", true]]}, "result": [[0.0001641825172311788], 0, 1.314802885055542, 1590487919.5376344], "version": 0.2, "tvm_version": "0.7.dev1"}
{"input": ["llvm", "conv2d_NCHWc.x86", [["TENSOR", [1, 2, 256, 320], "float32"], ["TENSOR", [2, 2, 3, 3], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}], "config": {"index": 36, "code_hash": null, "entity": [["tile_ic", "sp", [-1, 1]], ["tile_oc", "sp", [-1, 1]], ["tile_ow", "sp", [-1, 32]], ["unroll_kw", "ot", true]]}, "result": [[0.00014663836520854528], 0, 3.4310293197631836, 1590487922.846251], "version": 0.2, "tvm_version": "0.7.dev1"}
{"input": ["llvm", "conv2d_NCHWc.x86", [["TENSOR", [1, 2, 256, 320], "float32"], ["TENSOR", [2, 2, 3, 3], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}], "config": {"index": 9, "code_hash": null, "entity": [["tile_ic", "sp", [-1, 2]], ["tile_oc", "sp", [-1, 1]], ["tile_ow", "sp", [-1, 2]], ["unroll_kw", "ot", true]]}, "result": [[0.00018114181677483774], 0, 1.3032824993133545, 1590487924.0816596], "version": 0.2, "tvm_version": "0.7.dev1"}
{"input": ["llvm", "conv2d_NCHWc.x86", [["TENSOR", [1, 2, 256, 320], "float32"], ["TENSOR", [2, 2, 3, 3], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}], "config": {"index": 68, "code_hash": null, "entity": [["tile_ic", "sp", [-1, 1]], ["tile_oc", "sp", [-1, 1]], ["tile_ow", "sp", [-1, 10]], ["unroll_kw", "ot", false]]}, "result": [[0.00011425483953435289], 0, 1.730971097946167, 1590487925.7408793], "version": 0.2, "tvm_version": "0.7.dev1"}
{"input": ["llvm", "conv2d_NCHWc.x86", [["TENSOR", [1, 2, 256, 320], "float32"], ["TENSOR", [2, 2, 3, 3], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}], "config": {"index": 81, "code_hash": null, "entity": [["tile_ic", "sp", [-1, 2]], ["tile_oc", "sp", [-1, 1]], ["tile_ow", "sp", [-1, 32]], ["unroll_kw", "ot", false]]}, "result": [[0.00012996511115645148], 0, 3.3890671730041504, 1590487929.0349448], "version": 0.2, "tvm_version": "0.7.dev1"}
{"input": ["llvm", "conv2d_NCHWc.x86", [["TENSOR", [1, 2, 256, 320], "float32"], ["TENSOR", [2, 2, 3, 3], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}], "config": {"index": 37, "code_hash": null, "entity": [["tile_ic", "sp", [-1, 2]], ["tile_oc", "sp", [-1, 1]], ["tile_ow", "sp", [-1, 32]], ["unroll_kw", "ot", true]]}, "result": [[0.00011664882574600265], 0, 1.3587851524353027, 1590487930.3270526], "version": 0.2, "tvm_version": "0.7.dev1"}
{"input": ["llvm", "conv2d_NCHWc.x86", [["TENSOR", [1, 2, 256, 320], "float32"], ["TENSOR", [2, 2, 3, 3], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}], "config": {"index": 40, "code_hash": null, "entity": [["tile_ic", "sp", [-1, 1]], ["tile_oc", "sp", [-1, 1]], ["tile_ow", "sp", [-1, 40]], ["unroll_kw", "ot", true]]}, "result": [[0.00014529477851287214], 0, 1.8773386478424072, 1590487932.0300732], "version": 0.2, "tvm_version": "0.7.dev1"}
...
I check the correctness of shape of convolution in my IR, which are all NCHW (same as onnx and pytorch), so the reason of KeyError is different I have also check the log file and all is 0 but not 2/3/4.
Further, I also print the content of in_node_entry
keys: ['node', 'inputs', 'types', 'op', 'name']
node:
free_var %input.1: Tensor[(1, 3, 256, 320), float32]
free_var %encoder.level1.conv.weight: Tensor[(12, 3, 3, 3), float32]
%0 = nn.conv2d(%input.1, %encoder.level1.conv.weight, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 12, 128, 160), float32] */;
free_var %encoder.level1.bn.weight: Tensor[(12), float32]
free_var %encoder.level1.bn.bias: Tensor[(12), float32]
free_var %encoder.level1.bn.running_mean: Tensor[(12), float32]
free_var %encoder.level1.bn.running_var: Tensor[(12), float32]
%1 = nn.batch_norm(%0, %encoder.level1.bn.weight, %encoder.level1.bn.bias, %encoder.level1.bn.running_mean, %encoder.level1.bn.running_var, epsilon=0.001f) /* ty=(Tensor[(1, 12, 128, 160), float32], Tensor[(12), float32], Tensor[(12), float32]) */;
%2 = %1.0;
free_var %v953: Tensor[(12, 1, 1), float32]
%3 = reshape(%v953, meta[relay.Constant][0] /* ty=Tensor[(1), int32] */ /* ty=Tensor[(1), int32] */, newshape=[-1]) /* ty=Tensor[(12), float32] */;
%4 = nn.prelu(%2, %3) /* ty=Tensor[(1, 12, 128, 160), float32] */;
free_var %encoder.level2_0.conv.0.weight: Tensor[(12, 1, 3, 3), float32]
%5 = nn.conv2d(%4, %encoder.level2_0.conv.0.weight, strides=[2, 2], padding=[1, 1, 1, 1], groups=12, kernel_size=[3, 3]) /* ty=Tensor[(1, 12, 64, 80), float32] */;
%6 = nn.pad(%5, pad_width=[[0, 0], [0, 0], [0, 0], [0, 0]]) /* ty=Tensor[(1, 12, 64, 80), float32] */;
%7 = nn.avg_pool2d(%6, pool_size=[64, 80], strides=[64, 80], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 12, 1, 1), float32] */;
%8 = nn.batch_flatten(%7) /* ty=Tensor[(1, 12), float32] */;
%9 = nn.batch_flatten(%8) /* ty=Tensor[(1, 12), float32] */;
%10 = multiply(1f /* ty=float32 */, %9) /* ty=Tensor[(1, 12), float32] */;
free_var %encoder.level2_0.conv.1.dense.0.weight: Tensor[(12, 12), float32]
%11 = nn.dense(%10, %encoder.level2_0.conv.1.dense.0.weight, units=12) /* ty=Tensor[(1, 12), float32] */;
free_var %encoder.level2_0.conv.1.dense.0.bias: Tensor[(12), float32]
%12 = multiply(1f /* ty=float32 */, %encoder.level2_0.conv.1.dense.0.bias) /* ty=Tensor[(12), float32] */;
%13 = nn.bias_add(%11, %12) /* ty=Tensor[(1, 12), float32] */;
free_var %encoder.level2_0.conv.1.dense.1.weight: Tensor[(12), float32]
%14 = nn.prelu(%13, %encoder.level2_0.conv.1.dense.1.weight) /* ty=Tensor[(1, 12), float32] */;
%15 = reshape(%14, meta[relay.Constant][1] /* ty=Tensor[(4), int32] */ /* ty=Tensor[(4), int32] */, newshape=[1, 12, 1, 1]) /* ty=Tensor[(1, 12, 1, 1), float32] */;
multiply(%15, %5) /* ty=Tensor[(1, 12, 64, 80), float32] */
// meta data omitted. you can use show_meta_data=True to include meta data
inputs:
[[356, 0, 0], [344, 0, 0]]
types:
[TensorType([1, 12, 64, 80], float32)]
op:
multiply
name:
None
compare with right node with keys [‘node’, ‘inputs’, ‘types’, ‘op’, ‘name’, ‘topi_op’, ‘workloads’, ‘record_candidates’] , the “bug node” only has keys [‘node’, ‘inputs’, ‘types’, ‘op’, ‘name’]
I want to know what is the reason for this error?Or how can i fix it? I will try to compile the model from pytorch directly now
Here is the Model IR, which is too long to post, so i print in a log file.
It looks like multiply
is not recognized as mutiple input op. Can you dig into https://github.com/apache/incubator-tvm/blob/master/python/tvm/autotvm/graph_tuner/base_graph_tuner.py#L167 to see why it is not recognized?
In my case, Add is not recognized as a multiple-input operation as well.