[autoTVM] Crash when tuning "dense " on x86 cpu

I just modify the tutorial tune_relay_x86 [replace “nn.conv2d with "nn.dense”] to tune “dense”.

the model is vgg16, including 13 conv2Ds and 3 denses.

Q1: I run into issues when trying to run function tune_graph().

def tune_graph(graph, dshape, records, opt_sch_file, target, use_DP=True):
    print('\033[32;43m starting tune_graph \033[0m')
    target_op = [relay.op.get("nn.dense"),]
    Tuner = DPTuner if use_DP else PBQPTuner
    executor = Tuner(graph, {input_tensor: dshape}, records, target_op, target)
    executor.benchmark_layout_transform(min_exec_num=2000)
    executor.run()
    executor.write_opt_sch2record_file(opt_sch_file)

the error info are as follows

Cannot find config for target=llvm -keys=tracing,cpu -device=tracing, workload=('dense_nopack.x86', ('TENSOR', (1, 25088), 'float32'), ('TENSOR', (4096, 25088), 'float32'), None, 'float32'). Afallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -keys=tracing,cpu -device=tracing, workload=('dense_nopack.x86', ('TENSOR', (1, 4096), 'float32'), ('TENSOR', (4096, 4096), 'float32'), None, 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -keys=tracing,cpu -device=tracing, workload=('dense_nopack.x86', ('TENSOR', (1, 4096), 'float32'), ('TENSOR', (1000, 4096), 'float32'), None, 'float32'). A fallback configuration is used, which may bring great performance regression.
Traceback (most recent call last):
  File "run_low_level.py", line 211, in <module>
    compare(model_path)
  File "run_low_level.py", line 146, in compare
    tune_graph(irmod["main"], input_shape, log_file, graph_opt_sch_file, target=target)
  File "run_low_level.py", line 78, in tune_graph
    executor = Tuner(graph, {input_tensor: dshape}, records, target_op, target)
  File "/tensorflow/install_env/tvm-0.7/tvm/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py", line 43, in __init__
    super(DPTuner, self).__init__(*args, **kwargs)
  File "/tensorflow/install_env/tvm-0.7/tvm/python/tvm/autotvm/graph_tuner/base_graph_tuner.py", line 156, in __init__
    self._fetch_cfg()
  File "/tensorflow/install_env/tvm-0.7/tvm/python/tvm/autotvm/graph_tuner/base_graph_tuner.py", line 215, in _fetch_cfg
    infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
  File "/tensorflow/install_env/tvm-0.7/tvm/python/tvm/autotvm/graph_tuner/base_graph_tuner.py", line 44, in get_infer_layout
    raise ValueError("Cannot find infer layout for task %s" % task_name)
ValueError: Cannot find infer layout for task dense_nopack.x86

Q2: when I try to debug this error, use autotvm.record.pick_best(log_file,graph_opt_sch_file), and delete the function tune_graph(), the best schedule log can be generated right. However, another error occurs. error info show below:

Config for target=llvm -keys=cpu -mcpu=cascadelake, workload=None is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
Traceback (most recent call last):
  File "run_low_level.py", line 211, in <module>
    compare(model_path)
  File "run_low_level.py", line 154, in compare
    irmod, target=target, params=params)
  File "/tensorflow/install_env/tvm-0.7/tvm/python/tvm/relay/build_module.py", line 255, in build
    graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
  File "/tensorflow/install_env/tvm-0.7/tvm/python/tvm/relay/build_module.py", line 121, in build
    self._build(mod, target, target_host)
  File "/tensorflow/install_env/tvm-0.7/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 225, in __call__
    raise get_last_ffi_error()
AttributeError: Traceback (most recent call last):
  [bt] (8) /tensorflow/install_env/tvm-0.7/tvm/build/libtvm.so(tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)+0x70) [0x7f15ecba5940]
  [bt] (7) /tensorflow/install_env/tvm-0.7/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x62) [0x7f15ecbacb02]
  [bt] (6) /tensorflow/install_env/tvm-0.7/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)#6}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)+0x14) [0x7f15ecba7a94]
  [bt] (5) /tensorflow/install_env/tvm-0.7/tvm/build/libtvm.so(tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*)+0x37) [0x7f15eca16807]
  [bt] (4) /tensorflow/install_env/tvm-0.7/tvm/build/libtvm.so(tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&)+0x103b) [0x7f15eca46dab]
  [bt] (3) /tensorflow/install_env/tvm-0.7/tvm/build/libtvm.so(tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const+0x2bc) [0x7f15ec9c17cc]
  [bt] (2) /tensorflow/install_env/tvm-0.7/tvm/build/libtvm.so(tvm::RelayExpr tvm::relay::LayoutRewriter<tvm::relay::alter_op_layout::AlterTransformMemorizer>(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)+0xbe8) [0x7f15ec9c9748]
  [bt] (1) /tensorflow/install_env/tvm-0.7/tvm/build/libtvm.so(tvm::relay::alter_op_layout::AlterTransformMemorizer::CallWithNewLayouts(tvm::relay::Call const&, std::vector<tvm::RelayExpr, std::allocator<tvm::RelayExpr> > const&)+0x414) [0x7f15ec9c6e44]
  [bt] (0) /tensorflow/install_env/tvm-0.7/tvm/build/libtvm.so(+0x2062b49) [0x7f15ec14cb49]
  File "/tensorflow/install_env/tvm-0.7/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun
    rv = local_pyfunc(*pyargs)
  File "/tensorflow/install_env/tvm-0.7/tvm/python/tvm/relay/op/nn/_nn.py", line 98, in alter_op_layout_conv2d
    return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
  File "<decorator-gen-53>", line 2, in conv2d_alter_layout
  File "/tensorflow/install_env/tvm-0.7/tvm/python/tvm/target/generic_func.py", line 267, in dispatch_func
    return dispatch_dict[k](*args, **kwargs)
  File "/tensorflow/install_env/tvm-0.7/tvm/topi/python/topi/x86/conv2d_alter_op.py", line 44, in _alter_conv2d_layout
    workload = cfg.workload
