TF Lite quantized conv2d operator conversion

consider handling symmetric int8 kernel (fixed 0 offset) and uint8 with 0 offset as specializations.

This requires that the TFLite model provided by user uses a symmetric quantization, which is expected to be generated by TF/TFLite tools. AFAIK, the official pre-trained MobileNetV1 is not. Assuming 0 offset seems a bit aggressive?

also, I believe tflite implementations perform quantized downscale to uint8 feature data between layers. In tflite these are implemented by integer mpy and downshift operations. Some target devices don’t have fp, and especially not dp, so maybe consider supporting the integer implementation as an option.

Yes, enabling integer only devices is one of the targets of Gemmlowp. You are expert on this :slight_smile: If we are taking TFLite’s approach, the computing inside q_conv2d are purely integer operation, including the requantization which converts accumulated INT32 result back to INT8. So, no worry about integer implementation :slight_smile:

I wouldn’t say expert, but I ported the tflite quantized inception models to a risc-v architecture in my last job.

Note that the tflite implementation uses int64 operations in the integer downscale operations, while some targets do not support that. So your integer implementation may need to allow for that.

The downscale integer operations will also need to do rounding and saturation.

Note that the tflite implementation uses int64 operations in the integer downscale operations, while some targets do not support that.

That is interesting, so you were handling it on ISA level or C (with complier’s help to emulate int64)?

The downscale only stores 8 bit data from the upper word of the int64 result. The downscale constant mpy moves data into the upper word, and then a right shift round, and then saturate and store the uint8 feature value.

The risc-v had a form of multiply instruction that just keeps the upper word of what would have been a 64 bit result… so that provided the needed part for the downscale.

The tflite implementations include an option for selecting reference code. I found it very useful for debug to modify their reference code to do printf output before and after activation, rounding, saturation and I prefixed the output with the c,h,w index values. The prefix c,h,w can be used to sort the lines so that you can match up with your implementation output order is (assuming it also tags and outputs its values). I converted six of the tflite quantized models to risc-v using this.

so, this is just a suggestion, from my own experience, that you provide a similar way to dump the data while trying to match the tflite reference per layer data. Then you can do whatever other optimizations and parallel processing and tune for performance.

I modified tflite operations to support dumping for debug of six models … the four inceptions and two mobilenets. This was back in March, but might be some help. This version was used on ubuntu 18.04.
tflite c++ model printf debug

1 Like

@jackwish @FrozenGene Thanks for explaining the tradeoffs. I got some time to think more deeply about this.

Σc,r,s QW(k, c, r, s) × QA(n, c, h + r, w + s) // Term 1
  - Σc,r,s zp_a × QW(k, c, r, s)                     // Term 2
  - Σc,r,s zp_w × QA(n, c, h + r, w + s)        // Term 3
  + Σc,r,s zp_a × zp_w                                // Term 4

I came up with following conclusions (After reading FBGemm and QNNPACK)

  • Partial output reuse - Intermediate INT32 value reuse - We should strive to perform all the computation for intermediate INT32 value before going to the memory. @jackwish and @FrozenGene performed this by fusing the core computation (Term 1 to Term4 in the above equation) along with requantize, so that INT32 is always served in the register file and never has to spilled in memory and brought back.
  • Input reuse - Quantized input matrix reuse - This is also very necessary. If we look at the equation, Term 3 is basically a big offset matrix that has to be applied to Term 1. Both Term 1 and Term 3 share the quantized input matrix. If we perform the calculation in different operators and no fusion, we will not be able to reuse quantized matrix A.

So, in current form, Option 1 does not satisfy both of the above requirements, potentially leaving significant performance opportunity on the table. As performance is primary goal (atleast in my case), I am also leaning towards option 2 as well (with a dread of writing a tensorized schedule). However, I do have a couple of unresolved thoughts that are bothering me.

  1. For option 2, we are performing computations for Term 2 and Term 4 at runtime. I don’t know how much penalty can be avoided by pre-computing that at compile time. This is also a limitation of current TVM infrastructure, where pre-compute/fold-constant is only limited to Relay.
  2. Making a final “weak” case for option 1 here. So, it might be possible to keep 4 terms as separate Relay ops. Precompute Term 2 and Term 4 using Relay passes, solving the problem in point 1. And then somehow perform both horizontal and vertical fusion to fuse everything into one giant op. Then, we can use compute_inline() and compute_at() to perform both input and output reuse. Again, this is “weak” case, this does not sound easy to do. (This is somewhat in line with FBGemm, where they have more granular APIs, where some APIs are executed at compile time).

