Thanks for the detailed explanation @FrozenGene. This is really helpful. Let me lay down both the options so that other can also comment.
@eqy @ziheng @tqchen please let us know if you have any suggestions on this one. Please tag others who might be interested in this.
Also, let me know if I should convert the type to RFC.
Background
Suppose there are two matrices A and W, and we want to perform convolution operation on them and get output C.
If A and W are both FP32 matrices, the calculation is
C(n, k, h, w) = Σc,r,s W(k, c, r, s) × A(n, c, h + r, w + s)
where n = batch_size, k = output channels, h = output height, w = output width
and c = input_channels, r = kernel_height, s = kernel_width
Now, a quantized representation looks something like this
A = scale_a × (QA - zp_a)
where A is FP32 tensor, Qa is (u)int8 tensor and FP32 scale_a and int32 zp_a (zero_point) are quantization parameters for that tensor. Please refer to this link to understand more about the quantization parameters.
So, the above conv computation becomes
C(n, k, h, w) = Σc,r,s [scale_w × (QW(k, c, r, s) - zp_w)] × [scale_a × (QA(n, c, h + r, w + s) - zp_a)]
C(n, k, h, w) = scale_w × scale_a × Σc,r,s (QW(k, c, r, s) - zp_w) × (QA(n, c, h + r, w + s) - zp_a) (1)
There are two ways where we can go from here. First, break it down into simpler operators and write schedules for each one of them. Second, keep this form as it is and optimize the schedule in one go.
Option 1 - Simplify the compute into 4 operations
Let’s look at the terms that are inside the Σc,r,s in equation (1). Let’s also call this operator q_conv2d
, which takes quantized matrices and their quantization parameters. (We would also need output quantization parameters here, but skipping that detail to stay focussed)
Σc,r,s (QW(k, c, r, s) - zp_w) × (QA(n, c, h + r, w + s) - zp_a)
Σc,r,s QW(k, c, r, s) × QA(n, c, h + r, w + s) // Term 1
- Σc,r,s zp_a × QW(k, c, r, s) // Term 2
- Σc,r,s zp_w × QA(n, c, h + r, w + s) // Term 3
+ Σc,r,s zp_a × zp_w // Term 4
Now, the whole computation is broken down into 4 terms
- Term 1 - Normal conv2d between the two quantized (u)int8 matrices. This is awesome because it allows to reuse the conv2d schedules. For ARM CPU, we can use VMLAL instruction (LLVM already does that). For Intel, we can use VNNI that performs compute on int8 which internally upcasts to int16/int32 whenever required.
- Term 2 - Needs new operator. Constant folding gets rid of this
- Term 3 - Needs new operator. Sum reduction on input quantized matrix. This might bring some overhead. But the operations are much lower than Term1.
- Term 4 - Constant folding. Just few multiplications of scalar values.
The way I envision this is that framework parsers will map the qunatized_conv2d of frameworks to relay q_conv2d (with quantize/dequantize wherever necessary). Then, a simplify_inference type pass can break the q_conv2d into simpler operators. The unknown here is how bad is the 3rd term.
Option 2 - Keep the compute intact
In this case, the equation (1) directly goes into the TVM schedule. @FrozenGene mentions that QNNPACK takes this path, which gives this option some force. In addition, @FrozenGene has internally tested this, potentially giving higher performance than TFLite itself. See this for little more details. Other good thing is that we can keep the q_conv2d
API close to QNNPack, so that we can fallback to QNNPack codegen if necesasry.
Σc,r,s (QW(k, c, r, s) - zp_w) × (QA(n, c, h + r, w + s) - zp_a)
The problem that I see with Option 2 though is two-fold
- It has to upcast the quantized tensor elements from int8 to int32 to perform the subtraction from zero_point first. Once it has been upcasted to int32, we can’t use Intel VNNI or ARM VMLAL. It might be possible to somehow undo this in schedule and make schedule look like Option 1, but that seems hacky.
- TF Lite uses GemmLowP for implementing the conv. I have not delved into many details, but it performs im2col transformation, followed by matrix multiplication. The im2col transformation, being memory bandwidth bound, has been shown to be not the best option for many cases. So, my hunch is that, for end-to-end, we might be able to perform better than GemmLowP with our layout transformation passes.
Option 3 - Best of both worlds
We can keep Option 1 as a Relay optimization pass, that a user can control by Relay BuildConfig. In this manner, we can keep both options.
Note - I have deliberately skipped scale_w, scale_a and output quantization parameter handling. It will make discussion difficult. Will tackle them in a separate thread if necessary.