[RFC] Functionality of AlterOpLayout and Possible Refactoring


#1

As discussed in [DISCUSS] and Documentation of AlterLayout Pass I would like to open an RFC about how we could make AlterOpLayout better and easier to understand.

The proposal is discussed with @janimesh and our colleague Sejong. The idea of Peephole Pass is mainly from @janimesh and Sejong.

@tqchen @merrymercy Please feel free to comment. I’m still not quite confident this is the best design though.

Expand me to read the functionality of AlterOpLayout, as well as the motivation of doing refactoring.

The Functionality of AlterOpLayout

Basically AlterOpLayout is trying to solve two problems,

1. Replace an operator with a set of operator(s).

For example, replace conv2d with conv2d_winograd. The replacing interface looks like

@conv2d_alter_layout.register(["arm_cpu"])
def _alter_conv2d_layout(attrs, inputs, tinfos, F):
    # attrs is the attributes of current conv2d
    # inputs are nnvm.symbol or relay's expression
    # tinfos indicates the shape and dtype for each
    data, weight = inputs
    weight = F.nn.contrib_conv2d_winograd_weight_transform(weight,
                                                           tile_size=tile_size)
    # do some other processing for input & weight
    # return the replaced operator
    return F.nn.contrib_conv2d_winograd_without_weight_transform(
        input, weight, **new_attrs)

2. Correct the layout

If the replaced operator requires a different layout, the AlterOpLayout pass will insert a layout_transform op automatically. Here is an example to replace all the NHWC-conv2d in the network with NCHW implement.

# Approach-2.1
@conv2d_alter_layout.register(["cpu"])
def _alter_conv2d_layout(attrs, inputs, tinfos, F):
    data, weight = inputs
    new_attrs = {k : attrs[k] for k in attrs.keys()}
    if attrs["data_layout"] == "NHWC":
        new_attrs["data_layout"] = "NCHW" # output will also be NCHW
        return conv2d(data, weight, **new_attrs)
    return None # None means remain the same

