[Relay] Duplicated MaxPool in multiple branches

    conv2d                       
      |
   max pool
    /     \                   
 conv2d  conv2d
    \     /
    concat

is transformed into

       conv2d
     /       \
max pool  max pool
    |         |
  conv2d   conv2d
    \         /
       concat

The network:

from tvm import relay
from tvm.relay.testing import layers

def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''):
    conv = layers.conv2d(
        data=data,
        channels=int(num_filter),
        kernel_size=kernel,
        strides=stride,
        padding=pad,
        name='%s%s_conv1' % (name, suffix))

    bn = layers.batch_norm_infer(data=conv, epsilon=2e-5, name='%s%s_bn' % (name, suffix))
    act = relay.nn.relu(data=bn)
    return act

def Pooling(data, kernel, stride, pad, pool_type, name):
    if pool_type == 'max':
        return relay.nn.max_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad)
    elif pool_type == 'avg':
        return relay.nn.avg_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad,
                                   count_include_pad=True)
    else:
        raise ValueError("Invalid pooling type: " + pool_type)


def get_net(batch_size,
            num_classes,
            image_shape,
            dtype):
    data_shape = (batch_size,) + image_shape
    data = relay.var("data",
                     shape=data_shape,
                     dtype=dtype)

    conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv")
    pool1 = Pooling(data=conv, kernel=(3, 3), stride=(2, 2), pool_type="max", pad=(0, 0),
                   name="pool")
    conv1 = Conv(pool1, 32, kernel=(3, 3), stride=(2, 2), name="conv1")
    conv2 = Conv(pool1, 32, kernel=(3, 3), stride=(2, 2), name="conv2")
    concat = relay.concatenate((conv1,conv2), axis=1)
    args = relay.ir_pass.free_vars(concat)
    return relay.Function(args, concat)

IR after fuse_ops (call relay.build with opt_pass_level = 3, max_pool is called twice in %9 and %19)

