[RFC] Using arm intrinsics to implement fixed point multiplication in TVM

Introduction and motivation

Mathematically, the fixed point multiplication (FPM) can be described as:

fpm(x,m,s) = round(x*m*2^(s-31))

In this expression:

FPM is at the heart of the requantization process in quantized neural networks (QNNs), where the 32-bit integers resulting from a convolution or GEMM need to be requantized to a narrower data type (usually int8 or uint8).

Our analysis shows that we can achieve up to a 3% improvement, on Arm targets, by speeding up FPM. Even though it might not seem a lot, in a previous RFC we showed that we are now 4% away from frameworks like TFlite, so even a tiny 3% improvement is appealing for us.

Background

In its current state, TVM implements FPM as a sequence of relay operators. The pseudo-code is showed below:

def fixed_point_multiply(x, fixed_point_multiplier, right_shift) 
    x = cast(x,int64) * fixed_point_multiplier 
    total_right_shift = right_shift + 31 
    pos_rounding_value = 1 << (total_right_shift -1) 
    x = x + pos_rounding_value 
    x = x >> total_right_shift return cast(x, int32)
  • All the operators (shift, sum, multiplication) are Relay operators
  • All the computation is mostly carried in 64 bits, converting to 32 bits only at the very end of the FPM operator and is very close to the mathematical expression described above
  • TVM picks a to-nearest rounding rule and breaks ties upward (i.e., x.5 becomes x+1).
  • The Relay implementation also considers the case of a negative right shift (not showed in the pseudo-code)

However, architectures like Armv8-A provide interesting instructions to execute this operation directly in 32 bits. In particular, it can be shown that this operation can be achieved (on Armv8-A targets) as a combination of sqrdmulh and srshl instructions (which indeed operate on 32bits quads). In particular:

  • sqrdmulh(a,b) : executes ((a*b*2)+round_const) * 2^(-31). Note that the round_const is used to round to nearest breaking ties upward
  • srshl(a,n) : executes a*2^(-n), rounding always upward (this means we need to nudge the result to round to-nearest).

Design and implementation

We propose to create a TVM intrinsic qmuls written in TVM IR (TIR) that will execute a Q-multiplication followed by a right shift.

The intrinsic signature is as follows:

qmuls(x, y, q, s)

Where x and y are two Q-numbers, and Q is passed as third argument. The right shift s is passed as last argument to the intrinsic.

There are multiple reasons to introduce this:

  • It is general enough, so it can be reused whenever we need to multiply Q-numbers (shift can be set to zero if we want to achieve only a Q-multiplication)
  • Each hardware vendor can provide an hardware specific implementation for the operation
  • The intrinsic can be overloaded by different targets using tvm.target.intrin.register_intrin_rule. This is a simpler approach than overloading through compute strategy or tensorization.

In the sections below, we describe the main code changes of this RFC.

Relay changes

We created a new Relay operator fixed_point_multiplication and registered a compute and an injective_schedule for it.

  • The Relay operator has two attributes, the multiplier (m) and the right shift(s)
  • The compute is a simple loop over the array (i.e., mostly like a unary operation)
  • The injective schedule has the task to vectorize the loop.

TIR changes

The main TIR changes are the following:

  • We registered a tvm.intrin.rule.default.qmuls TVM intrinsic that executes the same operations or the Relay implementation(but using TIR operators).
  • We created a TIR operator qmuls(x,y,q,s) which executes the call:

call_intrin(x.dtype, "qmuls", x, y, q, s)

Intrinsic overload

In order to overload the intrinsic for Armv8-A we need to make use of tvm.target.intrin.register_intrin_rule. However, the intrinsics are overloaded by target_name which in case of Armv8-A is only llvm.

This means that, in order to specialize for llvm.aarch64 we had to hack into lower_intrin.cc and register a new llvm.intrin.rule.aarch64. pattern.

Given the above tweak, we could easily exploit the tvm.target.intrin.register_intrin_rule method in order to register a version of qmuls tailored for Armv8-A ISA. The result is similar to the following:

def _qmuls_arm(op): 
   x = op.args[0] 
   multiplier = op.args[1] 
   shift = op.args[2] 
   sqrdmulh = tvm.tir.call_llvm_intrin(op.dtype, 'llvm.aarch64.neon.sqrdmulh',    tvm.tir.const(2, 'uint32'), x, multiplier) 
   fixup = (sqrdmulh & (-shift)) >> 31 
   fixed_up_x = (sqrdmulh + fixup) 
   out = tvm.tir.call_llvm_intrin(op.dtype, 'llvm.aarch64.neon.srshl', tvm.tir.const(2, 'uint32'), sqrdmulh, shift) 
   return out 

