Relay op strategy

You can follow the example in vta to overwrite the implementation for a specific target. https://github.com/apache/incubator-tvm/blob/master/vta/python/vta/top/op.py#L63

Thank you very much, I have another problem. I want to implement BN op with te.extern, like cuDNN. And I don’t want to unpack BN. Is there any exmaple for this or how should I do? Please help me :grinning:@haichen

I have created new strategy for BN, as follows:

@override_native_generic_func("batch_norm_strategy")
def batch_norm_strategy(attrs, inputs, out_type, target):
    """batch_norm ssnpu strategy"""
    strategy = _op.OpStrategy()
    strategy.add_implementation(
        wrap_compute_batch_norm(topi.nn.batch_norm),
        wrap_topi_schedule(topi.generic.schedule_injective),
        name="batch_norm.generic")
    return strategy

@batch_norm_strategy.register("ssnpu")
def batch_norm_strategy_ssnpu(attrs, inputs, out_type, target):
    """batch_norm ssnpu strategy"""
    strategy = _op.OpStrategy()
    strategy.add_implementation(
        wrap_compute_batch_norm(topi.ssnpu.batch_norm_vp),
        wrap_topi_schedule(topi.ssnpu.schedule_batch_norm_vp),
        name="batch_norm.ssnpu",
        plevel=15)
    return strategy

def batch_norm_vp(data, gamma, beta, mean, variance):
    """Batch_norm operator on ssnpu"""
    return te.extern(
        data.shape,
        [data, gamma, beta, mean, variance],
        lambda ins, outs: tvm.tir.call_packed(
            "tvm.contrib.ssnpu.batch_norm.forward",
            ins[0], ins[1], ins[2], ins[3], ins[4], outs[0]), dtype=data.dtype, name="C")

but I got some probleam when I run the net

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /home/wda/tvm_0622/build/libtvm.so(tvm::runtime::GraphRuntime::Run()+0x79) [0x7fc3ce6ed8b7]
  [bt] (7) /home/wda/tvm_0622/build/libtvm.so(std::function<void ()>::operator()() const+0x32) [0x7fc3cdb83460]
  [bt] (6) /home/wda/tvm_0622/build/libtvm.so(+0x2862878) [0x7fc3ce6f3878]
  [bt] (5) /home/wda/tvm_0622/build/libtvm.so(+0x285ff21) [0x7fc3ce6f0f21]
  [bt] (4) /home/wda/tvm_0622/build/libtvm.so(tvm::runtime::PackedFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x30) [0x7fc3cda28c4e]
  [bt] (3) /home/wda/tvm_0622/build/libtvm.so(std::function<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x5a) [0x7fc3cd9d6d48]
  [bt] (2) /home/wda/tvm_0622/build/libtvm.so(+0x27ebdc6) [0x7fc3ce67cdc6]
  [bt] (1) /home/wda/tvm_0622/build/libtvm.so(+0x27eaa3d) [0x7fc3ce67ba3d]
  [bt] (0) /home/wda/tvm_0622/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x37) [0x7fc3cd96d191]
  File "/home/wda/tvm_0622/src/runtime/library_module.cc", line 78
TVMError: Check failed: ret == 0 (-1 vs. 0) : Assert fail: (num_args == 6), fused_nn_batch_norm: num_args should be 6

The problem has been solved. My batch normal’s compute is wrong. :joy:

return te.extern(
        #TODO data.shape[1] maybe need modify
        [data.shape, data.shape[1], data.shape[1]],
        [data, gamma, beta, mean, variance],
        lambda ins, outs: tvm.tir.call_packed(
            "tvm.contrib.ssnpu.batch_norm.forward",
            ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], outs[1], outs[2]), dtype=[dtype, dtype, dtype], name="C")