AttributeError: 'FallbackConfigEntity' object has no attribute 'workload'

How to solve those bugs, thanks in advance.

1 Like

For the first error, I think you cannot include dense in the graph tuner as it is only used for conv2d. You could to remove the dense records and try again (cc @kevinthesun ).

For the second question, if you turned off the graph tuner, you should use ApplyHistoryBest instead of ApplyGraphBest .

Thanks!!! Problem solved.

“I think you cannot include dense in the graph tuner as it is only used for conv2d”, it means in tune_graph() only supported in nn.conv2d, is that right? If it is nn.conv3d, what about it? Thank you!

I don’t think it supports nn.conv3d.

Yeah, when I’m trying to tune nn.conv3d it will report an error here. But it seems that TVM has written some “conv3d” functions like “conv2d”. Can I implement “conv3d_infer_layout” according to “conv2d_ infer_ layout” will allow the program to run normally? And it seems like a lot of work.Thanks!

If you need to support conv3d layout tuning in general, you need to not only have infer layout but also have the corresponding compute.

For example, conv2d has an addition compute “NCHWc” other than NCHW. Since you can choose any x in NCHW[x]c for the first conv2d, and any y in NCHW[y]c in the second conv2d, we need graph tuner to optimize the end-to-end performance of NCHW[x]c -> transform(x, y) -> NCHW[y]c. In short, you’re right this would be a lot of work.