tvm.target.intrin.register_intrin_rule("llvm.aarch64", "fixed_point_multiply", _fixed_point_multiply_arm, override=True)

Few notes on the above implementation:

  • Please note that we also consider the case of a negative right shift (not showed in the code)
  • The fixup is needed to round to nearest (instead of rounding upward as srshl does)
  • We use the default implementation when the data (x) is not a vector or when Q is not 31

Final notes on performance and precision

Performance

As previously mentioned, the best performance gain in using those intrinsics seems to set around 3%, but the performance improvement we got is only around 1.5%:

Precision

There are corner cases in which the intrinsic implementation will have a +1/-1 error compared to the default TVM implementation. This is because we are rounding twice instead than once. In other words:

  • default behavior: out = round(x*y*2^-s)
  • arm behavior: out = round(round(x*y)*2^-s)

PR

The PR for this RFC is here: https://github.com/apache/incubator-tvm/pull/5980

3 Likes

Thanks for the nice RFC.

Trying to understand if I missed anything. What will happen for non-ARM machines? Are we going to use fixed_point_multiply relay operator for non-ARM machines and then use injective schedule?

Hi @anijain2305,

Both Arm and non-arm machines will use the same fixed_point_multiply relay operator, which will have an injective schedule associated with it, calling into tvm.tir.fixed_point_multiply().

The only difference is how the tvm.tir.fixed_point_multiply() is implemented. On non-arm machines it will follow the same logic previously used, while for arm-machines (specifically AArch64 machines) it will be implemented through arm intrinsics.

Introducing fixed point mulitply in the tir seems to be a quite overkill, given that most of the operator itself can be expressed by the basic integer arithmetics, would it be easier to detect the pattern (of multipy shift and round) and rewrite into the fixed point multiply?

Notably, we can also directly add 0.5(factor) to the bias for so we can directly use the round down behavior in the right shift.

I wonder if we can apply better legalization in the QNN to get around the issue(e.g. use int32 when possible) without having to bring the primitive to the TIR level cc @anijain2305

Hi @tqchen, Thanks a lot for you comments.

Actually, I understand the first part of your comment, but I am afraid I don’t follow the rest :slight_smile:

Just to fully understand:

  • About adding 0.5(factor) to the bias, what do you mean? The bias is added before the requantization (as an int32) right? Do you mean to incorporate the bias addition within the fixed_point_multiply()
  • About the comment on legalization: do you mean trying to intercept the fixed point multiplication during the legalization pass?

A different implementation would be to have the fixed_point_multiply() as a topi operator (instead of an intrinsic), and then invoking add/multiply/shift there (i.e., inside the compute). That operator could be overridden for the specific target (i.e., arm) to use LLVM intrinsics.

What do you think?

@tqchen The problem arises because LLVM codegen is not able to use suitable instructions. A fixed point multiply at Relay level will have to upcast the input tensors to int64. ARM instructions that @giuseros shared take int32 tensors and perform the upcasting internally in the HW (please correct me if I am wrong - @giuseros). Therefore, today QNN/Relay graphs do not use the best possible ARM instructions.

At the same time, I have similar concerns about overkill. I earlier missed this, but having a new op disallows operator fusion, leading to 1.5% speedup instead of 3% speedup.

I’m in favor of the intrinsic. Pattern matching of code (idiom recognition) is generally a pain.

If we go that route, we should define it in such a way that the two values to multiply are interchangeable, i.e. fpm(x, y, s) is same as fpm(y, x, s), i.e. the x and y are values to multiply and s is the shift amount. What I mean specifically is that the original post uses different language to describe x and y (or m in that case), but it should not make any distinction between them.

Hi @anijain2305, All correct, except that the problem about fusion is more related to the fact that qnn.conv2d is lowered as a nn.conv2d followed by a requantize .

The best would be to fuse the requantization before the unpacking of the output tensor (i.e., after the main compute node) but I cannot do that, because the requantization happens later (hence going after the unpacking). I think this problem is common to most conv2d implementations that are in the arm_cpu path.

Hi @kparzysz, Yes pattern matching seems hard, we should mark the given set of operation from relay (and use the group later).

That is why a middle layer solution, i.e., implementing the fpm in topi rather than tir, might be the right approach

Hi @giuseros

You are correct that qnn.conv2d and qnn.requantize are different operators. And both of them are lowered to a sequence of Relay operators. But, here the strength of Relay comes in. Relay fuses nn.conv2d followed by a large number of elemwise ops into one operator. This can be seen by printing the Relay IR after fusion pass.

It might be the case that your operator is also fusion-friendly and if yes, then this point is not important.

Hi @anijain2305,

Yes, they are fused together, but at the end.

nn.conv2d is usually implemented as three compute nodes: pack+core+unpack.

