[Graph fuse] x86 compilation failed with some schedules

With some schedule combinations, x86 CNN model compilation fails with:

Traceback (most recent call last):
  File "test_tvm_e2e.py", line 80, in <module>
    tm= end2end_benchmark(model, target, batch_size)
  File "test_tvm_e2e.py", line 47, in end2end_benchmark
    graph, lib, params = nnvm.compiler.build(net, target=target, shape={"data": data_shape}, params=params)
  File "/home/ubuntu/tvm/nnvm/python/nnvm/compiler/build_module.py", line 305, in build
    graph = graph.apply("GraphCompile")
  File "/home/ubuntu/tvm/nnvm/python/nnvm/graph.py", line 234, in apply
    check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
  File "/home/ubuntu/tvm/nnvm/python/nnvm/_base.py", line 75, in check_call
    raise NNVMError(py_str(_LIB.NNGetLastError()))
nnvm._base.NNVMError: [20:40:19] include/nnvm/op.h:530: Check failed: op != nullptr 

@masahi mentioned this might be related to graph fuse and layout transformation. I changed “if tag.is_broadcast(op.tag)” to “if tag.is_injective(op.tag)” in conv2d_NCHWc schedule but the same error pops. I also tried to roll back graph_fuse.cc back to https://github.com/dmlc/tvm/pull/1548, but getting another error:

Traceback (most recent call last):
  File "test_tvm_e2e.py", line 80, in <module>
    tm= end2end_benchmark(model, target, batch_size)
  File "test_tvm_e2e.py", line 47, in end2end_benchmark
    graph, lib, params = nnvm.compiler.build(net, target=target, shape={"data": data_shape}, params=params)
  File "/home/ubuntu/tvm/nnvm/python/nnvm/compiler/build_module.py", line 288, in build
    graph, params = precompute_prune(graph, params)
  File "/home/ubuntu/tvm/nnvm/python/nnvm/compiler/build_module.py", line 407, in precompute_prune
    out_arrs = _run_graph(pre_graph, params)
  File "/home/ubuntu/tvm/nnvm/python/nnvm/compiler/build_module.py", line 356, in _run_graph
    graph, libmod, _ = build(graph, target, shape, dtype)
  File "/home/ubuntu/tvm/nnvm/python/nnvm/compiler/build_module.py", line 305, in build
    graph = graph.apply("GraphCompile")
  File "/home/ubuntu/tvm/nnvm/python/nnvm/graph.py", line 234, in apply
    check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
  File "/home/ubuntu/tvm/nnvm/python/nnvm/_base.py", line 75, in check_call
    raise NNVMError(py_str(_LIB.NNGetLastError()))
nnvm._base.NNVMError: [20:39:37] src/core/pass.cc:30: Check failed: reg != nullptr Cannot find pass GraphCompile in the registry

Any idea on this issue?

The last error should be fixed by adding this commit https://github.com/dmlc/tvm/pull/1564

I am probably responsible for the first error. I thought this was due to autotvm being involved, but your error log suggest this also occurs outside of autotvm.

Can you make a standalone script for reproducing? Or at least tell me what models are giving this error?

Use this branch: https://github.com/kevinthesun/tvm/tree/AutoTVMx86
Then create a file named “resnet_best.log” contains pre-tuned schedules for resnet50 and copy following schedules into it:

