TF Lite quantized conv2d operator conversion

Thanks @FrozenGene. I have two followup questions

  • Whats the current plan of expressing TVM compute for q_conv2d? Do you plan to write it from scratch? Or can we break it down into simpler operators using a Relay pass like simplify_inference (somewhat on the lines of gemmlowp).

  • Also thinking of generalizing this with other frameworks as well. Maybe, we should have a look at other frameworks and come up with one q_conv2d that works for all. For TF lite, we don’t need dequantization, but for others we might. In that case, we might need to enhance/break q_conv2d if it is not the right abstraction.

q_conv2d's compute is very similar with conv2d. The changed place is we should care convolution computation part. i.e. data * weight will be changed into (data - input_zero_point).astype("int32) * (weight - kernel_zero_point).astype("int32). Then we call q_conv2d_requantize to quantize int32 into int8. Another place we should care is the pad. The pad value should be input_zero_point, not 0. I think we use one new op q_conv2d is reasonable. Because we found the optimization of q_conv2d has different places compared with conv2d.

Yes. we also consider other frameworks even other libraries such as QNNPACK. When to design the API, I keep in mind the API could be compatible with QNNPACK , though I don’t try integrating with QNNPACK, but it should be. Come back to dequantization, In TFLite, we have one op named as dequantize, which does the dequantization work. So, if we need dequantization, we need implement dequantize. If other frameworks need dequantize, we could combine q_conv2d + dequantize.

Do you think if we convert every multiply of data * weight to data - input_zero_point).astype("int32) * (weight - kernel_zero_point).astype("int32), it would result in bad performance? Because we have to subtract each element and that gets in the way of using Fused-Multiply-Accmulate? For example, it seems very unlikely that I can use Intel VNNI instruction with this type of compute.

That being said, I was able to work through the maths and break down the computation into actual conv2d between 2 int8 matrices and 3 other operators, which seems to me might be faster or more friendly to map to HW. I will write up the equations tomorrow in a pretty manner so that we can look at it together. We can have more discussion then.

Because we found the optimization of q_conv2d has different places compared with conv2d

Will it be possible to shed more light on this line? Which optimizations are these? Relay or TVM schedule? In case of TVM schedule, are you talking about any particular target?

Makes sense about the QNNPack and dequantization comment. Thanks for reminding about QNNPACK, will look into that as well.

Other question - Do you care about cases, where the output_min/max/scale etc are not given (does not happen in TF Lite)? In that case, we might need to calculate min/max with input/weight min/max. I don’t know if we should care about this scenario, but seems like MxNet has this configuration as well.

We previously dig deeply when we want to improve performance, but we find we can’t omit this. We use tensorize to improve the performance. i.e. Firstly LOAD 8 elements then subtract it, then compute, then requantize. QNNPACK does it like this way too.

Our main focus target is ARM CPU. On ARM CPU, when to handle q_conv2d, if we want to get better performance compared with spatial pack, we have to tensorize. The optimization way is said before, we should LOAD 8 firstly, then subtract it, then compute, then requantize into int8. The spatial_pack's llvm codegen is not this, it will be LOAD, SUB, LOAD SUB and so on. If you want to improve performance on Intel CPU, I think you should tensorize too.

Output min / max / scale will always be given in TFLite model. If MXNet doesn’t have, we should calculate before pass into q_conv2d.

However, GEMMLOWP has one way we could control the input_zero_point / kernel_zero_point, i.e. we could just data * weight -> convolution. then compute the offset after GEMM. see: https://github.com/google/gemmlowp/blob/master/doc/low-precision.md#efficient-handling-of-offsets But I haven’t analyzed it. But I think when we have q_conv2d, it could be done very easily. i.e. write new schedule / tensorize for specific target for q_conv2d.

Thanks for the detailed explanation @FrozenGene. This is really helpful. Let me lay down both the options so that other can also comment.
@eqy @ziheng @tqchen please let us know if you have any suggestions on this one. Please tag others who might be interested in this.

Also, let me know if I should convert the type to RFC.

Background

Suppose there are two matrices A and W, and we want to perform convolution operation on them and get output C.

If A and W are both FP32 matrices, the calculation is

C(n, k, h, w) = Σc,r,s W(k, c, r, s) × A(n, c, h + r, w + s)