If we merge all 4 terms back into its original form, which is something like sum((A - a_zp) * (W - w_zp)), the subtract happens only when we loading input or weights from memory, which may not the bottleneck in practice if there were a good schedule.

Precomputing is good, and QNNPACK’s approach is precomputing Term 2 & 4 and merge them into bias before operator runs (QNNPACK don’t have a compile time, but do have a prepare stage). However, it won’t be easy to do similar work in TVM…

One thing I should mention (hope my comments won’t mislead in any direction) is that, though I have kept saying that reducing INT32 intermedia memory access is important for performance, the fact is that the memory access pattern (schedule) is very significant too.

Yes, I understand that portion. Loop optimizations and memory access patterns are extremely important for performance.

Thanks everybody for the great discussion. I will put up an RFC for reading unquantized models from MxNet and TFLite to Relay in next few days.

Thanks, everyone for insightful discussions. I also want to point everyone back to the original quantization workflow RFC: https://github.com/dmlc/tvm/issues/2259

There are a few facts that I want to highlight.

There are more than one quantization schemes and different resulting speed-accuracy tradeoffs

“Quantization” is a generic term that has been used for many methods, specifically, there are choices of

  • Different bitwidth, sign/unsigned in different layers
  • Symmetric vs asymmetric
  • Can use floating pt multiplication vs force to only use integer shift

Most frameworks, like TFLite tries to have one opinionated view about the quantized model this. There is a good need to support importing these models, but we also need to keep in mind that adopting directly adopting the implied scheme may not give the best performance.

Sometimes, we might even want to use a mixed-scheme across neural networks. Just like many of the discussions mentioned here. Most troubles are due to the asymetric scheme :slight_smile: We should always keep this in mind. There is no one perfect solution and we want to build API to make sure we can cover more

How many quantized ops we want to introduce

As a general rule of thumb, I would recommend us to minimize the number of new operators as well as their names. For example, both quantize and dequantize can likely be converted to subtraction multiplication and cast. This means that we do not want to introduce these operators at least not at the core level. The less new operators we introduce, the easier it is to reuse many of the existing infrastructures. We can likely do a similiar thing for relu(becomes clip) and maxpool(which will be the same, assuming a single zero point). As a matter of fact, giving up asymmetry will enable most of the quantization pipeline fall into the current core operator.

We can, however, introduce dialects that facilitate the conversion, in such case relay.contrib.q_conv2d or is an OK name as long as we try to lower as many of them as possible to the low0level operators.

1 Like

It’s a good idea to make it possible in TVM to utilize existing op in a symmetric scheme. I’d like to share some knowledge regarding symmetric/asymmetric, hope it may help the design decision.

  • TensorFlow/TFLite use asymmetric scheme by default, the pre-trianed quantized MobileNetV1 (which is built from quantization-aware training), though it supports symmetric.
  • PyTorch/Caffe2/QNNPACK seems to follow the asymmetric approach. (By seem, I mean zero point is essential in code, but there is no detail document stating that.)
  • TensorRT adopts a symmetric design.

I think introduce a namespace like relay.contrib.quantize is one good solution. We could introduce q_conv2d / q_fully_connected and so on in this namespace.

I completely agree with this. However, supporting asymmetric quantization might be necessary to keep the accuracy in check. So, I think we need to support asymmetric quantization.

I also like this idea.


I think there are three things we are trying to balance - performance, accuracy and reusing existing compiler infrastructure (or keeping code clean).

For example, an asymmetric quantized_convolution can be a dialect, which we can rewrite using low-level ops (Option 1). As discussed above, this might lead to bad performance but good accuracy and good reuse of existing compiler infrastructure. On the other hand, Option 2 leads to almost best performance and accuracy, but at the expense of a new operator.


As @tqchen aptly said - “There is no one perfect solution and we want to build API to make sure we can cover more”

