How to deploy two different tvm compiled model in c++ statically?

Hi @srkreddy1238, @tqchen is it possible to deploy and generate c++ API’s with two different tvm compiled models at the same time. Like I have two tvm compiled models for face detection and object detection can inference them same time?
Note: It should be statically not dynamically

Check out

Likely you will want to use the system module. What you want is possible but with a bit of effort. Note that tvm’s graph runtime takes two inputs:

  • graph_json which is the graph json file
  • lib the tvm module containing all the functions needed by the graph.

To deploy two modules together, we somehow need to combine the generated code together to create a single module that contains functions needed by both modules. Then we can create to create two graph runtimes, one for each module.

@tqchen thank you for your suggestion, i am able to deploy single tvm compiled model in both dynamic and static way. For deploying two models in one c++ code any more suggestions or samples above shared links gives idea about model deployment in c++. i need help on deploying two models.

My comment above is for two models. You need to somehow generate a module that contains functions used in two models (in normal c code it could be as simple as linking everything together) and two versions of json

Hi @tqchen Any samples to generate a module that contains functions used in two different models?

@tqchen the symbol name of the model in the .o is tvm_runtime_create, tvm_runtime_run…, so if there are two models such as a.o and b.o, when we link a.o,b.o,runtime.o, main.o to the final executable, the symbol name will be the same and will cause a link error.

usecases: use mobilenetv1, mobilenetv2 in a same app, how to identify them? by there json and params file? there should be some ways to identify them at the .o or .so level.

1 Like

tvm.version ‘0.8.dev0’

compile and test steps

  1. prepare relay models

  2. disable tvm clear compile engine cache

  3. load relay models and collect all outputs

  4. new IRModule from outputs. The Big model

  5. relay.build The Big model. At this moment we have function cache for all models. save lib(contain funcitons) only

  6. relay.build every model to get graph and params

  7. load The Big model to tvm runtime

  8. use graph and params test each model

tvm source RelayBuildModule::Build (src/relay/backend/build_module.cc:224).
commit this line CompileEngine::Global()->Clear(); to keep compil engine funciton cache

  /*!
   * \brief Build relay IRModule for graph runtime
   *
   * \param mod Relay IRModule
   * \param target Target device
   * \param target_host Host target device
   */
  void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) {
    targets_ = targets;
    target_host_ = target_host;
    BuildRelay(mod, params_);
    // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096.
//    CompileEngine::Global()->Clear(); 
  }
import os
import platform

import numpy as np
import tvm
import tvm.relay.testing
import tvm.relay.testing
from tvm import ir
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay import analysis
from tvm.relay import expr as _expr
from tvm.relay import function as _function

work_root = '/Users/rqg/PycharmProjects/TvmProject'

infos = {
    'fd': {
        'network': 'fd_landscape',
        'input_shape': (1, 3, 384, 224),
        'input_name': 'data',
        'output_name': ['mbox_conf_softmax'],
        'data_blobs_path': 'data/fd_landscape_blobs.npz',
        'model_json_file': 'models/fd_landscape.tar',
        'image_dir': '/root/PG_LANDSCAPE',
        'mean': 127.5,
        'std': 1.0,
    },
    'fv': {
        'network': 'fv',
        'input_shape': (1, 3, 48, 48),
        'input_name': 'data',
        'output_name': ['cls_softmax', 'bbox_pred'],
        'data_blobs_path': 'data/fv_blobs.npz',
        'model_json_file': 'models/fv.tar',
        'image_dir': '/root/PG/',
        'mean': 127.5,
        'std': 127.5,
    },
    'landmark': {
        'network': 'landmark_nobn',
        'input_shape': {1, 3, 112, 112},
        'input_name': 'data1',
        'output_name': ['fc1'],
        'data_blobs_path': 'data/landmark_nobn_blobs.npz',
        'model_json_file': 'models/landmark_nobn.tar',
        'image_dir': '/root/LANDMARK/',
        'mean': 127.5,
        'std': 127.5,
    },
    'pose': {
        'network': 'facepose_forshake',
        'input_shape': {1, 3, 112, 112},
        'input_name': 'data1',
        'output_name': ['roll_prob', 'yaw_prob', 'pitch_prob'],
        'data_blobs_path': 'data/facepose_forshake_blobs.npz',
        'model_json_file': 'models/facepose_forshake.tar',
        'image_dir': '/root/LANDMARK/',
        'mean': 127.5,
        'std': 127.5,
    },
    'spoof': {
        'network': 'face_spoof',
        'input_shape': (1, 3, 32, 32),
        'input_name': 'data1',
        'output_name': ['softmax1'],
        'data_blobs_path': 'data/face_spoof_blobs.npz',
        'model_json_file': 'models/face_spoof.tar',
        'image_dir': '/root/LANDMARK/',
        'mean': 127.5,
        'std': 127.5,
    },
}

