conv2d
|
max pool
/ \
conv2d conv2d
\ /
concat
is transformed into
conv2d
/ \
max pool max pool
| |
conv2d conv2d
\ /
concat
The network:
from tvm import relay
from tvm.relay.testing import layers
def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''):
conv = layers.conv2d(
data=data,
channels=int(num_filter),
kernel_size=kernel,
strides=stride,
padding=pad,
name='%s%s_conv1' % (name, suffix))
bn = layers.batch_norm_infer(data=conv, epsilon=2e-5, name='%s%s_bn' % (name, suffix))
act = relay.nn.relu(data=bn)
return act
def Pooling(data, kernel, stride, pad, pool_type, name):
if pool_type == 'max':
return relay.nn.max_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad)
elif pool_type == 'avg':
return relay.nn.avg_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad,
count_include_pad=True)
else:
raise ValueError("Invalid pooling type: " + pool_type)
def get_net(batch_size,
num_classes,
image_shape,
dtype):
data_shape = (batch_size,) + image_shape
data = relay.var("data",
shape=data_shape,
dtype=dtype)
conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv")
pool1 = Pooling(data=conv, kernel=(3, 3), stride=(2, 2), pool_type="max", pad=(0, 0),
name="pool")
conv1 = Conv(pool1, 32, kernel=(3, 3), stride=(2, 2), name="conv1")
conv2 = Conv(pool1, 32, kernel=(3, 3), stride=(2, 2), name="conv2")
concat = relay.concatenate((conv1,conv2), axis=1)
args = relay.ir_pass.free_vars(concat)
return relay.Function(args, concat)
IR after fuse_ops (call relay.build with opt_pass_level = 3, max_pool is called twice in %9 and %19)
fn (%data: Tensor[(1, 3, 299, 299), float32])
-> Tensor[(1, 64, 36, 36), float32] {
%0 = meta.relay.Constant(id=0) # ty=Tensor[(32, 3, 3, 3), float32]
%1 = meta.relay.Constant(id=1) # ty=Tensor[(32, 1, 1), float32]
%2 = fn(%p0: Tensor[(1, 3, 299, 299), float32],
%p1: Tensor[(32, 3, 3, 3), float32],
%p2: Tensor[(32, 1, 1), float32])
-> Tensor[(1, 32, 149, 149), float32] {
%3 = nn.conv2d(%p0, %p1, strides=[2, 2], channels=32, kernel_size=[3, 3]) # ty=Tensor[(1, 32, 149, 149), float32]
%4 = add(%3, %p2) # ty=Tensor[(1, 32, 149, 149), float32]
%5 = nn.relu(%4) # ty=Tensor[(1, 32, 149, 149), float32]
%5
}
%6 = %2(%data, %0, %1) # ty=Tensor[(1, 32, 149, 149), float32]
%7 = fn(%p01: Tensor[(1, 32, 149, 149), float32])
-> Tensor[(1, 32, 74, 74), float32] {
%8 = nn.max_pool2d(%p01, pool_size=[3, 3], strides=[2, 2]) # ty=Tensor[(1, 32, 74, 74), float32]
%8
}
%9 = %7(%6) # ty=Tensor[(1, 32, 74, 74), float32]
%10 = meta.relay.Constant(id=2) # ty=Tensor[(32, 32, 3, 3), float32]
%11 = meta.relay.Constant(id=3) # ty=Tensor[(32, 1, 1), float32]
%12 = fn(%p02: Tensor[(1, 32, 74, 74), float32],
%p11: Tensor[(32, 32, 3, 3), float32],
%p21: Tensor[(32, 1, 1), float32])
-> Tensor[(1, 32, 36, 36), float32] {
%13 = nn.conv2d(%p02, %p11, strides=[2, 2], channels=32, kernel_size=[3, 3]) # ty=Tensor[(1, 32, 36, 36), float32]
%14 = add(%13, %p21) # ty=Tensor[(1, 32, 36, 36), float32]
%15 = nn.relu(%14) # ty=Tensor[(1, 32, 36, 36), float32]
%15
}
%16 = %12(%9, %10, %11) # ty=Tensor[(1, 32, 36, 36), float32]
%17 = fn(%p03: Tensor[(1, 32, 149, 149), float32])
-> Tensor[(1, 32, 74, 74), float32] {
%18 = nn.max_pool2d(%p03, pool_size=[3, 3], strides=[2, 2]) # ty=Tensor[(1, 32, 74, 74), float32]
%18
}
%19 = %17(%6) # ty=Tensor[(1, 32, 74, 74), float32]
%20 = meta.relay.Constant(id=4) # ty=Tensor[(32, 32, 3, 3), float32]
%21 = meta.relay.Constant(id=5) # ty=Tensor[(32, 1, 1), float32]
%22 = fn(%p04: Tensor[(1, 32, 74, 74), float32],
%p12: Tensor[(32, 32, 3, 3), float32],
%p22: Tensor[(32, 1, 1), float32])
-> Tensor[(1, 32, 36, 36), float32] {
%23 = nn.conv2d(%p04, %p12, strides=[2, 2], channels=32, kernel_size=[3, 3]) # ty=Tensor[(1, 32, 36, 36), float32]
%24 = add(%23, %p22) # ty=Tensor[(1, 32, 36, 36), float32]
%25 = nn.relu(%24) # ty=Tensor[(1, 32, 36, 36), float32]
%25
}
%26 = %22(%19, %20, %21) # ty=Tensor[(1, 32, 36, 36), float32]
%27 = (%16, %26)
%28 = fn(%p05: Tuple[Tensor[(1, 32, 36, 36), float32], Tensor[(1, 32, 36, 36), float32]])
-> Tensor[(1, 64, 36, 36), float32] {
%29 = concatenate(%p05, axis=1) # ty=Tensor[(1, 64, 36, 36), float32]
%29
}
%30 = %28(%27) # ty=Tensor[(1, 64, 36, 36), float32]
%30
}
# meta data omitted. you can use show_meta_data=True to include meta-data
cc @tqchen