[VTA][ONNX] Error while loading ResNet18 model

I plan to deploy pre-trained models exported to ONNX on PYNQ boards with VTA. After studying VTA and following the tutorials, I tried to deploy my own ResNet18 model from ONNX with the same configuration as in the tutorial but I encounter the following error after quantizing the network with relay.

Traceback (most recent call last):

  File "deploy_classification.py", line 136, in <module>
    stop_name=pack_dict[model][1])

  File "/home/tvm/vta/python/vta/top/graphpack.py", line 466, in graph_pack
    expr = packer.visit(expr)

  File "home//tvm/python/tvm/relay/expr_functor.py", line 44, in visit
    res = self.visit_function(expr)

  File "/home/tvm/python/tvm/relay/expr_functor.py", line 200, in visit_function
    new_body = self.visit(fn.body)

  File "/home/tvm/python/tvm/relay/expr_functor.py", line 46, in visit
    res = self.visit_call(expr)

  File "/home/tvm/vta/python/vta/top/graphpack.py", line 207, in visit_call
    args = [self.visit(arg) for arg in call.args]
.
.
.
File "/home/tvm/python/tvm/relay/expr_functor.py", line 46, in visit
    res = self.visit_call(expr)

  File "/home/tvm/vta/python/vta/top/graphpack.py", line 235, in visit_call
    self.cfactor)

  File "/home/tvm/vta/python/vta/top/graphpack.py", line 78, in _weight_shape_match
    channels_pad = int(channels) % cfactor_out

TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'

From what I was able to track and I understand call.attrs.channels in graphpack.py doesnā€™t return anything, so the relay expression might be missing this parameter or canā€™t find it, correct me if iā€™m wrong. The dumped model is the following:

