Introduction and motivation
Mathematically, the fixed point multiplication (FPM) can be described as:
fpm(x,m,s) = round(x*m*2^(s-31))
In this expression:
xis the quantized value to multiply, and
sare an integer multiplier and a shift.
roundcan be any of the rounding rules described here.
FPM is at the heart of the requantization process in quantized neural networks (QNNs), where the 32-bit integers resulting from a convolution or GEMM need to be requantized to a narrower data type (usually int8 or uint8).
Our analysis shows that we can achieve up to a 3% improvement, on Arm targets, by speeding up FPM. Even though it might not seem a lot, in a previous RFC we showed that we are now 4% away from frameworks like TFlite, so even a tiny 3% improvement is appealing for us.
In its current state, TVM implements FPM as a sequence of relay operators. The pseudo-code is showed below:
def fixed_point_multiply(x, fixed_point_multiplier, right_shift) x = cast(x,int64) * fixed_point_multiplier total_right_shift = right_shift + 31 pos_rounding_value = 1 << (total_right_shift -1) x = x + pos_rounding_value x = x >> total_right_shift return cast(x, int32)
- All the operators (shift, sum, multiplication) are Relay operators
- All the computation is mostly carried in 64 bits, converting to 32 bits only at the very end of the FPM operator and is very close to the mathematical expression described above
- TVM picks a
to-nearestrounding rule and breaks ties upward (i.e.,
- The Relay implementation also considers the case of a negative right shift (not showed in the pseudo-code)
However, architectures like Armv8-A provide interesting instructions to execute this operation directly in 32 bits. In particular, it can be shown that this operation can be achieved (on Armv8-A targets) as a combination of sqrdmulh and srshl instructions (which indeed operate on 32bits quads). In particular:
((a*b*2)+round_const) * 2^(-31). Note that the
round_constis used to round to nearest breaking ties upward
a*2^(-n), rounding always upward (this means we need to nudge the result to round to-nearest).
Design and implementation
We propose to create a TVM intrinsic
qmuls written in TVM IR (TIR) that will execute a Q-multiplication followed by a right shift.
The intrinsic signature is as follows:
qmuls(x, y, q, s)
y are two Q-numbers, and Q is passed as third argument. The right shift
s is passed as last argument to the intrinsic.
There are multiple reasons to introduce this:
- It is general enough, so it can be reused whenever we need to multiply Q-numbers (shift can be set to zero if we want to achieve only a Q-multiplication)
- Each hardware vendor can provide an hardware specific implementation for the operation
- The intrinsic can be overloaded by different targets using
tvm.target.intrin.register_intrin_rule. This is a simpler approach than overloading through compute strategy or tensorization.
In the sections below, we describe the main code changes of this RFC.
We created a new Relay operator
fixed_point_multiplication and registered a compute and an
injective_schedule for it.
- The Relay operator has two attributes, the multiplier (
m) and the right shift(
- The compute is a simple loop over the array (i.e., mostly like a unary operation)
- The injective schedule has the task to vectorize the loop.
The main TIR changes are the following:
- We registered a
tvm.intrin.rule.default.qmulsTVM intrinsic that executes the same operations or the Relay implementation(but using TIR operators).
- We created a TIR operator
qmuls(x,y,q,s)which executes the call:
call_intrin(x.dtype, "qmuls", x, y, q, s)
In order to overload the intrinsic for Armv8-A we need to make use of
tvm.target.intrin.register_intrin_rule. However, the intrinsics are overloaded by
target_name which in case of Armv8-A is only
This means that, in order to specialize for
llvm.aarch64 we had to hack into lower_intrin.cc and register a new
Given the above tweak, we could easily exploit the
tvm.target.intrin.register_intrin_rule method in order to register a version of
qmuls tailored for Armv8-A ISA. The result is similar to the following:
def _qmuls_arm(op): x = op.args multiplier = op.args shift = op.args sqrdmulh = tvm.tir.call_llvm_intrin(op.dtype, 'llvm.aarch64.neon.sqrdmulh', tvm.tir.const(2, 'uint32'), x, multiplier) fixup = (sqrdmulh & (-shift)) >> 31 fixed_up_x = (sqrdmulh + fixup) out = tvm.tir.call_llvm_intrin(op.dtype, 'llvm.aarch64.neon.srshl', tvm.tir.const(2, 'uint32'), sqrdmulh, shift) return out tvm.target.intrin.register_intrin_rule("llvm.aarch64", "fixed_point_multiply", _fixed_point_multiply_arm, override=True)
Few notes on the above implementation:
- Please note that we also consider the case of a negative right shift (not showed in the code)
- The fixup is needed to round to nearest (instead of rounding upward as
- We use the default implementation when the data (
x) is not a vector or when Q is not 31
Final notes on performance and precision
As previously mentioned, the best performance gain in using those intrinsics seems to set around 3%, but the performance improvement we got is only around 1.5%:
- The 3% improvement is for a case in which the requantization operation is fused within the main computation loop (e.g., GEMM or spatial convolution).
- In TVM, a quantized convolution is lowered as a sequence of a qnn.conv2d followed by a requantize operator. This makes fusing requantization within the compute not possible, explaining why we cannot fully achieve the 3% improvement.
There are corner cases in which the intrinsic implementation will have a +1/-1 error compared to the default TVM implementation. This is because we are rounding twice instead than once. In other words:
- default behavior:
out = round(x*y*2^-s)
- arm behavior:
out = round(round(x*y)*2^-s)
The PR for this RFC is here: https://github.com/apache/incubator-tvm/pull/5980