[VTA] Error when quantizing MxNet model

I am trying to run an MxNet model on an FPGA using VTA.

gluon_model = ... # some MxNet model
mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)

shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})

# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]):
    relay_prog = relay.quantize.quantize(mod['main'], params=params)

When I run the above code, I get the error:

Traceback (most recent call last):
  File "resNet.py", line 148, in <module>
    relay_prog = relay.quantize.quantize(mod['main'], params=params)
  File "/home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/relay/quantize/quantize.py", line 366, in quantize
    mod = quantize_seq(mod)
  File "/home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/relay/transform.py", line 185, in __call__
    return _transform.RunPass(self, mod)
  File "/home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/_ffi/_ctypes/function.py", line 209, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>::VisitExpr(tvm::relay::Expr const&)+0xd2) [0x7f7b08c4cf62]
  [bt] (7) /home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(std::_Function_handler<tvm::relay::Expr (tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*), tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>::InitVTable()::{lambda(tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*)#6}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*&&)+0x34) [0x7f7b08c487a4]
  [bt] (6) /home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::relay::ForwardRewriter::VisitExpr_(tvm::relay::CallNode const*)+0x2c4) [0x7f7b08e13a54]
  [bt] (5) /home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::relay::ForwardRewriter::GetTempExpr(tvm::relay::Expr const&)+0x42) [0x7f7b08e128b2]
  [bt] (4) /home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::relay::ExprMutator::VisitExpr(tvm::relay::Expr const&)+0x9e) [0x7f7b08c4641e]
  [bt] (3) /home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>::VisitExpr(tvm::relay::Expr const&)+0xd2) [0x7f7b08c4cf62]
  [bt] (2) /home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(std::_Function_handler<tvm::relay::Expr (tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*), tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>::InitVTable()::{lambda(tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*)#6}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Expr (tvm::relay::Expr const&)>*&&)+0x34) [0x7f7b08c487a4]
  [bt] (1) /home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(tvm::relay::ForwardRewriter::VisitExpr_(tvm::relay::CallNode const*)+0x5ec) [0x7f7b08e13d7c]
  [bt] (0) /home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/libtvm.so(+0x11c0b5b) [0x7f7b08f35b5b]
  File "/home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/_ffi/_ctypes/function.py", line 71, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/relay/quantize/_annotate.py", line 112, in frewrite_with_guard
    return func(ref_call, new_args, ctx)
  File "/home/youn/.local/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/relay/quantize/_annotate.py", line 281, in add_rewrite
    raise ValueError()
TVMError: ValueError

Would really appreciate some help on figuring out what could be wrong. Thanks!

Can you print lhs_kind and rhs_kind in add_rewrite in _annotate.py that caused error? Likely we have another unhandled case

Do you mind sharing what model you are trying to process? Is it from the gluon model zoo? It’s quite possible that this model is breaking the quantization pass.

Hey! Thanks for the reply. When I print lhs_kind and rhs_kind, I get

lhs_kind: 1
rhs_kind: 3

Hey! Thank you for the reply. It’s not exactly the model zoo but a modification of it.

This code should reproduce the model. It splits every convolution layer such that I just have half of the original number of filters:

def splitWeights(weights):
    numFilters = weights.shape[0]
    numInputChannels = weights.shape[1]
    kernelSize = weights.shape[2]

    newWeights = nd.empty((numFilters // 2, numInputChannels // 2, kernelSize, kernelSize))
    for i in range(numFilters // 2):
        newWeights[i] = weights[i][:numInputChannels // 2]
    return newWeights

foundFirstConv = False
def splitNet(net):
    global foundFirstConv
    for key, layer in net._children.items():
        newLayer = None
        if isinstance(layer, gluon.nn.Conv2D):
            newLayer = gluon.nn.Conv2D(
                channels=layer._channels // 2,
                kernel_size=layer._kwargs['kernel'],
                strides=layer._kwargs['stride'],
                padding=layer._kwargs['pad'],
                in_channels=3 if not foundFirstConv else layer._in_channels // 2
            )
        elif isinstance(layer, gluon.nn.BatchNorm):
            newLayer = gluon.nn.BatchNorm(
                axis=layer._kwargs['axis'],
                epsilon=layer._kwargs['eps'],
                momentum=layer._kwargs['momentum'],
                scale=not layer._kwargs['fix_gamma'],
                use_global_stats=layer._kwargs['use_global_stats'],
                in_channels=0 if not hasattr(layer, 'in_channels') else layer.in_channels // 2
            )

        if newLayer is not None:
            with net.name_scope():
                if hasattr(net, key):
                    setattr(net, key, newLayer)
                net.register_child(newLayer, key)
            if isinstance(newLayer, gluon.nn.Conv2D):
                weights = layer.weight.data()
                if not foundFirstConv:
                    newWeights = weights[:newLayer._channels]
                    foundFirstConv = True
                else:
                    newWeights = splitWeights(weights)
                newLayer.collect_params().initialize(mx.initializer.Constant(newWeights))
            elif isinstance(newLayer, gluon.nn.BatchNorm):
                pdict = layer.collect_params()
                oldPrefix = layer._prefix
                for k in pdict.keys():
                    data = pdict[k].data()
                    data = data[:data.shape[0]]
                    newLayer.collect_params(newLayer._prefix + k[len(oldPrefix):]).initialize(mx.initializer.Constant(data))

        splitNet(layer)

# Get off-the-shelf gluon model
gluon_model = vision.get_model(model, pretrained=True)
features = gluon_model.features

# Trim dense layers
gluon_model = gluon.nn.HybridSequential()   
gluon_model.add(features)

# split network (first half)
splitNet(gluon_model)

# one forward pass
dummy = nd.random_normal(shape=(1, 3, 224, 224))
output = gluon_model(dummy)
print(gluon_model)

This gives me the network:

HybridSequential(
  (0): HybridSequential(
    (0): Conv2D(3 -> 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=32)
    (2): Activation(relu)
    (3): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=False)
    (4): HybridSequential(
      (0): BasicBlockV1(
        (body): HybridSequential(
          (0): Conv2D(32 -> 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=32)
          (2): Activation(relu)
          (3): Conv2D(32 -> 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=32)
        )
      )
      (1): BasicBlockV1(
        (body): HybridSequential(
          (0): Conv2D(32 -> 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=32)
          (2): Activation(relu)
          (3): Conv2D(32 -> 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=32)
        )
      )
    )
    (5): HybridSequential(
      (0): BasicBlockV1(
        (body): HybridSequential(
          (0): Conv2D(32 -> 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
          (2): Activation(relu)
          (3): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        )
        (downsample): HybridSequential(
          (0): Conv2D(32 -> 64, kernel_size=(1, 1), stride=(2, 2))
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        )
      )
      (1): BasicBlockV1(
        (body): HybridSequential(
          (0): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
          (2): Activation(relu)
          (3): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        )
      )
    )
    (6): HybridSequential(
      (0): BasicBlockV1(
        (body): HybridSequential(
          (0): Conv2D(64 -> 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
          (2): Activation(relu)
          (3): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
        )
        (downsample): HybridSequential(
          (0): Conv2D(64 -> 128, kernel_size=(1, 1), stride=(2, 2))
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
        )
      )
      (1): BasicBlockV1(
        (body): HybridSequential(
          (0): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
          (2): Activation(relu)
          (3): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
        )
      )
    )
    (7): HybridSequential(
      (0): BasicBlockV1(
        (body): HybridSequential(
          (0): Conv2D(128 -> 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
          (2): Activation(relu)
          (3): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
        )
        (downsample): HybridSequential(
          (0): Conv2D(128 -> 256, kernel_size=(1, 1), stride=(2, 2))
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
        )
      )
      (1): BasicBlockV1(
        (body): HybridSequential(
          (0): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
          (2): Activation(relu)
          (3): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
        )
      )
    )
    (8): GlobalAvgPool2D(size=(1, 1), stride=(1, 1), padding=(0, 0), ceil_mode=True)
  )
)

try the patch for add_rewrite in this pr https://github.com/dmlc/tvm/pull/3538/files#diff-aeb26fe114f5a7c8c16eb2c837cbe88aR278

3 Likes

This seems to have fixed the issue. Thanks, you’re awesome!