target = 'llvm --system-lib'


def load_model(json_path):
    with open(os.path.join(work_root, json_path)) as json_file:
        return ir.load_json(json_file.read())


def print_result(data, module, num_output, module_output):
    for i in range(num_output):
        ret_out = module.get_output(i).asnumpy()

        diss = []
        for k in module_output:
            if data[k].shape == ret_out.shape:
                diss.append(np.linalg.norm(ret_out - data[k]))

        if diss:
            print(min(diss))
        else:
            print(ret_out)


def test_model(mg, mp, fk, sysLib, ctx):
    print(fk, '----' * 30)
    info = infos[fk]
    module = graph_runtime.create(mg, sysLib, ctx)
    module.load_params(relay.save_param_dict(mp))
    data = np.load(os.path.join(work_root, info['data_blobs_path']))
    module.set_input(info['input_name'], tvm.nd.array(data[info['input_name']]))
    module.run()
    num_output = module.get_num_outputs()
    print_result(data, module, num_output, info['output_name'])


if __name__ == '__main__':
    models = {}
    funcs = {}

    outputs = []
    for fk in infos:
        m = load_model(infos[fk]['model_json_file'])
        out = m['main'].body
        # func = _function.Function(analysis.free_vars(out), out)
        # models[fk] = ir.IRModule.from_expr(func)
        models[fk] = m
        if isinstance(out, tvm.relay.expr.Tuple):
            for i in out:
                outputs.append(i)
        else:
            outputs.append(out)

    outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
    func = _function.Function(analysis.free_vars(outputs), outputs)
    mod = ir.IRModule.from_expr(func)

    # for fk in infos:
    #     mod[fk] = models[fk]['main'].body

    print("compile model collection")
    # AlterOpLayout pass will cause different build result on second relay.build
    with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
        graph, lib, params = relay.build_module.build(mod, target=target)

    if platform.system() == 'Darwin':
        lib_name = 'main.dylib'
    elif platform.system() == 'Linux':
        lib_name = 'main.so'
    elif platform.system() == 'Linux':
        lib_name = 'main.dll'
    else:
        raise Exception('unknown system ' + platform.system())

    print("export_library main lib")
    lib.export_library(lib_name)


    model_bin = {}
    for fk in infos:
        print("compile model " + fk)
        # AlterOpLayout pass will cause different build result on second relay.build
        with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
            mg, mb, mp = relay.build_module.build(models[fk], target=target)
        model_bin[fk] = [mg, mp]

    print("load main lib")
    mod = tvm.runtime.load_module(lib_name)

    print("prepare env")
    # sysLib = tvm.get_global_func('runtime.SystemLib')()
    sysLib = mod
    ctx = tvm.cpu(0)

    print("start test model")

    for fk in model_bin:
        mg, mp = model_bin[fk]
        test_model(mg, mp, fk, sysLib, ctx)

3 Likes

here is build log