Since now the required data layout is different from the original layout, the pass needs to insert a layout_transform operator:
Before AlterOpLayout
data[NHWC] → conv2d_NHWC →
After AlterOpLayout
data[NHWC] → layout_transform (src=NHWC, dst=NCHW) → conv2d_NCHW → layout_transform (src=NCHW, dst=NHWC, #transform back to the original layout) →

Now people may wonder why it is AlterOpLayout’s responsibility to fix the layout? Isn’t it easier (and more clear) to have users explicitly take care of the layout transform in conv2d_alter_layout function? e.g.,

# Approach-2.2
# this can generate the same result as above
@conv2d_alter_layout.register(["cpu"])
def _alter_conv2d_layout(attrs, inputs, tinfos, F):
    data, weight = inputs
    new_attrs = {k : attrs[k] for k in attrs.keys()}
    if attrs["data_layout"] == "NHWC":
        new_attrs["data_layout"] = "NCHW"
        data = F.layout_transform(data, src="NHWC", dst="NCHW")
        output = F.conv2d(data, weight, **new_attrs)
        output = F.layout_transform(output, src="NCHW", dst="NHWC")
    return None

It is because we wanted to solve another problem - eliminate layout_transform operators as much as possible.
Consider we now have two convolutions, if using Approach-2.2, we’ll end up with,
Step 0:
data[NHWC] → conv2d_NHWC → conv2d_NHWC →
Step 1:
data[NHWC] → LT(NHWC→NCHW) → conv2d_NCHW → LT(NCHW→NHWC) → LT(NHWC→NCHW) → conv2d_NCHW → LT(NCHW→NHWC)

Note that two layout_transform (LT) above can be safely eliminated. Let’s take a look at how Approach-2.1 solves this problem,
Step 0:
data[NHWC] → conv2d_NHWC → conv2d_NHWC → [NHWC]
Step 1 - Replace the operator
data[NHWC] → conv2d_NCHW → conv2d_NCHW → [NHWC]
Step 2 - Infer Layout (shows in the square brackets):
data[NHWC] → [NCHW] conv2d_NCHW [NCHW] → [NCHW] conv2d_NCHW [NCHW] → [NHWC]
Step 3 - Fix layout mismatch by inserting layout_transform operators
data[NHWC] → LT(NHWC→NCHW) → [NCHW] conv2d_NCHW [NCHW] →(nothing need to be inserted here)→ [NCHW] conv2d_NCHW [NCHW] → LT(NCHW→NHWC) → [NHWC]

This is basically what NNVM’s AlterOpLayout pass is doing today.

AlterOpLayout Pass in Relay

NNVM’s AlterOpLayout pass seems to be clean, but why is that in Relay so different?

It is not due to the Relay itself, but due to the fact that Relay’s AlterOpLayout tries to enhance the layout inference (Step 2 above) by providing the input shapes (besides the input layouts). Take out = broadcast_add(input1, input2) as an example, given input1 is NCHW, broadcast_add cannot infer the layout of either input2 nor out. But if we know input1 's shape is [1, 3, 224, 224], input2's shape is [224, 224], then it is obvious that layout(input2) = HW and layout(out) = NCHW.

However, it is extremely difficult to propagate shape information in AlterOpLayout pass, recall the following two steps,
Step 1 - Replace the operator
data[NHWC] → conv2d_NCHW → conv2d_NCHW → [NHWC]
Step 1.1 (Infer Shape: IMPOSSIBLE)
Step 2 - Infer Layout (shows in the square brackets):
data[NHWC] → [NCHW] conv2d_NCHW [NCHW] → [NCHW] conv2d_NCHW [NCHW] → [NHWC]
We cannot do shape inference between Step 1 and Step 2, because at that point, data is in NHWC layout while conv2d_NCHW asks for NCHW layout, InferShape will see shape mismatch before layout_transform is properly inserted. It is like, InferShape depends on FixLayout, FixLayout depends on InferLayout, and InferLayout depends on InferShape.

This is why there’re some limits in Relay’s AlterOpLayout at the moment. It is very hard to address the limits, as a trial: https://github.com/yzhliu/tvm-1/blob/alter_layout_fix/src/relay/pass/alter_op_layout.cc

Refactoring Proposal

High-level idea

From a high-level perspective, we can dissolve current AlterOpLayout into two passes,

  1. OpAlteration - replace an operator with a set of other operators, user needs to take care of the layout transform, a.k.a,
    @conv2d_alter.register(["cpu"])
    def _alter_conv2d(attrs, inputs, tinfos, F):
        data, weight = inputs
        new_attrs = {k : attrs[k] for k in attrs.keys()}
        if attrs["data_layout"] == "NHWC":
            new_attrs["data_layout"] = "NCHW"
            data = F.layout_transform(data, src="NHWC", dst="NCHW")
            output = F.conv2d(data, weight, **new_attrs)
            output = F.layout_transform(output, src="NCHW", dst="NHWC")
        return None
  1. PeepholeOptimizer - remove the layout_transform operators if possible.

Details

Take the following network as an example,

data → conv2d → broadcast_add → 
      bias    → 

Infer Shape and Infer Layout

data[NCHW] → conv2d[NCHW] → broadcast_add[NCHW] → 
               bias [CHW] → 

Determine the best layout for operators
This could be a customized analysis pass, or the graph-tuner that Yao proposed: https://github.com/dmlc/tvm/pull/2184

Alter the operator according to the previous step. While the user needs to ensure the input and output layouts remain the same.

data[NCHW] → LT(NCHW→NCHW16c] conv2d → LT(NCHW16c→NCHW) →  broadcast_add[NCHW] → 
                                             bias [CHW] →

Alter operators according to pre-defined rules
This pass detects layout_transform, pass src_layout to the successor operator, asks if it can take this layout and do the computation. e.g., for the example above,
try_alter_broadcast(preferred_in_layout1 = NCHW16c, preferred_in_layout2 = CHW)
pre-defined rule can choose to alter it as following,

LT(NCHW16c→NCHW) → LT(NCHW→NCHW16c) → broadcast_add → LT(NCHW16c->NCHW)
      bias [CHW] → LT(CHW → CHW16c) → 

Peephole pass
Eliminate LT(NCHW16c→NCHW) → LT(NCHW→NCHW16c) in the above example, generates,

data → LT(NCHW→NCHW16c] conv2d → broadcast_add[NCHW16c] → LT(NCHW16c→NCHW)
 bias [CHW] → LT(CHW → CHW16c) →

Use block data format for whole model
#2

Thanks @yzhliu for the RFC. Just to reiterate, full layout handling requires 4 passes

  • InferLayout pass
  • FindBestLayout pass - Does not mutate the graph, just finds the best layout for each op
  • AlterOp pass - The input and output layout remains same as before, but the core op runs in the best layout determined by the FindBestLayout pass. As @yzhliu mentioned, a user can rewrite the op via a Python callback. Or, I can also envision a Relay pass automatically inserting the transforms by looking at original Inferred layout and the determined best layout.
  • Peephole pass - Remove redundant mem operators (layout transforms, reshapes, transposes etc)