TF Lite quantized conv2d operator conversion

How about we prototype then? I remember seeing VMLAL instruction long time back just by using LLVM on int8 instructions (no tensorize call). But I might be wrong. For Intel, we do have to use tensorize as LLVM is not strong enough there.

I can write Term 1 and Term 3 python in Option 1 and get some performance numbers. Though these will be unoptimized, hopefully we can get some more clarity. What do you say?

consider handling symmetric int8 kernel (fixed 0 offset) and uint8 with 0 offset as specializations. terms disappear with these. maybe your constant folding handles that already…

also, I believe tflite implementations perform quantized downscale to uint8 feature data between layers. In tflite these are implemented by integer mpy and downshift operations. Some target devices don’t have fp, and especially not dp, so maybe consider supporting the integer implementation as an option.

@jnorwood
Both very good points. Let me answer them one by one

If we look at the following equation.

Σc,r,s QW(k, c, r, s) × QA(n, c, h + r, w + s) // Term 1
  - Σc,r,s zp_a × QW(k, c, r, s)                     // Term 2
  - Σc,r,s zp_w × QA(n, c, h + r, w + s)        // Term 3
  + Σc,r,s zp_a × zp_w                                // Term 4

With both offsets set to 0, we are only left with Term 1. This is normal conv2d of the original quantized matrices.

So, I believe that if we design the q_conv2d API correctly, then setting the zero_point from outside to zero should satisfy. I am sure that even Option 2 can easily account for this in the schedule and not perform redundant SUB. But, good to keep that in mind.

Secondly, yes I agree we need to take care of what is supported in HW. We currently tackle each target separately, so we should be able to keep that differentiation. I agree that rounding and shifts require much more involvement as far as the assembly instructions go. There are lots of tricks there and we can mess up easily.

I think you maybe remember @ziheng’s implementation. His original implementation also make data / weight’s data type be int16, not int8. also confirmed with @merrymercy before.

For the option 1, I think we have one thing we must consider. i.e. padding. In TFLite’s quantization model, the pad value should be input_zero_point, not 0. However, current normal conv2d doesn’t have one parameter to accept input_zero_point.

@jackwish should also have interest.

There are lots of quantized operations in TFLite, do you plan to implement each one needed as a new op in TVM , like q_conv2d?

I think it is not easy to break those tflite op into smaller tvm ops.

Right. Because quantization computation is not the same as normal computation. Such as q_add / q_pool and so on.

So, currently for FP32, we spend lot of time writing schedules for important operators like conv2d, fully connected. The point is that if an operator takes lot of execution time, it is worth it to focus efforts on that.

I am using the same principle here. So, we will have to send time on q_conv2d. As far as writing schedules go, both options might need tensorize work.

I agree that for operators like Relu and Pool, we might want to use the Option 1 as default, and they might give good enough performance.

@FrozenGene I don’t understand the int16 and int8 comment. I was looking at this https://github.com/dmlc/tvm/blob/master/tests/python/unittest/test_codegen_arm.py#L45

This one the matrices are int8. But it does have VMLAL instruction without any tensorization. Is it possible to provide more context here?

Padding zero point instead of 0 is good point to remember. Can we do explicit padding before? Maybe explicit pad operator before conv is the right way to go? If I understand, this has to be handled in both options, and does not cause any differentiation point in choosing one option.

Thanks @FrozenGene for ping me. It’s very glad to see some many inspiring discussions on quantization. I have co-worked with @FrozenGene in the past several months on quantization regarding TVM, and especially on precision/accuracy maintenance of quantizing FP32/INT8 and performance design in INT8 quantization computing, I’d like to share some of my opinions.

Before going further, I’d like to say that the design decision depends on What you want most from quantization:

  1. Performance: to exhaust out every drop of the device capability.
  2. Development efficiency: to reuse/share the computing/scheduling of code/operator that TVM already have.
  3. Quantizing technology: try out some techniques which can convert/compute INT8 without significant accuracy impact.

Performance

As we all know that quantization brings two advantages: less disk/memory usage, faster inference. As we are talking about INT8 here, the first advantage is not a problem. We internally want performance improvement most which is our value.

Initially, we observed 1.5x - 2x performance gap between our convolution schedule and QNNPACK which shows the most powerful performance (btw, QNNPACK shares author with NNPACK) AFAIK.

But, where the INT8 performance comes from? The memory bandwidth or the computing effort? Our experience shows that the memory bandwidth really matters (as we are not going to reduce the multiplication in quantization context like Winograd or Strassen algorithm).