With that in mind, I think going with contrib ops and a separate namespace might be a good mid-way solution. To be more precise, we won’t have a TVM compute for this contrib op. Instead, we will rewrite these ops to be a sequence of low-level ops. We will add new low-level ops if something is not directly implementable with what we have today. This might take a performance hit but should keep the design modular and simpler.

  • If I have not forgotten how to do maths, this should satisfy asymmetric quantization requirements :slight_smile:
  • If one does want the best performance, I can envision a Relay pass to fuse the concerned ops into one op and write the schedule for that. This does not sound easy at all. But, something has to take a hit.

@tqchen, please let me know if I understood your comment correctly.
@jackwish and @FrozenGene, given this somewhat mismatches with what you had in mind, please comment if this makes any sense.

I want to say that it is not necessary true that asymmetric quantization is necessary to keep accuracy in check. As a matter of fact, if we think about it, asymmetric quantization only will give you at most 1-bit of additional accuracy gain. And usually that is not significant. Symmetric will give you as good accuracy. Due to the impl efficiency difference, we could even implement 16bit integer input instead(which gives much better accuracy) that could be as good as the 8bit input version.

There are also other source of accuracy factors, such as per channel scaling vs the global scaling that affect accuracy more than asymmetry.

So the main reason to support asymmetry is not necessarily for accuracy, but mainly for compatibility concerns

Note: asymmetric quantization can be represented directly represented by low level operators like sub multiply and cast.

I understand your point. At the same time, it will be bad if we just stick to symmetric quantization, and return models that have been quantized using asymmetric method.

Let’s try to ensure that we can support both if need be. We can certainly start with supporting symmetric quantization to setup and clean the whole flow, while ensuring that we don’t close the path for asymmetric.

2 Likes

I’m interested in the story around quantized networks.

In option 1, term 1 newer Arm cores (Armv8.2-A optionally, Armv8.4-A onwards mandatorily) including Neoverse N1, Cortex-A55 and others implement a dot product instruction from uint8 x uint8 vectors to uint32 vectors. See references below for more. This instruction exists in the AArch32 world as well as the AArch64 world. There are also vector with single scalar versions available. Would these be useful here in accelerating conv2d or other schedules ?

I don’t believe the loop vectorizer vectorizes this in LLVM today and that would probably take quite a bit of work and even then it may not be able to get the right results.

From my understanding this is something that could be experimented with either using the
a. Inline assembler form for instructions
b. Lowering directly to an llvm intrinsic for the vdot instruction. Arm does publish something known as the ACLE for use within C and C++ applications.

References

  1. https://community.arm.com/developer/tools-software/tools/b/tools-software-ides-blog/posts/exploring-the-arm-dot-product-instructions
  2. https://developer.arm.com/docs/ddi0596/latest/simd-and-floating-point-instructions-alphabetic-order/udot-vector-dot-product-unsigned-arithmetic-vector (Udot , vector by vector)
  3. https://developer.arm.com/docs/ddi0596/latest/simd-and-floating-point-instructions-alphabetic-order/udot-by-element-dot-product-unsigned-arithmetic-vector-by-element (udot vector by element)
  4. https://developer.arm.com/docs/ddi0596/latest/simd-and-floating-point-instructions-alphabetic-order/sdot-by-element-dot-product-signed-arithmetic-vector-by-element (Signed dot product by element)
  5. https://developer.arm.com/docs/ddi0596/latest/simd-and-floating-point-instructions-alphabetic-order/sdot-vector-dot-product-signed-arithmetic-vector
    (Signed dot product vector by vector)
  6. https://developer.arm.com/docs/ddi0597/latest/simd-and-floating-point-instructions-alphabetic-order/vsdot-by-element-dot-product-index-form-with-signed-integers
    (VSDOT for AArch32 i.e. 32 bit Arm instruction set in both T32 and A32 instruction sets)
  7. https://developer.arm.com/docs/ddi0597/latest/simd-and-floating-point-instructions-alphabetic-order for AArch32 and find VSDOT and VUDOT instructions.

I hope this helps.

Regards,
Ramana

1 Like

Of course. dot instruction can accelerate q_conv2d. Currently, we have to use SMLAL instruction. If LLVM can not handle it, we could do tensorize like x86 has done.

Thanks - My aim was to bring that into the design discussion now and draw attention to this feature.

Thanks @ramana-arm This is very helpful. TVM provides a feature to directly call the LLVM intrinsic as @FrozenGene mentioned.