The requantization operator is fused after the unpack, while the best would be to fuse after core (unpack can be hard to vectorize).

However, this is a topic for another discussion :slight_smile:

The relay operator I wrote should be fuse friendly, so it should not introduce any slow down. We need still to decide how to implement the fpm though :sweat_smile:

1 Like

If it’s only available as a topi operator then you couldn’t use it in a compute. This is a scalar function, an arithmetic operation of a certain kind. We have plenty of math intrinsics in TIR already, so this isn’t a precedent. You can have a topi operator for it, but it should be usable in TIR as well.

While I can certainly see the value of fixed point mul, there are a few other alternatives(simpler than fpm), which I list below

  • When the scale itself is power of two, it is possible to directly turn things into a right shift, without having to invoke any multiplication. However, given that right shift corresponds to rounds down by default, we will need to add 0.5 compensation(so it rounds to the nearest).
  • When the scale itself is not power of two. It is still interesting to ask whether or not we need to upcast to i64 in this case. It would be interesting to ask how to do things in the i32 domain. For example, an alternative would be do shift first on both a and b to make sure the result does not overflow and does multiplication in i32. Given that the final result will be in i8, i32 should be more than sufficient for such kind of scaling.

If we only use the intrinsic in limited cases, I can see us add the support as an intrinsic

Let me thank you all guys for the interesting discussion.

The main reason I implemented fpm as intrinsic was because I thought it similar to other operations like mul, add, sub, etc… (and I thought that different vendors might be plugging in their intrinsics to implement fpm)

But @kparzysz, you are absolutely right. If I implement this as a TOPI operator, it won’t be usable within a compute node. I didn’t think about that.

About writing this as fpm(x,y,s), I am not sure I can do that. The aim of the fixed point multiply is to multiply an int32 number x by a floating point number expressed as int32(round(2^30*M))*2^s, where M and s are the output of [M,s] = frexp(f), f being a float32 data. In fpm(x,m,s) I expect m=round(2^30*M) and s to be the shift that comes from frexp. In other words, I expect m and s to represent a floating point number with a given fixed point representation.

@tqchen, thanks for the explanation. Now everything is clear. However, we can implement the first optimization directly in Relay (if (is_power_2(scale)) shift else fpm)), and I guess we might implement the second optimization directly in TIR or in TOPI

PS I just uploaded the PR, please have a look

What you’re describing is a multiplication of Q numbers.

If we introduce a new intrinsic, we should express it in terms of relatively well-known concepts, and Q numbers are a standard concept in fixed-point arithmetic. Even your mathematical expression shows that x and m are interchangeable, so the description should not suggest any asymmetry.

Thanks for the link @kparzysz , I didn’t know about Q numbers.

I agree on expressing the intrinsic in terms of Q numbers, but I don’t follow when you say that x and m are interchangeable.

The point is that fpm(x,m,s) multiplies x (which is not a Q number) by a Q1.30 number (I think) described by m and s.

How can I exchange m with x?

Let’s say that fmpq(x,y,n) is defined as a fixed point multiplication of two Qk.n numbers x and y. Assume that k+m is 32, so that both x and y can be represented as int32 values.

Then, the fpm(x,m,s) that you want to implement is fmpq(2*x,m,31) * 2^s. As a matter of fact, the first operand in this multiplication is exactly sqrdmulh(x,m).

You can then invent the new topi operator, call it “fixed_point_multiply_and_scale” (or some better name), and implement it using the fmpq intrinsic with the scaling by 2^s.

Finally, you can “realize” that the original goal of multiplying an integer by a normalized floating point value is equivalent to the “fixed_point_multiply_and_scale”. This solves the original problem, and introduces a general TIR intrinsic for Q-number multiplication.

It makes sense now, thanks a lot @kparzysz !

fpmq would be my intrinsic, and it does the fixed point multiplication:

def fixed_point_multiply(x, y, n) 
    x = cast(x,int64) * y 
    pos_rounding_value = 1 << (n -1) 
    x = x + pos_rounding_value 
    x = x >> n 
    return cast(x, int32)

Which I call from the TOPI operator that I can overload for the arm target and use arm intrinsics.

However, I am sligthly worried about performance. Because in the default non-arm case, I would do two shifts (by n and by s), instead of combining everything into a single shift (n+s) - which is called total_right_shift in the original code.

What do you think?

The two shifts most likely get folded into one in the final code. LLVM will do that for sure. Also, keeping these shifts separate initially could help other targets that have instructions for fixed point multiplication.

Right, because those are all constants. Thanks @kparzysz , I like this design!

@tqchen, @anijain2305 what do you think? I would like to agree on a given direction and then I can go on and apply those changes