[External codegen] How to prevent batch norm from being decomposed?

Continuing from the previous thread, it seems batch norm is always decomposed during build_module.optimize(). Is this expected? How can I preserve batch norm until codegen?

cc @comaniac

Before opt

fn (%data: Tensor[(1, 3, 224, 224), float32], %layer1_weight: Tensor[(16, 3, 3, 3), float32], %layer1_bn_gamma: Tensor[(16), float32], %layer1_bn_beta: Tensor[(16), float32], %layer1_bn_mean: Tensor[(16), float32], %layer1_bn_var: Tensor[(16), float32]) -> Tensor[(1, 16, 224, 224), float32] {
  %3 = fn (%dnnl_input0: Tensor[(1, 3, 224, 224), float32], %dnnl_input1: Tensor[(16, 3, 3, 3), float32], %dnnl_input2: Tensor[(16), float32], %dnnl_input3: Tensor[(16), float32], %dnnl_input4: Tensor[(16), float32], %dnnl_input5: Tensor[(16), float32], Compiler="dnnl", ExternalSymbol="dnnl_0", Primitive=1) -> Tensor[(1, 16, 224, 224), float32] {
    %0 = nn.conv2d(%dnnl_input0, %dnnl_input1, padding=[1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
    %1 = nn.batch_norm(%0, %dnnl_input2, %dnnl_input3, %dnnl_input4, %dnnl_input5) /* ty=(Tensor[(1, 16, 224, 224), float32], Tensor[(16), float32], Tensor[(16), float32]) */;
    %2 = %1.0;
    nn.relu(%2) /* ty=Tensor[(1, 16, 224, 224), float32] */
  };
  %3(%data, %layer1_weight, %layer1_bn_gamma, %layer1_bn_beta, %layer1_bn_mean, %layer1_bn_var) /* ty=Tensor[(1, 16, 224, 224), float32] */
}

After

fn (%data: Tensor[(1, 3, 224, 224), float32], %layer1_weight: Tensor[(16, 3, 3, 3), float32], %layer1_bn_gamma: Tensor[(16), float32], %layer1_bn_beta: Tensor[(16), float32], %layer1_bn_mean: Tensor[(16), float32], %layer1_bn_var: Tensor[(16), float32]) -> Tensor[(1, 16, 224, 224), float32] {
  %12 = fn (%dnnl_input0: Tensor[(1, 3, 224, 224), float32], %dnnl_input1: Tensor[(16, 3, 3, 3), float32], %dnnl_input2: Tensor[(16), float32], %dnnl_input3: Tensor[(16), float32], %dnnl_input4: Tensor[(16), float32], %dnnl_input5: Tensor[(16), float32], Compiler="dnnl", ExternalSymbol="dnnl_0", Primitive=1) -> Tensor[(1, 16, 224, 224), float32] {
    %0 = nn.conv2d(%dnnl_input0, %dnnl_input1, padding=[1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
    %1 = add(%dnnl_input5, 1e-05f /* ty=float32 */) /* ty=Tensor[(16), float32] */;
    %2 = sqrt(%1) /* ty=Tensor[(16), float32] */;
    %3 = divide(1f /* ty=float32 */, %2) /* ty=Tensor[(16), float32] */;
    %4 = multiply(%3, %dnnl_input2) /* ty=Tensor[(16), float32] */;
    %5 = expand_dims(%4, axis=1, num_newaxis=2) /* ty=Tensor[(16, 1, 1), float32] */;
    %6 = multiply(%0, %5) /* ty=Tensor[(1, 16, 224, 224), float32] */;
    %7 = negative(%dnnl_input4) /* ty=Tensor[(16), float32] */;
    %8 = multiply(%7, %4) /* ty=Tensor[(16), float32] */;
    %9 = add(%8, %dnnl_input3) /* ty=Tensor[(16), float32] */;
    %10 = expand_dims(%9, axis=1, num_newaxis=2) /* ty=Tensor[(16, 1, 1), float32] */;
    %11 = add(%6, %10) /* ty=Tensor[(1, 16, 224, 224), float32] */;
    nn.relu(%11) /* ty=Tensor[(1, 16, 224, 224), float32] */
  };
  %12(%data, %layer1_weight, %layer1_bn_gamma, %layer1_bn_beta, %layer1_bn_mean, %layer1_bn_var) /* ty=Tensor[(1, 16, 224, 224), float32] */
}

Good point. We actually encountered this issue before and added a workaround, but we removed it in that PR as you see because it’s too ad-hoc. @zhiics, we should revisit this issue and setup a no touch like mechanism for ops being offloaded.

1 Like

a bit off topic, but if bn is decomposed, with opt level == 3 I expect FoldScaleAxis and FoldConstant to remove all compile time do-able math, so that in the dnnl codegen I can focus only on detecting conv + bias add + relu.

But when I run the snippet below I still get all batch norm related math in the subgraph as shown in the above post. Am I missing something? @comaniac

    mod["main"] = ConvBNReluAnnotator("dnnl").visit(mod["main"])
    mod = transform.PartitionGraph()(mod)
    with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
        opt_mod, params = relay.build_module.optimize(mod, "llvm")
        print(opt_mod["main"].astext())

I think that’s because the decomposed BN has been partitioned into a subgraph for an external function. Optimizations will not be applied to external functions because we cannot make sure if an optimization will generate unsupported ops or patterns.

@zhiics may say more about this part.

do you mean decompsed batch norm related ops form another subgraph? I don’t think so, because as shown in the top post all decomposed ops are inlined to the original subgraph.

I meant all math ops from the original BN are in a subgraph marked with your annotation. That subgraph becomes an external function as your top post. Starting from there, other passes will not traverse inside external functions, so constant folding and other optimizations did not apply.

So if I want to remove BN and its compile time math, should I apply FoldScaleAxis + FoldConstant before partitioning?

That’s an interesting proposal. I believe if you could decompose BN and apply constant folding before partitioning, then you will get what you want, but it’s possible that the part you want to keep may be changed as well, so be careful.

In general, our intention is to leave all annotated Relay ops there so that we can have a bigger subgraph (assuming we have an ideal annotator). If we decompose BN, then even we have simplified its math, it still breaks subgraphs. In this case, the benefit from math simplification may be moderated by the overhead of transferring data and invoking kernels.

@comaniac sorry I was wrong about FoldScaleAxis + FoldConstant. I’ll update my test script and try again.

ok I updated my test.

def test_fuse():
    def get_layers(prefix, data, in_channel, out_channel):
        weight = relay.const(np.random.randn(out_channel, in_channel, 3, 3).astype(np.float32))
        bn_gamma = relay.const(np.random.randn(out_channel).astype(np.float32))
        bn_beta = relay.const(np.random.randn(out_channel).astype(np.float32))
        bn_mmean = relay.const(np.random.randn(out_channel).astype(np.float32))
        bn_mvar = relay.const(np.random.randn(out_channel).astype(np.float32))

        layer = relay.nn.conv2d(data=data, weight=weight,
                                kernel_size=(3,3), channels=out_channel, padding=(1, 1))
        layer = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0]
        layer = relay.nn.relu(layer)
        return layer

    data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
    layer1 = get_layers("layer1_", data, 3, 16)
    layer2 = get_layers("layer2_", layer1, 16, 16)
    last = layer1
    net = relay.Function(relay.analysis.free_vars(last), last)

    ishape = (1, 3, 224, 224)
    mod, params = tvm.relay.testing.create_workload(net)
    mod["main"] = ConvBNReluAnnotator("dnnl").visit(mod["main"])
    mod = transform.PartitionGraph()(mod)
    # print(mod["main"].astext(show_meta_data=False))
    with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
        opt_mod, params = relay.build_module.optimize(mod, "llvm")
        print(opt_mod["main"].astext(show_meta_data=False))

The result is interesting. There is one multiply op which shouldn’t exist. If I remove annot + partitioning, I get the correct result, conv + add + relu.

v0.0.4
fn (%data: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 16, 224, 224), float32] {
  %3 = fn (%dnnl_input0: Tensor[(1, 3, 224, 224), float32], %dnnl_input1: Tensor[(16, 3, 3, 3), float32], Compiler="dnnl", ExternalSymbol="dnnl_0", Primitive=1) -> Tensor[(1, 16, 224, 224), float32] {
    %0 = multiply(%dnnl_input1, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1, 1), float32] */ /* ty=Tensor[(16, 1, 1, 1), float32] */) /* ty=Tensor[(16, 3, 3, 3), float32] */;
    %1 = nn.conv2d(%dnnl_input0, %0, padding=[1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
    %2 = add(%1, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */ /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 224, 224), float32] */;
    nn.relu(%2) /* ty=Tensor[(1, 16, 224, 224), float32] */
  };
  %3(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */ /* ty=Tensor[(16, 3, 3, 3), float32] */) /* ty=Tensor[(1, 16, 224, 224), float32] */
}

It is possible that this is due to another bug in my annotator.

class ConvBNReluAnnotator(ExprMutator):

    def __init__(self, backend):
        super(ConvBNReluAnnotator, self).__init__()
        self.in_compiler = 0
        self.backend = backend

    def annotate_call(self, call):
        new_args = []
        for arg in call.args:
            new_arg = super().visit(arg)
            if call.op.name == "nn.conv2d" or isinstance(new_arg, relay.expr.Var):
                new_arg = compiler_begin(new_arg, self.backend)
            new_args.append(new_arg)
        return relay.Call(call.op, new_args, call.attrs, call.type_args)

    def visit_call(self, call):
        if call.op.name == "nn.conv2d":  # Annotate begin at args
            if self.in_compiler == 1:
                self.in_compiler = 2
                return self.annotate_call(call)
        elif call.op.name == "nn.batch_norm":
            if self.in_compiler == 1:
                return self.annotate_call(call)
        elif call.op.name == "nn.relu":  # Annotate end at output
            self.in_compiler = 1
            op = self.annotate_call(call)
            # if self.in_compiler == 2:
            op = compiler_end(op, self.backend)
            self.in_compiler = 0
            return op
        return super().visit_call(call)

If the function is annotated as an external and it is lifted out as a global function, all optimizations will be skipped.

We do need to revisit how we can better interact with passes for different backends.

hmm functions created during partitioning seems to be already marked as external:

Since these functions are by definition not a global function, that explains why optimizations are being applied. But then I don’t see any way to prevent optimization, since these functions cannot be lifted as global.

Please correct me if I am wrong.

ok applying FoldScaleAxis + FoldConstant before partitioning gives the desired output.

    seq = transform.Sequential([
        relay.transform.InferType(),
        relay.transform.SimplifyInference(),
        relay.transform.FoldConstant(),
        relay.transform.FoldScaleAxis(),
    ])

    with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
        mod = seq(mod)
        mod["main"] = ConvBNReluAnnotator("dnnl").visit(mod["main"])
        mod = transform.PartitionGraph()(mod)
        opt_mod, params = relay.build_module.optimize(mod, "llvm")
        print(opt_mod["main"].astext(show_meta_data=False))
fn (%data: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 16, 224, 224), float32] {
  %2 = fn (%dnnl_input0: Tensor[(1, 3, 224, 224), float32], %dnnl_input1: Tensor[(16, 3, 3, 3), float32], Compiler="dnnl", ExternalSymbol="dnnl_0", Primitive=1) -> Tensor[(1, 16, 224, 224), float32] {
    %0 = nn.conv2d(%dnnl_input0, %dnnl_input1, padding=[1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
    %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1), float32] */ /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 224, 224), float32] */;
    nn.relu(%1) /* ty=Tensor[(1, 16, 224, 224), float32] */
  };
  %2(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */ /* ty=Tensor[(16, 3, 3, 3), float32] */) /* ty=Tensor[(1, 16, 224, 224), float32] */
}

so all optimizations are performed on external, partitoned subgraphs no matter what, but for my use case this does the job. But it would be great to have a way to prevent certain optimization from happening if there is such a need.

@masahi yes, they will be applied on the case you show here as they are actually closures. We cannot skip them from the pass manager directly. We need to have followup mechanisms to skip optimizations that are not needed for external closures.

One simplest/straightforward ways is probably directly checking if the Function being visited uses UseDefaultCompiler or not. If so, we need to skip it. Some passes can probably still be applied (e.g. constant folding, CSE, etc), but others (like fusion, simplifyinference, etc) should be skipped. We may want to have an RFC for this with some possible designs/approaches.

1 Like