{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 3, 224, 224], "float32"], ["TENSOR", [64, 3, 7, 7], "float32"], 3, [7, 7], [2, 2], [3, 3], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 3, 224, 224, "float32"], [64, 3, 7, 7, "float32"], [2, 2], [3, 3], "NCHW", "NCHW", "float32"], {"i": 66, "c": null, "e": [["tile_ic", "sp", [3, 1]], ["tile_oc", "sp", [2, 32]], ["tile_ow", "sp", [14, 8]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[9.992109883790107e-05], 0, 2.8028311729431152, 1537216965.302669], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [64, 64, 1, 1], "float32"], 64, [1, 1], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 64, 56, 56, "float32"], [64, 64, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {"i": 183, "c": null, "e": [["tile_ic", "sp", [32, 2]], ["tile_oc", "sp", [2, 32]], ["tile_ow", "sp", [8, 7]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[1.3869206422909004e-05], 0, 2.7181451320648193, 1537219666.662246], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [64, 64, 3, 3], "float32"], 64, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 64, 56, 56, "float32"], [64, 64, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {"i": 234, "c": null, "e": [["tile_ic", "sp", [8, 8]], ["tile_oc", "sp", [2, 32]], ["tile_ow", "sp", [7, 8]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.00010850694040478143], 0, 2.7787768840789795, 1537225952.185851], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [256, 64, 1, 1], "float32"], 64, [1, 1], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 64, 56, 56, "float32"], [256, 64, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {"i": 226, "c": null, "e": [["tile_ic", "sp", [16, 4]], ["tile_oc", "sp", [8, 32]], ["tile_ow", "sp", [8, 7]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[4.323142436866296e-05], 0, 2.7467689514160156, 1537232053.491833], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 256, 56, 56], "float32"], ["TENSOR", [64, 256, 1, 1], "float32"], 256, [1, 1], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 256, 56, 56, "float32"], [64, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {"i": 299, "c": null, "e": [["tile_ic", "sp", [64, 4]], ["tile_oc", "sp", [2, 32]], ["tile_ow", "sp", [7, 8]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[4.733017510668026e-05], 0, 2.7558913230895996, 1537240898.426819], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 256, 56, 56], "float32"], ["TENSOR", [128, 256, 1, 1], "float32"], 256, [1, 1], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 256, 56, 56, "float32"], [128, 256, 1, 1, "float32"], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {"i": 201, "c": null, "e": [["tile_ic", "sp", [32, 8]], ["tile_oc", "sp", [2, 64]], ["tile_ow", "sp", [7, 4]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[2.623872126210638e-05], 0, 2.7740578651428223, 1537248249.332708], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [128, 128, 3, 3], "float32"], 128, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 128, 28, 28, "float32"], [128, 128, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {"i": 239, "c": null, "e": [["tile_ic", "sp", [1, 128]], ["tile_oc", "sp", [4, 32]], ["tile_ow", "sp", [4, 7]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.00010607904158415841], 0, 2.796386957168579, 1537255370.085277], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 256, 56, 56], "float32"], ["TENSOR", [512, 256, 1, 1], "float32"], 256, [1, 1], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 256, 56, 56, "float32"], [512, 256, 1, 1, "float32"], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {"i": 242, "c": null, "e": [["tile_ic", "sp", [1, 256]], ["tile_oc", "sp", [8, 64]], ["tile_ow", "sp", [7, 4]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[9.01399304818586e-05], 0, 2.761781930923462, 1537261334.850996], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [512, 128, 1, 1], "float32"], 128, [1, 1], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 128, 28, 28, "float32"], [512, 128, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {"i": 282, "c": null, "e": [["tile_ic", "sp", [32, 4]], ["tile_oc", "sp", [16, 32]], ["tile_ow", "sp", [4, 7]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[4.2513385772775525e-05], 0, 2.768828868865967, 1537271043.926714], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 512, 28, 28], "float32"], ["TENSOR", [128, 512, 1, 1], "float32"], 512, [1, 1], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 512, 28, 28, "float32"], [128, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {"i": 229, "c": null, "e": [["tile_ic", "sp", [1, 512]], ["tile_oc", "sp", [2, 64]], ["tile_ow", "sp", [7, 4]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[4.85694837788484e-05], 0, 2.734384298324585, 1537278536.618539], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 512, 28, 28], "float32"], ["TENSOR", [256, 512, 1, 1], "float32"], 512, [1, 1], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 512, 28, 28, "float32"], [256, 512, 1, 1, "float32"], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {"i": 518, "c": null, "e": [["tile_ic", "sp", [2, 256]], ["tile_oc", "sp", [4, 64]], ["tile_ow", "sp", [7, 2]], ["tile_oh", "ot", 2]], "t": ""}], "r": [[2.5853111848282553e-05], 0, 2.7517669200897217, 1537288225.718235], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [256, 256, 3, 3], "float32"], 256, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 256, 14, 14, "float32"], [256, 256, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {"i": 211, "c": null, "e": [["tile_ic", "sp", [16, 16]], ["tile_oc", "sp", [8, 32]], ["tile_ow", "sp", [2, 7]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.00010395260962855365], 0, 2.7702410221099854, 1537291380.227396], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 512, 28, 28], "float32"], ["TENSOR", [1024, 512, 1, 1], "float32"], 512, [1, 1], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 512, 28, 28, "float32"], [1024, 512, 1, 1, "float32"], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {"i": 619, "c": null, "e": [["tile_ic", "sp", [1, 512]], ["tile_oc", "sp", [16, 64]], ["tile_ow", "sp", [7, 2]], ["tile_oh", "ot", 2]], "t": ""}], "r": [[9.182530663044275e-05], 0, 2.7436482906341553, 1537299864.213318], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [1024, 256, 1, 1], "float32"], 256, [1, 1], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 256, 14, 14, "float32"], [1024, 256, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {"i": 552, "c": null, "e": [["tile_ic", "sp", [32, 8]], ["tile_oc", "sp", [16, 64]], ["tile_ow", "sp", [7, 2]], ["tile_oh", "ot", 2]], "t": ""}], "r": [[4.598596980410023e-05], 0, 2.778485059738159, 1537306466.367181], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 1024, 14, 14], "float32"], ["TENSOR", [256, 1024, 1, 1], "float32"], 1024, [1, 1], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 1024, 14, 14, "float32"], [256, 1024, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {"i": 571, "c": null, "e": [["tile_ic", "sp", [1, 1024]], ["tile_oc", "sp", [4, 64]], ["tile_ow", "sp", [7, 2]], ["tile_oh", "ot", 2]], "t": ""}], "r": [[4.798035886826427e-05], 0, 2.7370760440826416, 1537312795.076411], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 1024, 14, 14], "float32"], ["TENSOR", [512, 1024, 1, 1], "float32"], 1024, [1, 1], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 1024, 14, 14, "float32"], [512, 1024, 1, 1, "float32"], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {"i": 173, "c": null, "e": [["tile_ic", "sp", [4, 256]], ["tile_oc", "sp", [16, 32]], ["tile_ow", "sp", [1, 7]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[2.58854292077731e-05], 0, 2.73170804977417, 1537315776.299752], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [512, 512, 3, 3], "float32"], 512, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 512, 7, 7, "float32"], [512, 512, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {"i": 154, "c": null, "e": [["tile_ic", "sp", [32, 16]], ["tile_oc", "sp", [16, 32]], ["tile_ow", "sp", [1, 7]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.00010541035865884879], 0, 2.8359010219573975, 1537319029.493087], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 1024, 14, 14], "float32"], ["TENSOR", [2048, 1024, 1, 1], "float32"], 1024, [1, 1], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 1024, 14, 14, "float32"], [2048, 1024, 1, 1, "float32"], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {"i": 197, "c": null, "e": [["tile_ic", "sp", [1, 1024]], ["tile_oc", "sp", [64, 32]], ["tile_ow", "sp", [1, 7]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[9.394007035269554e-05], 0, 2.7450110912323, 1537322535.663796], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [2048, 512, 1, 1], "float32"], 512, [1, 1], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 512, 7, 7, "float32"], [2048, 512, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {"i": 172, "c": null, "e": [["tile_ic", "sp", [128, 4]], ["tile_oc", "sp", [64, 32]], ["tile_ow", "sp", [1, 7]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[4.154470212064789e-05], 0, 2.755155086517334, 1537326796.077581], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 2048, 7, 7], "float32"], ["TENSOR", [512, 2048, 1, 1], "float32"], 2048, [1, 1], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 2048, 7, 7, "float32"], [512, 2048, 1, 1, "float32"], [1, 1], [0, 0], "NCHW", "NCHW", "float32"], {"i": 182, "c": null, "e": [["tile_ic", "sp", [512, 4]], ["tile_oc", "sp", [16, 32]], ["tile_ow", "sp", [1, 7]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[4.545244831071326e-05], 0, 2.73813796043396, 1537330727.336169], "v": 0.1}
~                                                                                                                                                                                                                                                            

Run this script:

import nnvm
import tvm
import mxnet as mx
import numpy as np
import time
import argparse
import json

from tvm.contrib import graph_runtime
from mxnet import gluon
from mxnet.gluon.model_zoo.vision import get_model

import sys

parser = argparse.ArgumentParser(description='Search convolution workload.')
parser.add_argument('--model', type=str, required=True,
                    help="Pretrained model from gluon model zoo.")
parser.add_argument('--opt', type=int, required=True,
                    help="Opt level")

run_times = 1000

def end2end_benchmark(model, target, batch_size):
    print("Testing %s" % (model))
    num_classes = 1000
    image_shape = (3, 299, 299) if "inception" in model else (3, 224, 224)
    data_shape = (batch_size,) + image_shape
    out_shape = (batch_size, num_classes)

    block = get_model(model, pretrained=True)
    net, params = nnvm.frontend.from_mxnet(block)


    tvm.autotvm.task.DispatchContext.current = tvm.autotvm.apply_history_best("resnet_best.log")
    ctx = tvm.cpu()
    opt_level = args.opt
    with nnvm.compiler.build_config(opt_level=opt_level):
        graph, lib, params = nnvm.compiler.build(net, target=target, shape={"data": data_shape}, params=params)
         
    module = graph_runtime.create(graph, lib, ctx)
    module.set_input(**params)
        
    data_array = np.random.uniform(0, 255, size=data_shape).astype("float32")
    input_data = tvm.nd.array(data_array, ctx=ctx)
    mx_data = mx.nd.array(data_array)
    module.set_input('data', input_data)

    # Warmup
    for _ in range(100):
        module.run()
    
    s = time.time()
    for _ in range(run_times):
        module.run()
    tvm_time = time.time() - s
    print("TVM %s inference time for batch size of %d: %f" % (model, batch_size, tvm_time * 1000/run_times))
    tvm_out = module.get_output(0, out=tvm.nd.empty(out_shape))
    mx_out = block(mx_data)
    np.testing.assert_array_almost_equal(tvm_out.asnumpy(), mx_out.asnumpy(), decimal=3)
    return tvm_time/run_times * 1000


if __name__ == "__main__":
    args = parser.parse_args()
    model = args.model
    batch_size = 1
    target = "llvm -mcpu=skylake-avx512"
    tvm_resnet = 0
    tvm_mobilenet = 0
    tm= end2end_benchmark(model, target, batch_size)
    tvm_mobilenet += tm

with command:

python test_tvm_e2e.py --model resnet50_v1 --opt 3

This is tested on AWS C5.

Does your branch contains graph level autotuning, where layout transform is happening in many places?

Graph level tuning is not there, since schedules have been tuned. Here I just load them and compile. Layout transformation op will appear.

thanks, I can repro your error. I’ll have a look.

One quick way to fix this is to change layout_transform op to broadcast op here.

You can replace it with

reg.register_schedule("__layout_transform__", _fschedule_broadcast)
reg.register_pattern("__layout_transform__", OpPattern.BROADCAST)

@kevinthesun can you try this patch?

I figured out the root cause of the error and the patch above should fix it.

Thank you or quick response! I’ll try it.