[BUG]tune_relay_x86.py with mobilenet failed

i follow the https://docs.tvm.ai/tutorials/autotvm/tune_relay_x86.html exactly, except i comment tune_kernels(…) for time saving. My TVM version is 0.6.dev .

py code as below:

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Auto-tuning a convolutional network for x86 CPU
===============================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_, `Eddie Yan <https://github.com/eqy>`_

This is a tutorial about how to tune convolution neural network
for x86 CPU.
"""
import os
import numpy as np

import tvm
from tvm import autotvm
from tvm import relay
from tvm.relay import testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
import tvm.contrib.graph_runtime as runtime
#################################################################
# Define network
# --------------
# First we need to define the network in relay frontend API.
# We can either load some pre-defined network from :code:`relay.testing`
# or building :any:`relay.testing.resnet` with relay.
# We can also load models from MXNet, ONNX and TensorFlow.
#
# In this tutorial, we choose resnet-18 as tuning example.


def get_network(name, batch_size):
    """Get the symbol definition and random weight of a network"""
    input_shape = (batch_size, 3, 224, 224)
    output_shape = (batch_size, 1000)

    if "resnet" in name:
        n_layer = int(name.split('-')[1])
        mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
    elif "vgg" in name:
        n_layer = int(name.split('-')[1])
        mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
    elif name == 'mobilenet':
        mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
    elif name == 'squeezenet_v1.1':
        mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
    elif name == 'inception_v3':
        input_shape = (1, 3, 299, 299)
        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
    elif name == 'mxnet':
        # an example for mxnet model
        from mxnet.gluon.model_zoo.vision import get_model
        block = get_model('resnet18_v1', pretrained=True)
        mod, params = relay.frontend.from_mxnet(block, shape={input_name: input_shape}, dtype=dtype)
        net = mod["main"]
        net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
        mod = relay.Module.from_expr(net)
    else:
        raise ValueError("Unsupported network: " + name)

    return mod, params, input_shape, output_shape

[details="Summary"]
This text will be hidden
[/details]

target = "llvm -mcpu=core-avx2"

batch_size = 1
dtype = "float32"
model_name = "mobilenet"
log_file = "%s.log" % model_name
graph_opt_sch_file = "%s_graph_opt.log" % model_name

# Set the input name of the graph
# For ONNX models, it is typically "0".
input_name = "data"

# Set number of threads used for tuning based on the number of
# physical CPU cores on your machine.
num_threads = 4
os.environ["TVM_NUM_THREADS"] = str(num_threads)

tuning_option = {
    'log_filename': log_file,
    'tuner': 'random',
    'early_stopping': None,

    'measure_option': autotvm.measure_option(
        builder=autotvm.LocalBuilder(),
        runner=autotvm.LocalRunner(number=10, repeat=1,
                                   min_repeat_ms=1000),
    ),
}


# You can skip the implementation of this function for this tutorial.
def tune_kernels(tasks,
                 measure_option,
                 tuner='gridsearch',
                 early_stopping=None,
                 log_filename='tuning.log'):

    for i, tsk in enumerate(tasks):
        prefix = "[Task %2d/%2d] " % (i+1, len(tasks))

        # converting conv2d tasks to conv2d_NCHWc tasks
        op_name = tsk.workload[0]
        if op_name == 'conv2d':
            func_create = 'topi_x86_conv2d_NCHWc'
        elif op_name == 'depthwise_conv2d_nchw':
            func_create = 'topi_x86_depthwise_conv2d_NCHWc_from_nchw'
        else:
            raise ValueError("Tuning {} is not supported on x86".format(op_name))

        task = autotvm.task.create(func_create, args=tsk.args,
                                   target=target, template_key='direct')
        task.workload = tsk.workload

        # create tuner
        if tuner == 'xgb' or tuner == 'xgb-rank':
            tuner_obj = XGBTuner(task, loss_type='rank')
        elif tuner == 'ga':
            tuner_obj = GATuner(task, pop_size=50)
        elif tuner == 'random':
            tuner_obj = RandomTuner(task)
        elif tuner == 'gridsearch':
            tuner_obj = GridSearchTuner(task)
        else:
            raise ValueError("Invalid tuner: " + tuner)

        # do tuning
        n_trial=len(task.config_space)
        tuner_obj.tune(n_trial=n_trial,
                       early_stopping=early_stopping,
                       measure_option=measure_option,
                       callbacks=[
                           autotvm.callback.progress_bar(n_trial, prefix=prefix),
                           autotvm.callback.log_to_file(log_filename)])


# Use graph tuner to achieve graph level optimal schedules
# Set use_DP=False if it takes too long to finish.
def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
    target_op = [relay.nn.conv2d]
    Tuner = DPTuner if use_DP else PBQPTuner
    executor = Tuner(graph, {input_name: dshape}, records, target_op, target)
    executor.benchmark_layout_transform(min_exec_num=2000)
    executor.run()
    executor.write_opt_sch2record_file(opt_sch_file)


########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance.

def tune_and_evaluate(tuning_opt):
    # extract workloads from relay program
    print("Extract tasks...")
    mod, params, data_shape, out_shape = get_network(model_name, batch_size)
    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
                                              params=params, ops=(relay.op.nn.conv2d,))

    # run tuning tasks
    print("Tuning...")
    # tune_kernels(tasks, **tuning_opt)
    tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file)

    # compile kernels with graph-level best records
    with autotvm.apply_graph_best(graph_opt_sch_file):
        print("Compile...")
        with relay.build_config(opt_level=3):
            graph, lib, params = relay.build_module.build(
                mod, target=target, params=params)

        # upload parameters to device
        ctx = tvm.cpu()
        data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype))
        module = runtime.create(graph, lib, ctx)
        module.set_input(input_name, data_tvm)
        module.set_input(**params)

        # evaluate
        print("Evaluate inference time cost...")
        ftimer = module.module.time_evaluator("run", ctx, number=100, repeat=3)
        prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
              (np.mean(prof_res), np.std(prof_res)))

# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run it by yourself.

tune_and_evaluate(tuning_option)

and the running output as below

Extract tasks...
Tuning...
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (1, 3, 224, 224, 'float32'), (32, 3, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('depthwise_conv2d_nchw', (1, 32, 112, 112, 'float32'), (32, 1, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (1, 32, 112, 112, 'float32'), (64, 32, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('depthwise_conv2d_nchw', (1, 64, 112, 112, 'float32'), (64, 1, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (1, 64, 56, 56, 'float32'), (128, 64, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('depthwise_conv2d_nchw', (1, 128, 56, 56, 'float32'), (128, 1, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (1, 128, 56, 56, 'float32'), (128, 128, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('depthwise_conv2d_nchw', (1, 128, 56, 56, 'float32'), (128, 1, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (1, 128, 28, 28, 'float32'), (256, 128, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('depthwise_conv2d_nchw', (1, 256, 28, 28, 'float32'), (256, 1, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (1, 256, 28, 28, 'float32'), (256, 256, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('depthwise_conv2d_nchw', (1, 256, 28, 28, 'float32'), (256, 1, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (1, 256, 14, 14, 'float32'), (512, 256, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('depthwise_conv2d_nchw', (1, 512, 14, 14, 'float32'), (512, 1, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (1, 512, 14, 14, 'float32'), (512, 512, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('depthwise_conv2d_nchw', (1, 512, 14, 14, 'float32'), (512, 1, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (1, 512, 7, 7, 'float32'), (1024, 512, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('depthwise_conv2d_nchw', (1, 1024, 7, 7, 'float32'), (1024, 1, 3, 3, 'float32'), (1, 1), (1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -device=tracing, workload=('conv2d', (1, 1024, 7, 7, 'float32'), (1024, 1024, 1, 1, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Traceback (most recent call last):

  File "/home/bokyliu/Project/TVM/tune_relay_x86.py", line 222, in <module>
    tune_and_evaluate(tuning_option)

  File "/home/bokyliu/Project/TVM/tune_relay_x86.py", line 196, in tune_and_evaluate
    tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file)

  File "/home/bokyliu/Project/TVM/tune_relay_x86.py", line 177, in tune_graph
    executor = Tuner(graph, {input_name: dshape}, records, target_op, target)

  File "/home/bokyliu/Project/incubator-tvm/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py", line 43, in __init__
    super(DPTuner, self).__init__(*args, **kwargs)

  File "/home/bokyliu/Project/incubator-tvm/python/tvm/autotvm/graph_tuner/base_graph_tuner.py", line 157, in __init__
    self._fetch_cfg()

  File "/home/bokyliu/Project/incubator-tvm/python/tvm/autotvm/graph_tuner/base_graph_tuner.py", line 217, in _fetch_cfg
    for record in cfg_dict[workload]:

KeyError: ('conv2d', (1, 3, 224, 224, 'float32'), (32, 3, 3, 3, 'float32'), (2, 2), (1, 1), (1, 1), 'NCHW', 'float32')


Process finished with exit code 1

I didn’t found any similar bugs on https://discuss.tvm.ai/, anyone have sulotion should currect me.

I follow this tutorial intend to tune my onnx model, but tutorial failed. For my onnx model, at the first step i use graph:

graph(%0 : Float(1, 3, 112, 112)
      %1 : Float(64, 3, 3, 3)
      %2 : Float(64)
      %3 : Float(64)
      %4 : Float(64)
      %5 : Float(64)
      %6 : Long()
      %7 : Float(64)
      %8 : Float(64, 1, 3, 3)
      %9 : Float(64)
      %10 : Float(64)
      %11 : Float(64)
      %12 : Float(64)
      %13 : Long()
      %14 : Float(64)) {
  %15 : Float(1, 64, 56, 56) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[2, 2]](%0, %1), scope: AirFace/Conv_block[conv1]/Conv2d[conv]
  %16 : Float(1, 64, 56, 56) = onnx::BatchNormalization[epsilon=1e-05, momentum=1](%15, %2, %3, %4, %5), scope: AirFace/Conv_block[conv1]/BatchNorm2d[bn]
  %17 : Float(1, 64, 56, 56) = onnx::PRelu(%16, %7), scope: AirFace/Conv_block[conv1]/PReLU[prelu]
  return (%17);
}

and it tune success! But after i add three layers as below:

graph(%0 : Float(1, 3, 112, 112)
      %1 : Float(64, 3, 3, 3)
      %2 : Float(64)
      %3 : Float(64)
      %4 : Float(64)
      %5 : Float(64)
      %6 : Long()
      %7 : Float(64)
      %8 : Float(64, 1, 3, 3)
      %9 : Float(64)
      %10 : Float(64)
      %11 : Float(64)
      %12 : Float(64)
      %13 : Long()
      %14 : Float(64)) {
  %15 : Float(1, 64, 56, 56) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[2, 2]](%0, %1), scope: AirFace/Conv_block[conv1]/Conv2d[conv]
  %16 : Float(1, 64, 56, 56) = onnx::BatchNormalization[epsilon=1e-05, momentum=1](%15, %2, %3, %4, %5), scope: AirFace/Conv_block[conv1]/BatchNorm2d[bn]
  %17 : Float(1, 64, 56, 56) = onnx::PRelu(%16, %7), scope: AirFace/Conv_block[conv1]/PReLU[prelu]
  %18 : Float(1, 64, 56, 56) = onnx::Conv[dilations=[1, 1], group=64, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%17, %8), scope: AirFace/Conv_block[conv2_dw]/Conv2d[conv]
  %19 : Float(1, 64, 56, 56) = onnx::BatchNormalization[epsilon=1e-05, momentum=1](%18, %9, %10, %11, %12), scope: AirFace/Conv_block[conv2_dw]/BatchNorm2d[bn]
  %20 : Float(1, 64, 56, 56) = onnx::PRelu(%19, %14), scope: AirFace/Conv_block[conv2_dw]/PReLU[prelu]
  return (%20);
}

the tune output is

  [bt] (8) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::schedule::MakePipeline(tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, tvm::Stmt, bool)+0x5a) [0x7f970825fe3a]
  [bt] (7) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::ComputeOpNode::BuildProvide(tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, bool) const+0x115) [0x7f97080a1355]
  [bt] (6) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::MakeComputeStmt(tvm::ComputeOpNode const*, tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, bool)+0x4f) [0x7f97080a095f]
  [bt] (5) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::ComputeLoopNest::make(tvm::BaseComputeOpNode const*, tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, bool)+0x2a3) [0x7f970809fad3]
  [bt] (4) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::op::MakeLoopNest(tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, unsigned long, bool, std::unordered_set<tvm::IterVar, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<tvm::IterVar> > const&, std::unordered_map<tvm::IterVar, tvm::Expr, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Expr> > >*, bool)+0x296d) [0x7f97080bcd9d]
  [bt] (3) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::schedule::PassUpIndex(tvm::Stage const&, tvm::Map<tvm::IterVar, tvm::Range, void, void> const&, std::unordered_map<tvm::IterVar, tvm::Expr, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Expr> > >*, bool)+0x716) [0x7f970823d6d6]
  [bt] (2) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::indexdiv(tvm::Expr, tvm::Expr)+0x4b) [0x7f97080593eb]
  [bt] (1) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::floordiv(tvm::Expr, tvm::Expr)+0x248) [0x7f9708058bd8]
  [bt] (0) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7f9707eb3af2]
  File "/home/bokyliu/Project/incubator-tvm/src/lang/../arithmetic/const_fold.h", line 225
TVMError: Check failed: pb->value != 0 (0 vs. 0) : Divide by zero
During handling of the above exception, another exception occurred:

TVMError: Check failed: pb->value != 0 (0 vs. 0) : Divide by zero
Error during compile function
-----------------------------
v0.0.4
fn (%p0: Tensor[(1, 3, 112, 112), float32], %p1: Tensor[(64, 3, 3, 3), float32], %p2: Tensor[(64), float32], %p3: Tensor[(64), float32], %p4: Tensor[(64), float32], %p5: Tensor[(64), float32], Primitive=1) -> Tensor[(1, 64, 56, 56), float32] {
  %0 = nn.conv2d(%p0, %p1, strides=[2, 2], padding=[1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %1 = expand_dims(%p2, axis=1, num_newaxis=2) /* ty=Tensor[(64, 1, 1), float32] */;
  %2 = multiply(%0, %1) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %3 = negative(%p3) /* ty=Tensor[(64), float32] */;
  %4 = multiply(%3, %p2) /* ty=Tensor[(64), float32] */;
  %5 = add(%4, %p4) /* ty=Tensor[(64), float32] */;
  %6 = expand_dims(%5, axis=1, num_newaxis=2) /* ty=Tensor[(64, 1, 1), float32] */;
  %7 = add(%2, %6) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  nn.prelu(%7, %p5) /* ty=Tensor[(1, 64, 56, 56), float32] */
}

i’m sure no division in my onnx graph.

You cannot tune graph without first tuning kernels.

Thank you for the reply. But i did as you said, the output as same as tune graph without first tuning kernels.

  [bt] (8) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::schedule::MakePipeline(tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, tvm::Stmt, bool)+0x5a) [0x7fcad1cb9e3a]
  [bt] (7) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::ComputeOpNode::BuildProvide(tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, bool) const+0x115) [0x7fcad1afb355]
  [bt] (6) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::MakeComputeStmt(tvm::ComputeOpNode const*, tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, bool)+0x4f) [0x7fcad1afa95f]
  [bt] (5) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::ComputeLoopNest::make(tvm::BaseComputeOpNode const*, tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, bool)+0x2a3) [0x7fcad1af9ad3]
  [bt] (4) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::op::MakeLoopNest(tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, unsigned long, bool, std::unordered_set<tvm::IterVar, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<tvm::IterVar> > const&, std::unordered_map<tvm::IterVar, tvm::Expr, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Expr> > >*, bool)+0x296d) [0x7fcad1b16d9d]
  [bt] (3) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::schedule::PassUpIndex(tvm::Stage const&, tvm::Map<tvm::IterVar, tvm::Range, void, void> const&, std::unordered_map<tvm::IterVar, tvm::Expr, std::hash<tvm::IterVar>, std::equal_to<tvm::IterVar>, std::allocator<std::pair<tvm::IterVar const, tvm::Expr> > >*, bool)+0x716) [0x7fcad1c976d6]
  [bt] (2) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::indexdiv(tvm::Expr, tvm::Expr)+0x4b) [0x7fcad1ab33eb]
  [bt] (1) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::floordiv(tvm::Expr, tvm::Expr)+0x248) [0x7fcad1ab2bd8]
  [bt] (0) /home/bokyliu/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7fcad190daf2]
  File "/home/bokyliu/Project/incubator-tvm/src/lang/../arithmetic/const_fold.h", line 225
TVMError: Check failed: pb->value != 0 (0 vs. 0) : Divide by zero
During handling of the above exception, another exception occurred:

TVMError: Check failed: pb->value != 0 (0 vs. 0) : Divide by zero
Error during compile function
-----------------------------
v0.0.4
fn (%p0: Tensor[(1, 3, 112, 112), float32], %p1: Tensor[(64, 3, 3, 3), float32], %p2: Tensor[(64), float32], %p3: Tensor[(64), float32], %p4: Tensor[(64), float32], %p5: Tensor[(64), float32], Primitive=1) -> Tensor[(1, 64, 56, 56), float32] {
  %0 = nn.conv2d(%p0, %p1, strides=[2, 2], padding=[1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %1 = expand_dims(%p2, axis=1, num_newaxis=2) /* ty=Tensor[(64, 1, 1), float32] */;
  %2 = multiply(%0, %1) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %3 = negative(%p3) /* ty=Tensor[(64), float32] */;
  %4 = multiply(%3, %p2) /* ty=Tensor[(64), float32] */;
  %5 = add(%4, %p4) /* ty=Tensor[(64), float32] */;
  %6 = expand_dims(%5, axis=1, num_newaxis=2) /* ty=Tensor[(64, 1, 1), float32] */;
  %7 = add(%2, %6) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  nn.prelu(%7, %p5) /* ty=Tensor[(1, 64, 56, 56), float32] */
}

Have you tried to build the model without neither kernel nor graph tuning?

Yes, i tried. The output remain as

TVMError: Check failed: pb->value != 0 (0 vs. 0) : Divide by zero
During handling of the above exception, another exception occurred:

TVMError: Check failed: pb->value != 0 (0 vs. 0) : Divide by zero
Error during compile function
-----------------------------
v0.0.4
fn (%p0: Tensor[(1, 3, 112, 112), float32], %p1: Tensor[(64, 3, 3, 3), float32], %p2: Tensor[(64), float32], %p3: Tensor[(64), float32], %p4: Tensor[(64), float32], %p5: Tensor[(64), float32], Primitive=1) -> Tensor[(1, 64, 56, 56), float32] {
  %0 = nn.conv2d(%p0, %p1, strides=[2, 2], padding=[1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %1 = expand_dims(%p2, axis=1, num_newaxis=2) /* ty=Tensor[(64, 1, 1), float32] */;
  %2 = multiply(%0, %1) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %3 = negative(%p3) /* ty=Tensor[(64), float32] */;
  %4 = multiply(%3, %p2) /* ty=Tensor[(64), float32] */;
  %5 = add(%4, %p4) /* ty=Tensor[(64), float32] */;
  %6 = expand_dims(%5, axis=1, num_newaxis=2) /* ty=Tensor[(64, 1, 1), float32] */;
  %7 = add(%2, %6) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  nn.prelu(%7, %p5) /* ty=Tensor[(1, 64, 56, 56), float32] */
}

Seems like a bug as you said. Would you mind to provide a minimal example that is able to reproduce the problem so that we can further investigate it?

链接:https://pan.baidu.com/s/1QUHpWmNwjc1No5kamBx_cA 提取码:dmpl 复制这段内容后打开百度网盘手机App,操作更方便哦

Please let me know if i made any thing wrong. Thank you for your attention!

Don’t have a Baidu account to download it…

Which network disk do you use?

You can put it on a public space such as https://gist.github.com/ so that everyone can access.

github