def @main(%data: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 1000), float32] {
  %0 = add(%data, meta[relay.Constant][0] /* ty=Tensor[(3, 1, 1), float32] */ /* ty=Tensor[(3, 1, 1), float32] */) /* ty=Tensor[(1, 3, 224, 224), float32] */;
  %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 3, 7, 7), float32] */ /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %2 = add(%1, meta[relay.Constant][2] /* ty=Tensor[(64, 1, 1), float32] */ /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %3 = nn.relu(%2) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %5 = add(%4, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */ /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %7 = annotation.stop_fusion(%6) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %8 = multiply(%7, 16f /* ty=float32 */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %9 = round(%8) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %10 = clip(%9, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %11 = cast(%10, dtype="int8") /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %12 = nn.conv2d(%11, meta[relay.Constant][4] /* ty=Tensor[(64, 64, 3, 3), int8] */ /* ty=Tensor[(64, 64, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %13 = add(%12, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), int32] */ /* ty=Tensor[(64, 1, 1), int32] */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %14 = nn.relu(%13) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %15 = add(%14, 4 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %16 = right_shift(%15, 3 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %17 = clip(%16, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %18 = cast(%17, dtype="int8") /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %19 = annotation.stop_fusion(%18) /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %20 = nn.conv2d(%19, meta[relay.Constant][6] /* ty=Tensor[(64, 64, 3, 3), int8] */ /* ty=Tensor[(64, 64, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %21 = add(%20, 64 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %22 = right_shift(%21, 7 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %23 = clip(%22, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %24 = cast(%23, dtype="int8") /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %25 = annotation.stop_fusion(%24) /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %26 = cast(%25, dtype="int32") /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %27 = annotation.stop_fusion(%4) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %28 = multiply(%27, 16f /* ty=float32 */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %29 = round(%28) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %30 = clip(%29, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %31 = cast(%30, dtype="int32") /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %32 = add(%26, %31) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %33 = clip(%32, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %34 = cast(%33, dtype="int8") /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %35 = annotation.stop_fusion(%34) /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %36 = cast(%35, dtype="int32") /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %37 = left_shift(%36, 3 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %38 = add(%37, meta[relay.Constant][7] /* ty=Tensor[(64, 1, 1), int32] */ /* ty=Tensor[(64, 1, 1), int32] */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %39 = nn.relu(%38) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %40 = add(%39, 4 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %41 = right_shift(%40, 3 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %42 = clip(%41, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %43 = cast(%42, dtype="int8") /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %44 = nn.conv2d(%43, meta[relay.Constant][8] /* ty=Tensor[(64, 64, 3, 3), int8] */ /* ty=Tensor[(64, 64, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %45 = add(%44, meta[relay.Constant][9] /* ty=Tensor[(64, 1, 1), int32] */ /* ty=Tensor[(64, 1, 1), int32] */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %46 = nn.relu(%45) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %47 = add(%46, 16 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %48 = right_shift(%47, 5 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %49 = clip(%48, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %50 = cast(%49, dtype="int8") /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %51 = annotation.stop_fusion(%50) /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %52 = nn.conv2d(%51, meta[relay.Constant][10] /* ty=Tensor[(64, 64, 3, 3), int8] */ /* ty=Tensor[(64, 64, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %53 = add(%52, 128 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %54 = right_shift(%53, 8 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %55 = clip(%54, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %56 = cast(%55, dtype="int8") /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %57 = annotation.stop_fusion(%56) /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %58 = cast(%57, dtype="int32") /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %59 = cast(%33, dtype="int8") /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %60 = annotation.stop_fusion(%59) /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %61 = cast(%60, dtype="int32") /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %62 = add(%58, %61) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %63 = left_shift(%62, 2 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %64 = add(%63, meta[relay.Constant][11] /* ty=Tensor[(64, 1, 1), int32] */ /* ty=Tensor[(64, 1, 1), int32] */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %65 = nn.relu(%64) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %66 = add(%65, 2 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %67 = right_shift(%66, 2 /* ty=int32 */) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %68 = clip(%67, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 56, 56), int32] */;
  %69 = cast(%68, dtype="int8") /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %70 = nn.conv2d(%69, meta[relay.Constant][12] /* ty=Tensor[(128, 64, 3, 3), int8] */ /* ty=Tensor[(128, 64, 3, 3), int8] */, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %71 = add(%70, meta[relay.Constant][13] /* ty=Tensor[(128, 1, 1), int32] */ /* ty=Tensor[(128, 1, 1), int32] */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %72 = nn.relu(%71) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %73 = add(%72, 128 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %74 = right_shift(%73, 8 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %75 = clip(%74, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %76 = cast(%75, dtype="int8") /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %77 = annotation.stop_fusion(%76) /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %78 = nn.conv2d(%77, meta[relay.Constant][14] /* ty=Tensor[(128, 128, 3, 3), int8] */ /* ty=Tensor[(128, 128, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %79 = add(%78, 128 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %80 = right_shift(%79, 8 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %81 = clip(%80, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %82 = cast(%81, dtype="int8") /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %83 = annotation.stop_fusion(%82) /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %84 = cast(%68, dtype="int8") /* ty=Tensor[(1, 64, 56, 56), int8] */;
  %85 = nn.conv2d(%84, meta[relay.Constant][15] /* ty=Tensor[(128, 64, 1, 1), int8] */ /* ty=Tensor[(128, 64, 1, 1), int8] */, strides=[2, 2], padding=[0, 0, 0, 0], kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %86 = add(%85, 128 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %87 = right_shift(%86, 8 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %88 = clip(%87, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %89 = cast(%88, dtype="int8") /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %90 = annotation.stop_fusion(%89) /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %91 = add(%83, %90) /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %92 = clip(%91, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %93 = cast(%92, dtype="int32") /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %94 = left_shift(%93, 2 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %95 = add(%94, meta[relay.Constant][16] /* ty=Tensor[(128, 1, 1), int32] */ /* ty=Tensor[(128, 1, 1), int32] */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %96 = nn.relu(%95) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %97 = add(%96, 2 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %98 = right_shift(%97, 2 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %99 = clip(%98, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %100 = cast(%99, dtype="int8") /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %101 = nn.conv2d(%100, meta[relay.Constant][17] /* ty=Tensor[(128, 128, 3, 3), int8] */ /* ty=Tensor[(128, 128, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %102 = add(%101, meta[relay.Constant][18] /* ty=Tensor[(128, 1, 1), int32] */ /* ty=Tensor[(128, 1, 1), int32] */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %103 = nn.relu(%102) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %104 = add(%103, 64 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %105 = right_shift(%104, 7 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %106 = clip(%105, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %107 = cast(%106, dtype="int8") /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %108 = annotation.stop_fusion(%107) /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %109 = nn.conv2d(%108, meta[relay.Constant][19] /* ty=Tensor[(128, 128, 3, 3), int8] */ /* ty=Tensor[(128, 128, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %110 = add(%109, 128 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %111 = right_shift(%110, 8 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %112 = clip(%111, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %113 = cast(%112, dtype="int8") /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %114 = annotation.stop_fusion(%113) /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %115 = cast(%92, dtype="int8") /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %116 = annotation.stop_fusion(%115) /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %117 = cast(%116, dtype="int8") /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %118 = add(%114, %117) /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %119 = cast(%118, dtype="int32") /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %120 = left_shift(%119, 2 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %121 = add(%120, meta[relay.Constant][20] /* ty=Tensor[(128, 1, 1), int32] */ /* ty=Tensor[(128, 1, 1), int32] */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %122 = nn.relu(%121) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %123 = add(%122, 2 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %124 = right_shift(%123, 2 /* ty=int32 */) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %125 = clip(%124, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 128, 28, 28), int32] */;
  %126 = cast(%125, dtype="int8") /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %127 = nn.conv2d(%126, meta[relay.Constant][21] /* ty=Tensor[(256, 128, 3, 3), int8] */ /* ty=Tensor[(256, 128, 3, 3), int8] */, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %128 = add(%127, meta[relay.Constant][22] /* ty=Tensor[(256, 1, 1), int32] */ /* ty=Tensor[(256, 1, 1), int32] */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %129 = nn.relu(%128) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %130 = add(%129, 128 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %131 = right_shift(%130, 8 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %132 = clip(%131, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %133 = cast(%132, dtype="int8") /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %134 = annotation.stop_fusion(%133) /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %135 = nn.conv2d(%134, meta[relay.Constant][23] /* ty=Tensor[(256, 256, 3, 3), int8] */ /* ty=Tensor[(256, 256, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %136 = add(%135, 128 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %137 = right_shift(%136, 8 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %138 = clip(%137, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %139 = cast(%138, dtype="int8") /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %140 = annotation.stop_fusion(%139) /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %141 = cast(%125, dtype="int8") /* ty=Tensor[(1, 128, 28, 28), int8] */;
  %142 = nn.conv2d(%141, meta[relay.Constant][24] /* ty=Tensor[(256, 128, 1, 1), int8] */ /* ty=Tensor[(256, 128, 1, 1), int8] */, strides=[2, 2], padding=[0, 0, 0, 0], kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %143 = add(%142, 256 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %144 = right_shift(%143, 9 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %145 = clip(%144, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %146 = cast(%145, dtype="int8") /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %147 = annotation.stop_fusion(%146) /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %148 = add(%140, %147) /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %149 = clip(%148, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %150 = cast(%149, dtype="int32") /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %151 = left_shift(%150, 3 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %152 = add(%151, meta[relay.Constant][25] /* ty=Tensor[(256, 1, 1), int32] */ /* ty=Tensor[(256, 1, 1), int32] */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %153 = nn.relu(%152) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %154 = add(%153, 4 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %155 = right_shift(%154, 3 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %156 = clip(%155, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %157 = cast(%156, dtype="int8") /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %158 = nn.conv2d(%157, meta[relay.Constant][26] /* ty=Tensor[(256, 256, 3, 3), int8] */ /* ty=Tensor[(256, 256, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %159 = add(%158, meta[relay.Constant][27] /* ty=Tensor[(256, 1, 1), int32] */ /* ty=Tensor[(256, 1, 1), int32] */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %160 = nn.relu(%159) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %161 = add(%160, 32 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %162 = right_shift(%161, 6 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %163 = clip(%162, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %164 = cast(%163, dtype="int8") /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %165 = annotation.stop_fusion(%164) /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %166 = nn.conv2d(%165, meta[relay.Constant][28] /* ty=Tensor[(256, 256, 3, 3), int8] */ /* ty=Tensor[(256, 256, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %167 = add(%166, 128 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %168 = right_shift(%167, 8 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %169 = clip(%168, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %170 = cast(%169, dtype="int8") /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %171 = annotation.stop_fusion(%170) /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %172 = cast(%149, dtype="int8") /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %173 = annotation.stop_fusion(%172) /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %174 = cast(%173, dtype="int8") /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %175 = add(%171, %174) /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %176 = cast(%175, dtype="int32") /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %177 = left_shift(%176, 2 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %178 = add(%177, meta[relay.Constant][29] /* ty=Tensor[(256, 1, 1), int32] */ /* ty=Tensor[(256, 1, 1), int32] */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %179 = nn.relu(%178) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %180 = add(%179, 2 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %181 = right_shift(%180, 2 /* ty=int32 */) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %182 = clip(%181, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 256, 14, 14), int32] */;
  %183 = cast(%182, dtype="int8") /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %184 = nn.conv2d(%183, meta[relay.Constant][30] /* ty=Tensor[(512, 256, 3, 3), int8] */ /* ty=Tensor[(512, 256, 3, 3), int8] */, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %185 = add(%184, meta[relay.Constant][31] /* ty=Tensor[(512, 1, 1), int32] */ /* ty=Tensor[(512, 1, 1), int32] */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %186 = nn.relu(%185) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %187 = add(%186, 128 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %188 = right_shift(%187, 8 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %189 = clip(%188, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %190 = cast(%189, dtype="int8") /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %191 = annotation.stop_fusion(%190) /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %192 = nn.conv2d(%191, meta[relay.Constant][32] /* ty=Tensor[(512, 512, 3, 3), int8] */ /* ty=Tensor[(512, 512, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %193 = add(%192, 128 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %194 = right_shift(%193, 8 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %195 = clip(%194, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %196 = cast(%195, dtype="int8") /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %197 = annotation.stop_fusion(%196) /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %198 = cast(%182, dtype="int8") /* ty=Tensor[(1, 256, 14, 14), int8] */;
  %199 = nn.conv2d(%198, meta[relay.Constant][33] /* ty=Tensor[(512, 256, 1, 1), int8] */ /* ty=Tensor[(512, 256, 1, 1), int8] */, strides=[2, 2], padding=[0, 0, 0, 0], kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %200 = add(%199, 128 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %201 = right_shift(%200, 8 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %202 = clip(%201, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %203 = cast(%202, dtype="int8") /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %204 = annotation.stop_fusion(%203) /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %205 = add(%197, %204) /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %206 = clip(%205, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %207 = cast(%206, dtype="int32") /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %208 = left_shift(%207, 4 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %209 = add(%208, meta[relay.Constant][34] /* ty=Tensor[(512, 1, 1), int32] */ /* ty=Tensor[(512, 1, 1), int32] */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %210 = nn.relu(%209) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %211 = add(%210, 8 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %212 = right_shift(%211, 4 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %213 = clip(%212, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %214 = cast(%213, dtype="int8") /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %215 = nn.conv2d(%214, meta[relay.Constant][35] /* ty=Tensor[(512, 512, 3, 3), int8] */ /* ty=Tensor[(512, 512, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %216 = add(%215, meta[relay.Constant][36] /* ty=Tensor[(512, 1, 1), int32] */ /* ty=Tensor[(512, 1, 1), int32] */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %217 = nn.relu(%216) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %218 = add(%217, 32 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %219 = right_shift(%218, 6 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %220 = clip(%219, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %221 = cast(%220, dtype="int8") /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %222 = annotation.stop_fusion(%221) /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %223 = nn.conv2d(%222, meta[relay.Constant][37] /* ty=Tensor[(512, 512, 3, 3), int8] */ /* ty=Tensor[(512, 512, 3, 3), int8] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %224 = add(%223, 256 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %225 = right_shift(%224, 9 /* ty=int32 */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %226 = clip(%225, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %227 = cast(%226, dtype="int8") /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %228 = annotation.stop_fusion(%227) /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %229 = cast(%206, dtype="int8") /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %230 = annotation.stop_fusion(%229) /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %231 = cast(%230, dtype="int8") /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %232 = add(%228, %231) /* ty=Tensor[(1, 512, 7, 7), int8] */;
  %233 = cast(%232, dtype="int32") /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %234 = multiply(%233, meta[relay.Constant][38] /* ty=Tensor[(512, 1, 1), int32] */ /* ty=Tensor[(512, 1, 1), int32] */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %235 = add(%234, meta[relay.Constant][39] /* ty=Tensor[(512, 1, 1), int32] */ /* ty=Tensor[(512, 1, 1), int32] */) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %236 = nn.relu(%235) /* ty=Tensor[(1, 512, 7, 7), int32] */;
  %237 = nn.global_avg_pool2d(%236) /* ty=Tensor[(1, 512, 1, 1), int32] */;
  %238 = cast(%237, dtype="float32") /* ty=Tensor[(1, 512, 1, 1), float32] */;
  %239 = multiply(%238, 0.0078125f /* ty=float32 */) /* ty=Tensor[(1, 512, 1, 1), float32] */;
  %240 = reshape(%239, newshape=[0, -1]) /* ty=Tensor[(1, 512), float32] */;
  %241 = nn.batch_flatten(%240) /* ty=Tensor[(1, 512), float32] */;
  %242 = multiply(1f /* ty=float32 */, %241) /* ty=Tensor[(1, 512), float32] */;
  %243 = nn.dense(%242, meta[relay.Constant][40] /* ty=Tensor[(1000, 512), float32] */ /* ty=Tensor[(1000, 512), float32] */, units=1000) /* ty=Tensor[(1, 1000), float32] */;
  add(%243, meta[relay.Constant][41] /* ty=Tensor[(1000), float32] */ /* ty=Tensor[(1000), float32] */) /* ty=Tensor[(1, 1000), float32] */
}

Can anybody help me? Is there any solution I can try? Iā€™m also interested in deploying low-precision models on FPGAs with VTA and would like to contribute.

Thanks in advance

Did you invoke graphpack function with argument maxpool 4 avgpool 237? It seems everything is ok from your relay function. Did you skip the first conv2D?

Yes, I provided the start/stop pack idx for maxpool and avgpool, also skipped the first convolution

Hi, were you able to resolve this problem?

I think the error is due to: The graph_pack implementation assumes that the call.attrs always has attribute ā€˜channelsā€™, while when conveting onnx models to relay functions, this assumption is violated(although this assumption works well for the MxNet models in the tutorial).

To support graph packing different models from different frontends in VTA, I think we should use alternative ways to get value of ā€œchannelsā€ rather than getting it directly from the call.attrs.

1 Like