It seems difficult for me to write some “conv3d” functions same as “conv2d”, can I change my 3x3x3 convolution kernel into 3x1x1 and 1x3x3x3? But it is still used “nn.conv3d” in the network. And i may still need to modify some functions. Do you have any other good suggestions? AndI have another question to ask. I used the from_ pytorch.py in the official websitehttps://tvm.apache.org/docs/tutorials/frontend/from_pytorch.html#sphx-glr-tutorials-frontend-from-pytorch-py. The model can run normally. For llvm backend, the code after TVM has a certain speed improvement. But for CUDA backend, i can see that the code is using GPU on NVIDIA GPU, but the running speed may be as slow as using CPU. Do you have any suggestions on this issue? Can I understand that, “tvm.transform.PassContext”image It is just a simple optimization. The optimization method can be seen in the source code, such as “Alteroplayout()”. However, whether the running speed of the model can be improved after this optimization is not guaranteed. Autotvm is equivalent to searching for the best combination among these optimization combinations, can I understand autotvm like that.Thanks!

But for CUDA backend, i can see that the code is using GPU on NVIDIA GPU, but the running speed may be as slow as using CPU

That’s why you need tuning. Note that tuning on GPU doesn’t need graph tuner so there’s no problem to tune conv3d on GPUs.

Autotvm is equivalent to searching for the best combination among these optimization combinations

Not exactly as @sqchao commented.

I see, thank you very much!

I have a confused problem that I have implemented the NCDHW[x]c required by autotvm in the way of conv2d for my conv3d graph_tune, and can normally use tune_ relay_ X86.py runs out of Mean inference time (std dev):, but when I save mod_lib as a .so file and use a separate test script to load it, I can’t get the result normally. I can’t get the result in model.run () throw an exception: Process finished with exit code 139 (interrupted by signal 11: SIGSEGV). My puzzle is why it can be used normally in tune_relay_x86.py, but it can’t be executed normally if it is saved as a .so file separately. Do you have any good suggestions on this problem? thank you!

And there is another problem, that is, if I only use several conv3d convolutions, it does not involve the size change of the feature map, that is to say, the multiple conv3d convolutions are convoluted on the same size DHW, and only C is changing. I save this model as a .so file, and then load it separately, and it can load normally.

It might be due to some mismatches from graph tuner, but I have no concrete clue with the information you provided. You might need to provide more details or even a reproducible example for people to dive into.

I’ll sort it out later about the code that support conv3d in tune_graph error examples, do I need to provide all the modified code,such as: computing of NCDHW[x]c , infer_ layout, the Torch network and all the test code etc,in order to better reproduce the error? Thanks!

It depends on you. Of course if you provide all required materials to reproduce the error then it’s easier for people to jump in for help. However, since you’re making a new feature instead of debugging, it usually needs more time for people to figure out the details, which may be impractical for most people with limited bandwidth and you may get nothing in the end.

Thus, you need to judge if you’re looking for someone to fully understand your changes and perfectly solve your problems, or you’re expecting someone that could address a well-summarized problem you formulated from the errors with just a few minutes.

I uploaded my modified code to my GitHub which can be found in: GitHub - aiblackmaner/tvm. My modified code can be found in the python/tvm/topi/x86/conv3d.py the codes marked with # modify were modified by me, and I put my tested code into the tests/test_3d folder. There are mainly four files in tests/test_3d, network.py includes the main network structure,load_test.py includes the test code for loading the saved .so file, from_pytorch_3d.py is same as the tvm official website example from_pytorch.py but it is for my 3d network, and the same with the tune_realy_x86_3d.py.

The test environment of my computer is:Ubuntu 16.04 x86_64, Torch 1.4.0, TorchVision 0.5.0 , clang+llvm-9.0.0-x86_64, and the commit id i pulled from https://github.com/apache/tvm.git is 1d5504b, the version is 0.7.dev1.

The error is: if i use the network2 in network.py and set the input data shape to (16, 16, 16) and save the pytorch params, both from_pytorch_3d.py and tune_realy_x86_3d.py can work normally, but if i increase the shape of data to (48, 48, 48) or (96,96, 96) both of them can not work normally, and error with :Process finished with exit code 139 (interrupted by signal 11: SIGSEGV).. If i use the network1 in network.py and set the input data shape to (96, 96, 96), both from_pytorch_3d.py and tune_realy_x86_3d.py can work normally. This is a puzzling question, if you have time, please take a look, thanks very very much!