fn (%data: Tensor[(1, 3, 299, 299), float32])
    -> Tensor[(1, 64, 36, 36), float32] {
  %0 = meta.relay.Constant(id=0) # ty=Tensor[(32, 3, 3, 3), float32]
  %1 = meta.relay.Constant(id=1) # ty=Tensor[(32, 1, 1), float32]
  %2 = fn(%p0: Tensor[(1, 3, 299, 299), float32],
          %p1: Tensor[(32, 3, 3, 3), float32],
          %p2: Tensor[(32, 1, 1), float32])
          -> Tensor[(1, 32, 149, 149), float32] {
    %3 = nn.conv2d(%p0, %p1, strides=[2, 2], channels=32, kernel_size=[3, 3]) # ty=Tensor[(1, 32, 149, 149), float32]
    %4 = add(%3, %p2) # ty=Tensor[(1, 32, 149, 149), float32]
    %5 = nn.relu(%4) # ty=Tensor[(1, 32, 149, 149), float32]
    %5
  }
  %6 = %2(%data, %0, %1) # ty=Tensor[(1, 32, 149, 149), float32]
  %7 = fn(%p01: Tensor[(1, 32, 149, 149), float32])
          -> Tensor[(1, 32, 74, 74), float32] {
    %8 = nn.max_pool2d(%p01, pool_size=[3, 3], strides=[2, 2]) # ty=Tensor[(1, 32, 74, 74), float32]
    %8
  }
  %9 = %7(%6) # ty=Tensor[(1, 32, 74, 74), float32]
  %10 = meta.relay.Constant(id=2) # ty=Tensor[(32, 32, 3, 3), float32]
  %11 = meta.relay.Constant(id=3) # ty=Tensor[(32, 1, 1), float32]
  %12 = fn(%p02: Tensor[(1, 32, 74, 74), float32],
           %p11: Tensor[(32, 32, 3, 3), float32],
           %p21: Tensor[(32, 1, 1), float32])
           -> Tensor[(1, 32, 36, 36), float32] {
    %13 = nn.conv2d(%p02, %p11, strides=[2, 2], channels=32, kernel_size=[3, 3]) # ty=Tensor[(1, 32, 36, 36), float32]
    %14 = add(%13, %p21) # ty=Tensor[(1, 32, 36, 36), float32]
    %15 = nn.relu(%14) # ty=Tensor[(1, 32, 36, 36), float32]
    %15
  }
  %16 = %12(%9, %10, %11) # ty=Tensor[(1, 32, 36, 36), float32]
  %17 = fn(%p03: Tensor[(1, 32, 149, 149), float32])
           -> Tensor[(1, 32, 74, 74), float32] {
    %18 = nn.max_pool2d(%p03, pool_size=[3, 3], strides=[2, 2]) # ty=Tensor[(1, 32, 74, 74), float32]
    %18
  }
  %19 = %17(%6) # ty=Tensor[(1, 32, 74, 74), float32]
  %20 = meta.relay.Constant(id=4) # ty=Tensor[(32, 32, 3, 3), float32]
  %21 = meta.relay.Constant(id=5) # ty=Tensor[(32, 1, 1), float32]
  %22 = fn(%p04: Tensor[(1, 32, 74, 74), float32],
           %p12: Tensor[(32, 32, 3, 3), float32],
           %p22: Tensor[(32, 1, 1), float32])
           -> Tensor[(1, 32, 36, 36), float32] {
    %23 = nn.conv2d(%p04, %p12, strides=[2, 2], channels=32, kernel_size=[3, 3]) # ty=Tensor[(1, 32, 36, 36), float32]
    %24 = add(%23, %p22) # ty=Tensor[(1, 32, 36, 36), float32]
    %25 = nn.relu(%24) # ty=Tensor[(1, 32, 36, 36), float32]
    %25
  }
  %26 = %22(%19, %20, %21) # ty=Tensor[(1, 32, 36, 36), float32]
  %27 = (%16, %26)
  %28 = fn(%p05: Tuple[Tensor[(1, 32, 36, 36), float32], Tensor[(1, 32, 36, 36), float32]])
           -> Tensor[(1, 64, 36, 36), float32] {
    %29 = concatenate(%p05, axis=1) # ty=Tensor[(1, 64, 36, 36), float32]
    %29
  }
  %30 = %28(%27) # ty=Tensor[(1, 64, 36, 36), float32]
  %30
}
# meta data omitted. you can use show_meta_data=True to include meta-data

cc @tqchen

Hmm, this is something we need to be resolved, @vinx13 can you look a bit into this?

sure I will take a look

I think the problem is that two conv2d are in different groups, and the parameters to the new fused function is allocated per group. So when GetOrAllocParam is called twice on max_pool, the parameter corresponding to max_pool is allocated twice.

The fix should be modifying GetOrAllocParam to keep track of newly allocated params.

Actually max pool is duplicated during FoldScaleAxis backward pass

hmm interesting. What happens if you turn off FoldScaleAxis?

there are no problem if I turn off FoldScaleAxis.

This problem only occurs when conv2d is followed by batch_norm

@tqchen
I think the problem is the result of Transform(max_pool) (which is called by Conv2DBackwardTransform) is not memorized. The comments said that it should be only called once, however it is called twice actually if two conv2d consume the same max pool as input

I see, the problem is that we use axes.defined() to check if there is a chance of fusion, but many cases the axes is actually empty(instead of Null) and in that case we still cannot fuse.

One possible solution is to always change axes.defined() check to axes.size() != 0. Let us also revisit the memoize assumption if that fix do not solve the problem

I don’t see why axes can be empty and how it can solves this problem. I think the problem is that a foldable conv2d + bn (this bug doesnt occurs if there is no bn here) is indirectly refered twice. e.g.

  conv1 + bn
        |
     max pool
   /          \
conv2+bn    conv3+bn

In this case, max pool is transformed twice, which will also transform conv1 + bn twice. However, the result of Transform(conv1 + bn) is memorized (because Transform is called via ExprMutator::Mutate), so only max pool is duplicated.
Anyway, memorization works.

I see, can you send a patch and a regression test?

Sure, will send a PR