where n = batch_size, k = output channels, h = output height, w = output width and c = input_channels, r = kernel_height, s = kernel_width

Now, a quantized representation looks something like this

A = scale_a × (QA - zp_a)

where A is FP32 tensor, Qa is (u)int8 tensor and FP32 scale_a and int32 zp_a (zero_point) are quantization parameters for that tensor. Please refer to this link to understand more about the quantization parameters.

So, the above conv computation becomes

C(n, k, h, w) = Σc,r,s [scale_w × (QW(k, c, r, s) - zp_w)] × [scale_a × (QA(n, c, h + r, w + s) - zp_a)]

C(n, k, h, w) = scale_w × scale_a × Σc,r,s (QW(k, c, r, s) - zp_w) × (QA(n, c, h + r, w + s) - zp_a) (1)

There are two ways where we can go from here. First, break it down into simpler operators and write schedules for each one of them. Second, keep this form as it is and optimize the schedule in one go.

Option 1 - Simplify the compute into 4 operations

Let’s look at the terms that are inside the Σc,r,s in equation (1). Let’s also call this operator q_conv2d, which takes quantized matrices and their quantization parameters. (We would also need output quantization parameters here, but skipping that detail to stay focussed)

Σc,r,s (QW(k, c, r, s) - zp_w) × (QA(n, c, h + r, w + s) - zp_a)

Σc,r,s QW(k, c, r, s) × QA(n, c, h + r, w + s) // Term 1
  - Σc,r,s zp_a × QW(k, c, r, s)                     // Term 2
  - Σc,r,s zp_w × QA(n, c, h + r, w + s)        // Term 3
  + Σc,r,s zp_a × zp_w                                // Term 4

Now, the whole computation is broken down into 4 terms

  • Term 1 - Normal conv2d between the two quantized (u)int8 matrices. This is awesome because it allows to reuse the conv2d schedules. For ARM CPU, we can use VMLAL instruction (LLVM already does that). For Intel, we can use VNNI that performs compute on int8 which internally upcasts to int16/int32 whenever required.
  • Term 2 - Needs new operator. Constant folding gets rid of this
  • Term 3 - Needs new operator. Sum reduction on input quantized matrix. This might bring some overhead. But the operations are much lower than Term1.
  • Term 4 - Constant folding. Just few multiplications of scalar values.

The way I envision this is that framework parsers will map the qunatized_conv2d of frameworks to relay q_conv2d (with quantize/dequantize wherever necessary). Then, a simplify_inference type pass can break the q_conv2d into simpler operators. The unknown here is how bad is the 3rd term.

Option 2 - Keep the compute intact

In this case, the equation (1) directly goes into the TVM schedule. @FrozenGene mentions that QNNPACK takes this path, which gives this option some force. In addition, @FrozenGene has internally tested this, potentially giving higher performance than TFLite itself. See this for little more details. Other good thing is that we can keep the q_conv2d API close to QNNPack, so that we can fallback to QNNPack codegen if necesasry.

Σc,r,s (QW(k, c, r, s) - zp_w) × (QA(n, c, h + r, w + s) - zp_a)

The problem that I see with Option 2 though is two-fold

  • It has to upcast the quantized tensor elements from int8 to int32 to perform the subtraction from zero_point first. Once it has been upcasted to int32, we can’t use Intel VNNI or ARM VMLAL. It might be possible to somehow undo this in schedule and make schedule look like Option 1, but that seems hacky.
  • TF Lite uses GemmLowP for implementing the conv. I have not delved into many details, but it performs im2col transformation, followed by matrix multiplication. The im2col transformation, being memory bandwidth bound, has been shown to be not the best option for many cases. So, my hunch is that, for end-to-end, we might be able to perform better than GemmLowP with our layout transformation passes.

Option 3 - Best of both worlds

We can keep Option 1 as a Relay optimization pass, that a user can control by Relay BuildConfig. In this manner, we can keep both options.


Note - I have deliberately skipped scale_w, scale_a and output quantization parameter handling. It will make discussion difficult. Will tackle them in a separate thread if necessary.

4 Likes

I haven’t seen detail. But I think @ajtulloch should have interest.

Also @merrymercy.

I will read this post today.