Eventually, we choose to follow the QNNPACK approach, of which the computing steps include:

  1. Accumulated multiplication with zero point and bias.
  2. Requantization the accumulated result (INT32) to (INT8).
    QNNPACK fuse them as one, which reads INT8 and writes INT8.

If we break these two steps into two operators, the first reads INT8 writes INT32 into memory and the second reads INT32 and writes INT8. In our test, this approach showed significant performance drop. (I’d like to make it clear that we have tensorized step 1 which may prevent step 2 from fusion.) As soon as we merged them into one in tensorize micro kernel, we got basically same performance as QNNPACK. The difference here is if there is INT32 intermedia memory access in the operator, if the computing is merged, the INT32 intermedia result (the accumulated result) can serve in registers.

Regarding @janimesh 's proposals, there will be at least two INT32 intermediate result memory access (if we ignore the graph fusion issue) which may bring effort to optimize.

So, according to our experience, option 2 is more likely to get performance at QNNPACK level. (Note that, tensorize micro kernel looks to be a must for performance in such scenario due to the limitation of TVM codegen.) There could be other design that has more outstanding performance, but we have not gone that far.

Development efficiency

As @FrozenGene has addressed, we proposed new quantization-dedicated operators such as q_conv2d (NHWC layout) which has very different schedule when compared to normal conv2d. We eventually chose to add these operators because we failed to get sound performance in modified conv2d.

As we need brand new schedule for (basically) all new operators, there is much work lies ahead. So, I think one potential advantage of option 1 is it can reuse the existing schedule by simply modify conv2d. This is more easy to get the work done, though we may not get an extremely good performance - which maybe fine for most people.

Quantization technology

I think many people are interested in trying out quantization technique without the dependency to TensorFlow/TFLite. This is out of this topic, let’s simply ignore it…

One more thing

Please always keep in mind that, the Requantize which converts INT32 accumulated result into INT8 is needed, and should be part of quantization version conv2d.

So, to summarize, regarding @janimesh 's proposals, I think option 1 may get performance similar to TFLite, while option 2 is more capable of enabling powerful tensorize design.

Thanks.

3 Likes

Please refer here: http://infocenter.arm.com/help/index.jsp?topic=/com.arm.doc.dui0491f/BABDEAGJ.html

int32x4_t  vmlal_s16(int32x4_t a, int16x4_t b, int16x4_t c);    // VMLAL.S16 q0,d0,d0

We don’t have

int32x4_t  vmlal_sxx(int32x4_t a, int8x4_t b, int8x4_t c);    // VMLAL.S16 q0,d0,d0

Insert pad is one option, however, as described in this RFC: https://github.com/dmlc/tvm/issues/2682 the pad can not be fused and lead to worse performance. In option 2, we could make q_conv2d accept input_zero_point, and make pad’s the value is input_zero_point. So, for option 2, if we have q_conv2d's api, the pad issue is not problem.

Loved this detailed explanation. Thanks. Let me think about this in little more detail. I need to understand the intermediate INT32 register argument.

One other axis that complicates this whole design space I guess is HW support. Intel VNNI performs 4 INT8 reductions in a INT32 register. Given this makes compute very fast, it is unclear without any experiments if Intel machines have similar performance bottlenecks as ARM. I think it would require efforts like yours on the Intel side as well to find the tradeoffs.

I see. That makes sense.

One other axis that complicates this whole design space I guess is HW support

Yes. It’s very important to clarify what is the most important to take into consideration. On one hand, TVM can hardly automatically generates instructions such as VNNI or VMLAL for logic like sum(a[m][k].astype("int32") * b[n][k].astype("int32")) where a and b are INT8 data according to our experiment. (Actually we are very curious how to get this done.) Therefore tensorize seems to be a must for performance, otherwise the performance advantage of quantization cannot outperform FP32 much. On the other hand, tensorize needs to be rewritten for targets of different ISA, including x86 with different SSE/AVX extension, armv7a, armv8a and etc.

Our internal work has different schedules for different purposes, for example a spatial pack as fallback schedule, some tensorize based schedule for performance in different scenarios (single core, multicore core). We’d like to contribute back to community, though it may not happen in near days as far as i can see. But maybe we can share some design insights.

consider handling symmetric int8 kernel (fixed 0 offset) and uint8 with 0 offset as specializations.

This requires that the TFLite model provided by user uses a symmetric quantization, which is expected to be generated by TF/TFLite tools. AFAIK, the official pre-trained MobileNetV1 is not. Assuming 0 offset seems a bit aggressive?

also, I believe tflite implementations perform quantized downscale to uint8 feature data between layers. In tflite these are implemented by integer mpy and downshift operations. Some target devices don’t have fp, and especially not dp, so maybe consider supporting the integer implementation as an option.

