Folding Conv2D+BiasAdd+BatchNorm

Dear all, I’m a bit confused on how TVM fold operations. Hope anyone can help me. My question is that is there any way in TVM to fold the 2 add operations into 1 add since they both have 1 const operand? Thanks a lot!

  %0 = nn.conv2d(%data, %conv1_1_weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 224, 224), float32] */;
  %1 = nn.bias_add(%0, %conv1_1_bias) /* ty=Tensor[(1, 64, 224, 224), float32] */;
  %2 = nn.batch_norm(%1, %bn1_1_gamma, %bn1_1_beta, %bn1_1_moving_mean, %bn1_1_moving_var) /* ty=(Tensor[(1, 64, 224, 224), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
  %3 = %2.0;
  %4 = nn.relu(%3) /* ty=Tensor[(1, 64, 224, 224), float32] */;

After performing FoldScaleAxis(), it looks like:

  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 3, 3), float32] */ /* ty=Tensor[(64, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 224, 224), float32] */;
  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */ /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 224, 224), float32] */;
  %2 = add(%1, meta[relay.Constant][2] /* ty=Tensor[(64, 1, 1), float32] */ /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 224, 224), float32] */;
  %3 = nn.relu(%2) /* ty=Tensor[(1, 64, 224, 224), float32] */;

Code:

import tvm
from tvm import relay
from tvm.relay import transform
import tvm.relay.testing
from tvm.relay.build_module import bind_params_by_name

def fold_optimize(mod, params=None):
    optimize = tvm.transform.Sequential( [
        relay.transform.CanonicalizeOps(),
        relay.transform.SimplifyInference(),
        relay.transform.FoldScaleAxis(),
    ])
    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    mod = optimize(mod)
    return mod

if __name__ == '__main__':
    mod, params = relay.testing.vgg.get_workload(1, batch_norm=True)

    print(mod.astext(show_meta_data=False))
    with tvm.transform.PassContext(opt_level=3):
        mod = fold_optimize(mod, params=params)
    print(mod.astext(show_meta_data=False))

I don’t think TVM has a pass to fold two consecutive add operators. From my understanding, they might be fused together by FuseOps pass (need to double check).

Hi @comaniac, thanks for your reply! It seems FuseOps pass is realized in TIR by op inline. Currently I’m learning to develop a Relay codegen, like the DNNL compiler, to interface an in-house lib. I guess I need to develop a Relay pass for this case. Thanks for your help!

Ah if you are working on a case like DNNL, which uses BYOC to offload a subgraph to an external codegen, then you probably can just use the BYOC passes to achieve your goal. If you just need 2 add ops to be in one subgraph, then you can just run the BYOC passes, and they will fuse all consecutive supported ops to one Relay function and invoke your codegen for it. In this case, you can implement such fusion logic inside your codegen.

For more BYOC details, you are welcome to read our latest post that uses DNNL as an example to demonstrate how it works: https://tvm.apache.org/2020/07/15/how-to-bring-your-own-codegen-to-tvm

@comaniac This doc helps a lot. Appreciate your help!

Hi @comaniac. Thanks for your effort for this TVM community. All of your comments always help me to understand TVM very much. Recently, I also faced a similar problem with @xttr0n. I have tried to employ BYOC for my own codegen. However, when I tried to apply OP fusion for BN and Relu, it does not work. I think the reason is TupleGetItem… BN has three outputs and the first output is feeding into ReLU. So when I tried to use “GetRootCall()” function, I have to be used something like as follows: GetRootCall(callee->body.as(), 1, {“add”, “GetTupleItem”, “nn.relu”}); But it does not work. The error message was “AttributeError: Check failed: reg != nullptr: Operator relay.TupleGetItem is not registered”. So my question is what is the name of operation for TupleGetItem?

GetRootCall only works for CallNode but not TupleGetItemNode (https://github.com/apache/tvm/blob/main/src/relay/backend/utils.h#L281). You are welcome to send a PR to support other nodes in the function.

hi @comaniac, I am trying to add a new pattern [conv add add relu] in byoc, which represents for [conv2d bias bn relu]. And it seems that the pattern cannot be recognized in partition graph process. Do you have ideas about such a question?