In fact, on ARM CPU, if want to produce VMLAL, normal way can not complete. i.e. Option 1 / 2 both can not do. Because ARM CPU only support int16->int32 VMLAL. So we have to tensorize. i.e. LOAD 8 elements U8 and convert to INT16, SUB zero_point, then compute, finally we can leverage VMLAL. Option 1 maybe suit for VNNI instruction on Intel CPU.

Maybe I prefer Option 2. i.e. we provide standard / normal spatial_pack schedule on ARM CPU / Intel CPU, and . on generic target (i.e. nn.py), we provide q_conv2d's compute in naive way. Spatial pack schedule performs well in fact according to our test and implement it very easily. But if you want to improve, you could write tensorize version for your target. For example, to produce ARM CPU’s VMLAL / Intel’s VNNI.

How about we prototype then? I remember seeing VMLAL instruction long time back just by using LLVM on int8 instructions (no tensorize call). But I might be wrong. For Intel, we do have to use tensorize as LLVM is not strong enough there.

I can write Term 1 and Term 3 python in Option 1 and get some performance numbers. Though these will be unoptimized, hopefully we can get some more clarity. What do you say?

consider handling symmetric int8 kernel (fixed 0 offset) and uint8 with 0 offset as specializations. terms disappear with these. maybe your constant folding handles that already…

also, I believe tflite implementations perform quantized downscale to uint8 feature data between layers. In tflite these are implemented by integer mpy and downshift operations. Some target devices don’t have fp, and especially not dp, so maybe consider supporting the integer implementation as an option.

@jnorwood
Both very good points. Let me answer them one by one

If we look at the following equation.

Σc,r,s QW(k, c, r, s) × QA(n, c, h + r, w + s) // Term 1
  - Σc,r,s zp_a × QW(k, c, r, s)                     // Term 2
  - Σc,r,s zp_w × QA(n, c, h + r, w + s)        // Term 3
  + Σc,r,s zp_a × zp_w                                // Term 4

With both offsets set to 0, we are only left with Term 1. This is normal conv2d of the original quantized matrices.

So, I believe that if we design the q_conv2d API correctly, then setting the zero_point from outside to zero should satisfy. I am sure that even Option 2 can easily account for this in the schedule and not perform redundant SUB. But, good to keep that in mind.

Secondly, yes I agree we need to take care of what is supported in HW. We currently tackle each target separately, so we should be able to keep that differentiation. I agree that rounding and shifts require much more involvement as far as the assembly instructions go. There are lots of tricks there and we can mess up easily.

I think you maybe remember @ziheng’s implementation. His original implementation also make data / weight’s data type be int16, not int8. also confirmed with @merrymercy before.

For the option 1, I think we have one thing we must consider. i.e. padding. In TFLite’s quantization model, the pad value should be input_zero_point, not 0. However, current normal conv2d doesn’t have one parameter to accept input_zero_point.

@jackwish should also have interest.

There are lots of quantized operations in TFLite, do you plan to implement each one needed as a new op in TVM , like q_conv2d?

I think it is not easy to break those tflite op into smaller tvm ops.

Right. Because quantization computation is not the same as normal computation. Such as q_add / q_pool and so on.

So, currently for FP32, we spend lot of time writing schedules for important operators like conv2d, fully connected. The point is that if an operator takes lot of execution time, it is worth it to focus efforts on that.

I am using the same principle here. So, we will have to send time on q_conv2d. As far as writing schedules go, both options might need tensorize work.

I agree that for operators like Relu and Pool, we might want to use the Option 1 as default, and they might give good enough performance.

@FrozenGene I don’t understand the int16 and int8 comment. I was looking at this https://github.com/dmlc/tvm/blob/master/tests/python/unittest/test_codegen_arm.py#L45

This one the matrices are int8. But it does have VMLAL instruction without any tensorization. Is it possible to provide more context here?

Padding zero point instead of 0 is good point to remember. Can we do explicit padding before? Maybe explicit pad operator before conv is the right way to go? If I understand, this has to be handled in both options, and does not cause any differentiation point in choosing one option.

Thanks @FrozenGene for ping me. It’s very glad to see some many inspiring discussions on quantization. I have co-worked with @FrozenGene in the past several months on quantization regarding TVM, and especially on precision/accuracy maintenance of quantizing FP32/INT8 and performance design in INT8 quantization computing, I’d like to share some of my opinions.