compile model collection
 codegen_cpu.cc:827: CreateCall: __tvm_module_ctx
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_sigmoid_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_41
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_11
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_42
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_58
 codegen_cpu.cc:827: CreateCall: fused_reshape_multiply_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_5
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_9
 codegen_cpu.cc:827: CreateCall: fused_nn_dilate_nn_pad_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_43
 codegen_cpu.cc:827: CreateCall: fused_reshape_5
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_6
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_19
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_22
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_20
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_69
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_nn_relu_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_8
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_15
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_33
 codegen_cpu.cc:827: CreateCall: fused_vision_non_max_suppression
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_add_nn_relu_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_45
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_nn_relu_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_68
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d_3
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add
 codegen_cpu.cc:827: CreateCall: fused_image_resize_add_image_resize_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_nn_global_avg_pool2d
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_64
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_41
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_nn_relu_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_16
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_25
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_4
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d_2
 codegen_cpu.cc:827: CreateCall: fused_reshape
 codegen_cpu.cc:827: CreateCall: fused_nn_batch_flatten_reshape_transpose
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_sigmoid_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_8
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_36
 codegen_cpu.cc:827: CreateCall: fused_nn_global_avg_pool2d_4
 codegen_cpu.cc:827: CreateCall: fused_image_resize
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_nn_relu_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_21
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_38
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_70
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_38
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_30
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_59
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_40
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_49
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_61
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_72
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d
 codegen_cpu.cc:827: CreateCall: fused_transpose_nn_batch_flatten_transpose_nn_batch_flatten_transpose_nn_batch_f_9232982656378421941_
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_71
 codegen_cpu.cc:827: CreateCall: fused_reshape_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_75
 codegen_cpu.cc:827: CreateCall: fused_concatenate_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_7
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_23
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_7
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_nn_relu_3
 codegen_cpu.cc:827: CreateCall: fused_concatenate_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_35
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_4
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_10
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_24
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_36
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_65
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_33
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_37
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_60
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_45
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_12
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_51
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_17
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_62
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_6
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_6
 codegen_cpu.cc:827: CreateCall: fused_nn_dilate_nn_pad_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_32
 codegen_cpu.cc:827: CreateCall: fused_nn_dilate_nn_pad_5
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_18
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_add_add_nn_relu_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_77
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_5
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_48
 codegen_cpu.cc:827: CreateCall: fused_image_resize_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_31
 codegen_cpu.cc:827: CreateCall: fused_take
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_47
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_31
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_4
 codegen_cpu.cc:827: CreateCall: fused_reshape_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_13
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_sigmoid_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_add
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_79
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_44
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_26
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_81
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_55
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_11
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_14
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_14
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_51
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_12
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_13
 codegen_cpu.cc:827: CreateCall: fused_reshape_4
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_48
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_nn_relu_2
 codegen_cpu.cc:827: CreateCall: fused_concatenate_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_15
 codegen_cpu.cc:827: CreateCall: fused_nn_softmax
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_53
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_63
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_8
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_76
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_16
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_25
 codegen_cpu.cc:827: CreateCall: fused_nn_dilate_nn_pad
 codegen_cpu.cc:827: CreateCall: fused_vision_multibox_transform_loc
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_66
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_13
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_46
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_28
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_20
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_80
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_52
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_12
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_34
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_5
 codegen_cpu.cc:827: CreateCall: fused_nn_global_avg_pool2d_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_add_add_nn_relu_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_35
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_7
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_67
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_54
 codegen_cpu.cc:827: CreateCall: fused_reshape_multiply_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_18
 codegen_cpu.cc:827: CreateCall: fused_nn_dilate_nn_pad_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_40
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_10
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_24
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_add_add
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_sigmoid
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_9
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_11
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_22
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_27
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_nn_relu_4
 codegen_cpu.cc:827: CreateCall: fused_nn_softmax_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_56
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_add_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_reshape_squeeze_expand_dims_multiply
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_15
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_43
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_nn_relu_5
 codegen_cpu.cc:827: CreateCall: fused_nn_dilate_nn_pad_4
 codegen_cpu.cc:827: CreateCall: fused_nn_softmax_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_39
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_50
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_19
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_37
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_23
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_30
 codegen_cpu.cc:827: CreateCall: fused_image_resize_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_nn_softmax_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_34
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_29
 codegen_cpu.cc:827: CreateCall: fused_concatenate
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d_add_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_26
 codegen_cpu.cc:827: CreateCall: fused_reshape_multiply
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_39
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_74
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_46
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_27
 codegen_cpu.cc:827: CreateCall: fused_transpose_nn_batch_flatten_transpose_nn_batch_flatten_transpose_nn_batch_f_5308005036864095291_
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_57
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_49
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_29
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_17
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_32
 codegen_cpu.cc:827: CreateCall: fused_image_resize_2
 codegen_cpu.cc:827: CreateCall: fused_nn_global_avg_pool2d_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_78
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_82
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_28
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_42
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_nn_global_avg_pool2d_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_73
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_21
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d_add_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_10
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_14
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_17
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_16
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_50
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d_add
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_44
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_9
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_47
 codegen_cpu.cc:827: CreateCall: fused_image_resize_add_image_resize_add_image_resize_add_nn_relu_image_resize_im_13337369837761315479_
 codegen_cpu.cc:827: CreateCall: fused_reshape_1
