INT8 Quantization - Code generation for backends

For background on quantization, please read this link (INT8 quantization proposal).

This thread only focuses on implementation of quantized layers in TVM.

High-level overview

Hardware vendors are adding support for optimized INT8 operations in the hardware (Intel (https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training), Nvidia (https://devblogs.nvidia.com/mixed-precision-programming-cuda-8/)). To take full advantage of the hardware, we need to generate code that can generate these new instructions. In addition, since time-consuming layers like convolution have high data reuse property, we also have to find new schedules that can efficiently utilize the hardware.

Proposal

My current proposal is to focus on Intel Skylake and resnet-18 for now and complete an end-to-end implementation. We can start with the current TVM convolution layer optimized schedules and explore how new instructions change that schedule. Similarly, we can generate the quantized implementations for other layers in resnet-18.

When the end-to-end implementation is flushed out, we can add more backends (Nvidia, ARM)

Action Items

There will likely be many design decisions within each step, but this list is only covering the high level action items.

  1. TOPI - Generate the optimized quantized convolution schedule with optimized hardware instructions.

    1. Understand how does it affect data layout in and across kernels.
    2. Intermediate outputs need higher precision (INT32) to avoid overflow. This will require adding support for mixed precision arithmetic in TVM.
    3. The code generation will rely on LLVM to pattern match to INT8 operations. Intel LLVM team is currently working on that. We can also look at inline assembly if need be (https://github.com/dmlc/tvm/pull/1486).
  2. TOPI - Generate the optimized quantized schedules for fully connected, pooling, relu layers. The goal is to enable quantization on resnet-18

  3. NNVM - Modify the input graph to support quantization - like add input/output quantization layers, using the quantized models instead of precise ones.

def deploy_quantized_model(sym, qauntized_params)
    # Runs the quantized models

    # Inputs
    # sym - input network - NNVM modifies the network to support quantized inference
    # quantized_params - input params that will be quantized

Comments/suggestions are welcome.

1 Like

Some notes here:

  • Most TOPI conv functions already support mixed precision(via out_dtype argument)
  • @ziheng did some preliminary exploration of ARM mixed precision code
  • the NNVM operator is sufficient to support most quantized operator.
    • Note that there is no quantized/dequantized layers, but we can compose them up with round, multiply and cast

So the most interesting items to act on is to try to get best performance on a target platform

Similar observation as in your other post - it might be good to think about supporting arbitrary precision operators at the graph level since there is work on supporting bit-serial convolutions on RPi, and also quantized operators in FPGAs, accelerators etc.

The challenge with arbitrary precision is data layout, and specifically how to pack data into standard 8bit/32bit words.

I’m happy to discuss some ideas, and present some concrete scenarios down the road as we continue our work on quantization with VTA.

@thierry That is a good suggestion. I would suggest starting with a well understood backend for now. Intel INT8 operations like vpmaddubsw and upcoming VNNI instructions, as well as Nvidia D4PA instructions, will already need a non-trivial support.

Focusing on Intel Skylake for now will help flushing out the end-to-end pipeline. But yes, as you said, we need to ensure that any changes we do, they are extensible to accelerators/FPGAs as well.

@janimesh I totally agree. I think that nailing it on a single platform is a great start. With that, if we can think about future compatibility when designing the APIs, then it will help us down the road.

Thanks!

I am working with Intel LLVM team to support VNNI instructions code generation. A mid point in this goal is to support vpmaddwd instruction.

The motivation code is

 static const int N = 128;
 
  int16_t A[2*N];
  int16_t B[2*N];
  int C[N];
 
  for (int i = 0; i != N; ++i)
      C[i] = A[2*i]*B[2*i] + A[2*i+1]*B[2*i+1];
// Each iteration translates to vpmaddwd instruction
// Takes two sets of 2 16 bit values - |a0|a1| and |b0|b1| and computes |a0*b0 + a1*b1|
// while ensuring that the computation happens in 32 bits in HW
// Command -> clang++ exp.cpp -mavx512bw -O3 -S (trunk LLVM)

Intel LLVM team support code generation by IR pattern matching (https://reviews.llvm.org/D49636)

However, IR generated by TVM+LLVM is a totally different IR, though semantically same, causing the pattern matching to fail.

The relevant TVM code is

A = tvm.placeholder((N,), name='A', dtype='int16')
B = tvm.placeholder((N,), name='B', dtype='int16')
C = tvm.compute((N/2,), lambda i: (A[2*i].astype('int32') * B[2*i].astype('int32')) + (A[2*i + 1].astype('int32') * B[2*i + 1].astype('int32')), name='C')

s = tvm.create_schedule(C.op);
oi, ii = s[C].split(s[C].op.axis[0], factor=16)
s[C].vectorize(ii)
print(tvm.lower(s, [A, B, C], simple_mode=True));

target = 'llvm -mcpu=skylake-avx512'
ctx = tvm.context(target, 0);
a = tvm.nd.array(np.ones((N, ), dtype='int16'), ctx);
b = tvm.nd.array(np.ones((N, ), dtype='int16'), ctx);
c = tvm.nd.array(np.zeros((N/2, ), dtype='int16'), ctx);
func = tvm.build(s, [A, B, C], target, name='mmult')
func.save("baseline.s");
func.save("baseline.ll");

The key different in clang generated IR and TVM generated IR is that clang-IR is more optimized leading to vector loads + shuffle instructions. The TVM-IR loads scalar one by one. Intel LLVM does not support this type of pattern matching.

So, the question is - Should the IR be optimized to perform vector loads + shuffle in TVM?
Or should LLVM backend support pattern matching for all possible combinations?
Or more broadly, where do we draw the line?

2 Likes

There are two ways we can do this, @cowanmeg @vinx13 might have something to add

  • Craft the micro-kernel in ASM and use tensorization (has nothing to do with code gen backend).
  • Add a pattern matcher to CodeGenX86(like CodeGenARM) to support such pattern in tvm code generator, when possible

More generally speaking, we can find most of these intrinsics have things to do with dot operator that takes in vectors and get the dot product, maybe we can think how to enhance TVM to support dot operators natively in optimization

1 Like

It sounds like you need vectorized loads on A and B’s reduction axis in order for Intel’s LLVM to match the pattern, which TVM doesn’t support so you’re seeing IR that loads scalars from A and B and probably preforms vectorized writes to C.

We used a custom matrix-vector multiply microkernel for low bit operators that took advantage of ARM intrinsics that also required vectorization along a reduction axis here: https://github.com/dmlc/tvm/blob/master/topi/python/topi/rasp/bitserial_conv2d.py#L160

It looks like many of the useful new intrinsics express some type pairwise reduction within a vector. Maybe we can enhance vectorize to work on reduction axes to give LLVM a better chance of pattern matching for these intrinsics?

@cowanmeg That’s exactly what is happening. Thanks for the pointer. This should be extremely useful.

Regarding a generic solution, from what I see/anticipate for Intel, this reduction operation is “Multiply-accumulate”

For Intel Skylake
Inputs - A is Int16*32, B is Int16*32 
Output - C is Int32*16
Operation - 
A - |  a0   |    a1   |   a2   |    a3   | .......|  a30   |    a31   | 
B - |  b0   |    b1   |   b2   |    b3   | .......|  b30   |    b31   | 
C - |  a0*b0+a1*b1    |   a2*b2+a3*b3    | .......| a30*b30+a31*b31   |
For Intel Ice Lake (VNNI)
Inputs - A is Int8*64, B is Int8*64
Output - C is Int32*16
Operation - 
A - |  a0   |    a1   |   a2   |    a3   | .......|  a30   |    a31   | 
B - |  b0   |    b1   |   b2   |    b3   | .......|  b30   |    b31   | 
C - |  a0*b0+a1*b1   +   a2*b2+a3*b3     | .......| a30*b30+a31*b31   |

I have limited experience tbh, but an easier way for LLVM pattern matching to work is to perform vectorized loads for A and B. And then do vector shuffle to generate vectors [0,2,4,8…] and [1,3,5,7,…] (in the case of VNNI - there will be 4 shuffles).

Do you see similar type of pattern in the ARM as well?

We used a slightly similar pattern, since it was low-bit we were using popcount-AND to instead of multiplication, but we used a pairwise-add accumulate (vpadal).

A int8*16
B int16*8

A - |  a0   |    a1   |   a2   |    a3   | .......|  a14   |    a15   | 
B - |  b0               |   b2               | .......|  b7                   |

B - |  a0+a1+b0   |   a2+a3+b2   | .......|  a14+a15+ b7 |

The shuffles you mention are for accumulating the final elements of C into one int32 I assume?
We did something similar for accumulating results that required operating on half of a vector and combining, in the LLVMcodegen there’s some convenience functions that call LLVM’s VectorShuffle.

But this example is for manually emitting LLVM IR, I think you’re trying to reach a state where LLVM can output these types of instructions?

Thanks @cowanmeg for the pointer. Let me take a look.

Hi guys, I have a question about this topic. Is TVM going to enable quantization for a network as a whole, or simply quantize Conv/FC? If whole net is the case, what’s our plan for operators like softmax, which contains floating point computing (exp in softmax)? My consideration is that, as TensorFlow Lite and PyTorch/Caffe2 are using gemmlowp in quantized softmax, could it be a bit tricky to have similar functionality in TVM IR?

Thanks.

There is an ongoing effort to support full network quantization when possible