Before going further, I’d like to say that the design decision depends on What you want most from quantization:

  1. Performance: to exhaust out every drop of the device capability.
  2. Development efficiency: to reuse/share the computing/scheduling of code/operator that TVM already have.
  3. Quantizing technology: try out some techniques which can convert/compute INT8 without significant accuracy impact.

Performance

As we all know that quantization brings two advantages: less disk/memory usage, faster inference. As we are talking about INT8 here, the first advantage is not a problem. We internally want performance improvement most which is our value.

Initially, we observed 1.5x - 2x performance gap between our convolution schedule and QNNPACK which shows the most powerful performance (btw, QNNPACK shares author with NNPACK) AFAIK.

But, where the INT8 performance comes from? The memory bandwidth or the computing effort? Our experience shows that the memory bandwidth really matters (as we are not going to reduce the multiplication in quantization context like Winograd or Strassen algorithm).

Eventually, we choose to follow the QNNPACK approach, of which the computing steps include:

  1. Accumulated multiplication with zero point and bias.
  2. Requantization the accumulated result (INT32) to (INT8).
    QNNPACK fuse them as one, which reads INT8 and writes INT8.

If we break these two steps into two operators, the first reads INT8 writes INT32 into memory and the second reads INT32 and writes INT8. In our test, this approach showed significant performance drop. (I’d like to make it clear that we have tensorized step 1 which may prevent step 2 from fusion.) As soon as we merged them into one in tensorize micro kernel, we got basically same performance as QNNPACK. The difference here is if there is INT32 intermedia memory access in the operator, if the computing is merged, the INT32 intermedia result (the accumulated result) can serve in registers.

Regarding @janimesh 's proposals, there will be at least two INT32 intermediate result memory access (if we ignore the graph fusion issue) which may bring effort to optimize.

So, according to our experience, option 2 is more likely to get performance at QNNPACK level. (Note that, tensorize micro kernel looks to be a must for performance in such scenario due to the limitation of TVM codegen.) There could be other design that has more outstanding performance, but we have not gone that far.

Development efficiency

As @FrozenGene has addressed, we proposed new quantization-dedicated operators such as q_conv2d (NHWC layout) which has very different schedule when compared to normal conv2d. We eventually chose to add these operators because we failed to get sound performance in modified conv2d.

As we need brand new schedule for (basically) all new operators, there is much work lies ahead. So, I think one potential advantage of option 1 is it can reuse the existing schedule by simply modify conv2d. This is more easy to get the work done, though we may not get an extremely good performance - which maybe fine for most people.

Quantization technology

I think many people are interested in trying out quantization technique without the dependency to TensorFlow/TFLite. This is out of this topic, let’s simply ignore it…

One more thing

Please always keep in mind that, the Requantize which converts INT32 accumulated result into INT8 is needed, and should be part of quantization version conv2d.

So, to summarize, regarding @janimesh 's proposals, I think option 1 may get performance similar to TFLite, while option 2 is more capable of enabling powerful tensorize design.

Thanks.

3 Likes

Please refer here: http://infocenter.arm.com/help/index.jsp?topic=/com.arm.doc.dui0491f/BABDEAGJ.html

int32x4_t  vmlal_s16(int32x4_t a, int16x4_t b, int16x4_t c);    // VMLAL.S16 q0,d0,d0

We don’t have

int32x4_t  vmlal_sxx(int32x4_t a, int8x4_t b, int8x4_t c);    // VMLAL.S16 q0,d0,d0

Insert pad is one option, however, as described in this RFC: https://github.com/dmlc/tvm/issues/2682 the pad can not be fused and lead to worse performance. In option 2, we could make q_conv2d accept input_zero_point, and make pad’s the value is input_zero_point. So, for option 2, if we have q_conv2d's api, the pad issue is not problem.

Loved this detailed explanation. Thanks. Let me think about this in little more detail. I need to understand the intermediate INT32 register argument.

One other axis that complicates this whole design space I guess is HW support. Intel VNNI performs 4 INT8 reductions in a INT32 register. Given this makes compute very fast, it is unclear without any experiments if Intel machines have similar performance bottlenecks as ARM. I think it would require efforts like yours on the Intel side as well to find the tradeoffs.