export_library main lib
/Users/rqg/PycharmProjects/CoreML/bundle_compile_model_lib.py:149: DeprecationWarning: legacy graph runtime behavior of producing json / lib / params will be removed in the next release. Please see documents of tvm.contrib.graph_runtime.GraphModule for the  new recommended usage.
  graph, lib, params = relay.build_module.build(mod, target=target)
compile model fd
 codegen_cpu.cc:827: CreateCall: __tvm_module_ctx
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_11
 codegen_cpu.cc:827: CreateCall: fused_transpose_nn_batch_flatten_transpose_nn_batch_flatten_transpose_nn_batch_f_9232982656378421941_
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_9
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_6
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_8
 codegen_cpu.cc:827: CreateCall: fused_vision_non_max_suppression
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_3
 codegen_cpu.cc:827: CreateCall: fused_nn_dilate_nn_pad_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_15
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_4
 codegen_cpu.cc:827: CreateCall: fused_nn_batch_flatten_reshape_transpose
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_3
 codegen_cpu.cc:827: CreateCall: fused_concatenate_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_7
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_4
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_6
 codegen_cpu.cc:827: CreateCall: fused_nn_dilate_nn_pad_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_5
 codegen_cpu.cc:827: CreateCall: fused_take
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_4
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_13
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_11
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_12
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_10
 codegen_cpu.cc:827: CreateCall: fused_concatenate_2
 codegen_cpu.cc:827: CreateCall: fused_nn_softmax
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_8
 codegen_cpu.cc:827: CreateCall: fused_vision_multibox_transform_loc
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_1
 codegen_cpu.cc:827: CreateCall: fused_nn_dilate_nn_pad
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_12
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_5
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_7
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_10
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_13
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_14
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_16
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_15
 codegen_cpu.cc:827: CreateCall: fused_concatenate
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_17
 codegen_cpu.cc:827: CreateCall: fused_transpose_nn_batch_flatten_transpose_nn_batch_flatten_transpose_nn_batch_f_5308005036864095291_
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_9
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_14
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_2
compile model fv
/Users/rqg/PycharmProjects/CoreML/bundle_compile_model_lib.py:167: DeprecationWarning: legacy graph runtime behavior of producing json / lib / params will be removed in the next release. Please see documents of tvm.contrib.graph_runtime.GraphModule for the  new recommended usage.
  mg, mb, mp = relay.build_module.build(models[fk], target=target)
 codegen_cpu.cc:827: CreateCall: __tvm_module_ctx
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_5
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_22
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_33
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_20
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_19
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_8
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_21
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_30
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_24
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_23
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_6
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_18
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_31
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_7
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_17
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_23
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_18
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_20
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_25
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_24
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_22
 codegen_cpu.cc:827: CreateCall: fused_nn_softmax_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_19
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_26
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_27
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_29
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_32
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_28
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_21
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_16
compile model landmark
 codegen_cpu.cc:827: CreateCall: __tvm_module_ctx
 codegen_cpu.cc:827: CreateCall: fused_nn_global_avg_pool2d
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_add_nn_relu_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_nn_relu_1
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_25
 codegen_cpu.cc:827: CreateCall: fused_reshape
 codegen_cpu.cc:827: CreateCall: fused_image_resize_add_image_resize_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_36
 codegen_cpu.cc:827: CreateCall: fused_image_resize
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_38
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_41
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_nn_relu_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_33
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_32
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_add_add_nn_relu_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_45
 codegen_cpu.cc:827: CreateCall: fused_image_resize_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_47
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_31
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_26
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_nn_relu_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_28
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_35
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_add_add_nn_relu_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_34
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_40
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_46
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_27
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_nn_relu_4
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_nn_relu_5
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_39
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_37
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_30
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_29
 codegen_cpu.cc:827: CreateCall: fused_image_resize_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_43
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_add_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_34
 codegen_cpu.cc:827: CreateCall: fused_image_resize_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_42
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_image_resize_add_image_resize_add_image_resize_add_nn_relu_image_resize_im_13337369837761315479_
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_44
compile model pose
 codegen_cpu.cc:827: CreateCall: __tvm_module_ctx
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_42
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_58
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_64
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d_3
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_38
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_49
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_40
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_59
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_61
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_35
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_65
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_36
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_13
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_51
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_62
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_37
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_12
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_48
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_60
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_55
 codegen_cpu.cc:827: CreateCall: fused_concatenate_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_53
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_63
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_52
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_66
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_67
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_54
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_9
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_11
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_56
 codegen_cpu.cc:827: CreateCall: fused_nn_softmax_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_39
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_57
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_10
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_50
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_41
compile model spoof
 codegen_cpu.cc:827: CreateCall: __tvm_module_ctx
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_sigmoid_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_43
 codegen_cpu.cc:827: CreateCall: fused_reshape_5
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_81
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_nn_relu_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_68
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_45
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_16
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_nn_relu_1
 codegen_cpu.cc:827: CreateCall: fused_reshape_multiply_1
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_sigmoid_1
 codegen_cpu.cc:827: CreateCall: fused_nn_global_avg_pool2d_4
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_nn_relu_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_70
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_72
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_69
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_71
 codegen_cpu.cc:827: CreateCall: fused_reshape_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_75
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_sigmoid_2
 codegen_cpu.cc:827: CreateCall: fused_reshape_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_79
 codegen_cpu.cc:827: CreateCall: fused_nn_dilate_nn_pad_5
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_nn_relu
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_77
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_44
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_14
 codegen_cpu.cc:827: CreateCall: fused_reshape_4
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_48
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_add
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_15
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_76
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_51
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_80
 codegen_cpu.cc:827: CreateCall: fused_nn_global_avg_pool2d_1
 codegen_cpu.cc:827: CreateCall: fused_reshape_multiply_2
 codegen_cpu.cc:827: CreateCall: fused_nn_dilate_nn_pad_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_add_add
 codegen_cpu.cc:827: CreateCall: fused_nn_dense_add_sigmoid
 codegen_cpu.cc:827: CreateCall: fused_nn_softmax_3
 codegen_cpu.cc:827: CreateCall: fused_nn_dilate_nn_pad_4
 codegen_cpu.cc:827: CreateCall: fused_reshape_squeeze_expand_dims_multiply
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_46
 codegen_cpu.cc:827: CreateCall: fused_reshape_multiply
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d_add_1
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_74
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d_add_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_47
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_49
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_50
 codegen_cpu.cc:827: CreateCall: fused_nn_global_avg_pool2d_2
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_82
 codegen_cpu.cc:827: CreateCall: fused_nn_global_avg_pool2d_3
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_73
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_nn_relu_78
 codegen_cpu.cc:827: CreateCall: fused_nn_conv2d_add_add_17
 codegen_cpu.cc:827: CreateCall: fused_nn_avg_pool2d_add
 codegen_cpu.cc:827: CreateCall: fused_reshape_1
load main lib
prepare env
start test model
fd ------------------------------------------------------------------------------------------------------------------------
[[[ 0.          0.9999993   0.13859159  0.48088723  0.71894324
    0.8581148 ]
  [-1.         -1.         -1.         -1.         -1.
   -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.        ]]]
fv ------------------------------------------------------------------------------------------------------------------------
8.15762e-07
8.1490725e-10
landmark ------------------------------------------------------------------------------------------------------------------------
2.9493126e-06
pose ------------------------------------------------------------------------------------------------------------------------
5.4237614e-07
2.0453282e-07
5.892721e-07
spoof ------------------------------------------------------------------------------------------------------------------------
1.5498118e-06

1 Like

Did you find any way to achieve this?