Yes, enabling integer only devices is one of the targets of Gemmlowp. You are expert on this :slight_smile: If we are taking TFLite’s approach, the computing inside q_conv2d are purely integer operation, including the requantization which converts accumulated INT32 result back to INT8. So, no worry about integer implementation :slight_smile:

I wouldn’t say expert, but I ported the tflite quantized inception models to a risc-v architecture in my last job.

Note that the tflite implementation uses int64 operations in the integer downscale operations, while some targets do not support that. So your integer implementation may need to allow for that.

The downscale integer operations will also need to do rounding and saturation.

Note that the tflite implementation uses int64 operations in the integer downscale operations, while some targets do not support that.

That is interesting, so you were handling it on ISA level or C (with complier’s help to emulate int64)?

The downscale only stores 8 bit data from the upper word of the int64 result. The downscale constant mpy moves data into the upper word, and then a right shift round, and then saturate and store the uint8 feature value.

The risc-v had a form of multiply instruction that just keeps the upper word of what would have been a 64 bit result… so that provided the needed part for the downscale.

The tflite implementations include an option for selecting reference code. I found it very useful for debug to modify their reference code to do printf output before and after activation, rounding, saturation and I prefixed the output with the c,h,w index values. The prefix c,h,w can be used to sort the lines so that you can match up with your implementation output order is (assuming it also tags and outputs its values). I converted six of the tflite quantized models to risc-v using this.

so, this is just a suggestion, from my own experience, that you provide a similar way to dump the data while trying to match the tflite reference per layer data. Then you can do whatever other optimizations and parallel processing and tune for performance.

I modified tflite operations to support dumping for debug of six models … the four inceptions and two mobilenets. This was back in March, but might be some help. This version was used on ubuntu 18.04.
tflite c++ model printf debug

1 Like

@jackwish @FrozenGene Thanks for explaining the tradeoffs. I got some time to think more deeply about this.

Σc,r,s QW(k, c, r, s) × QA(n, c, h + r, w + s) // Term 1
  - Σc,r,s zp_a × QW(k, c, r, s)                     // Term 2
  - Σc,r,s zp_w × QA(n, c, h + r, w + s)        // Term 3
  + Σc,r,s zp_a × zp_w                                // Term 4

I came up with following conclusions (After reading FBGemm and QNNPACK)

  • Partial output reuse - Intermediate INT32 value reuse - We should strive to perform all the computation for intermediate INT32 value before going to the memory. @jackwish and @FrozenGene performed this by fusing the core computation (Term 1 to Term4 in the above equation) along with requantize, so that INT32 is always served in the register file and never has to spilled in memory and brought back.
  • Input reuse - Quantized input matrix reuse - This is also very necessary. If we look at the equation, Term 3 is basically a big offset matrix that has to be applied to Term 1. Both Term 1 and Term 3 share the quantized input matrix. If we perform the calculation in different operators and no fusion, we will not be able to reuse quantized matrix A.

So, in current form, Option 1 does not satisfy both of the above requirements, potentially leaving significant performance opportunity on the table. As performance is primary goal (atleast in my case), I am also leaning towards option 2 as well (with a dread of writing a tensorized schedule). However, I do have a couple of unresolved thoughts that are bothering me.

  1. For option 2, we are performing computations for Term 2 and Term 4 at runtime. I don’t know how much penalty can be avoided by pre-computing that at compile time. This is also a limitation of current TVM infrastructure, where pre-compute/fold-constant is only limited to Relay.
  2. Making a final “weak” case for option 1 here. So, it might be possible to keep 4 terms as separate Relay ops. Precompute Term 2 and Term 4 using Relay passes, solving the problem in point 1. And then somehow perform both horizontal and vertical fusion to fuse everything into one giant op. Then, we can use compute_inline() and compute_at() to perform both input and output reuse. Again, this is “weak” case, this does not sound easy to do. (This is somewhat in line with FBGemm, where they have more granular APIs, where some APIs are executed at compile time).

If we merge all 4 terms back into its original form, which is something like sum((A - a_zp) * (W - w_zp)), the subtract happens only when we loading input or weights from memory, which may not the bottleneck in practice if there were a good schedule.

Precomputing is good, and QNNPACK’s approach is precomputing Term 2 & 4 and merge them into bias before operator runs (QNNPACK don’t have a compile time, but do have a prepare stage). However, it won’t be easy to do similar work in TVM…

One thing I should mention (hope my comments won’t mislead in any direction) is that, though I have kept saying that reducing INT32 intermedia memory access is important for performance, the fact is that the memory access pattern (schedule) is very significant too.