[RFC][Tensorcore] INT4 end-to-end inference

Introduction

NVIDIA Turing tensor core has been enhanced for deep learning network inferencing.The Turing tensorcore adds new INT8 INT4, and INT1 precision modes for inferencing workloads that can tolerate quantization and don’t require FP16 precision while Volta tensor cores only support FP16/FP32 precisions.

Cutlass only supports INT4 matrix multiplication using tensor cores. There’s no existing libraries that fully support INT4 conv2d or INT4 end-to-end inference. In this RFC, we add new features in Relay and Topi to achieve this goal in TVM. We briefly describe the changes as follows. We will continue to post more performance numbers and associated PRs will be posted to keep track of the progress. All the experiments use Nvidia T4 GPU.

This is a joint work of our capstone project and Amazon AWS TVM team. The entire pipeline is to train the parameters in Pytorch and load them into the TVM for inference. Thanks @Laurawly, @GaryYuyjl @yidawang for the collaboration.

Topi

  1. One new direct tensorcore conv2d schedule (PR sent: 6121), and one im2col tensorcore conv2d schedule. We found that in most of the workloads, direct conv2d performs better than im2col conv2d; performance numbers shown below:
workload conv2d workload (batch_size, in_channels, in_size, out_channels, kernel_size, stride, padding) im2col(ms) direct(ms)
0 (8, 64, 56, 64, 3, 1, 1) 0.21777 0,19015
1 (8, 64, 56, 128, 3, 2, 1) 0.15 0.12979
2 (8, 64, 56, 128, 1, 2, 0) 0.04909 0.04359
3 (8, 128, 28, 128, 3, 1, 1) 0.14178 0.15725
4 (8, 128, 28, 256, 3, 2, 1) 0.10795 0.0994
5 (8, 128, 28, 256, 1, 2, 0) 0.02941 0.04659
6 (8, 256, 14, 256, 3, 1, 1) 0.1376 0.12328
7 (8, 256, 14, 512, 3, 2, 1) 0.12329 0.11763
8 (8, 256, 14, 512, 1, 2, 0) 0.02827 0.04626
9 (8, 512, 7, 512, 3, 1, 1) 0.20317 0.11436
  • Since Tensor Core had various and tight shape constraints on different precision’s matrix multiplication, direct convolution inherit such constraints on workloads’ batch, in_channel and out_channel axises. But for im2col, such constraints are looser since multiple axes could be fused into one to meet the divisible requirement. Hence in our implementations, most of the workloads that fit into direct convolution’s shape constraints would be applied with direct convolution strategy, while the others (e.g., the first convolution layer in Resnet18/ Resnet50) should use im2col.

  • In terms of the scheduling, we found that with input pre-packing, vectorization and with data layout HWNC, we gets performance gain. And we need to first write the result into shared memory before writing into global memory in order to fuse correctly with other operations.

  • For 4-bit inference, there is huge impact if we can utilize UINT4 instead of INT4. To achieve this, we made changes in relay to support UINT4 and INT4 convolution. It is strange that CUDA wmma APIs don’t support “UINT4 x INT4” but PTX does so we hard-coded a function in CUDA code-gen to call PTX assembly code if there is an “UINT4 x INT4”.

  1. Cast from INT32 to INT4
  • We need to requantize from INT32 to INT4 as the input of the convolution in each layer since the result of convolution in previous layer is with full precision. We can’t directly cast to INT4 by using “(INT4)” (which could be done for INT8) and it has to be done by packing several 4-bit number into an existing datatype, for example, int8 or int32. We choose int32 to store 8 4-bit numbers to align with Alibaba’s Sub-byte PR. In the casting process, every 8 consecutive 32-bit numbers are packed into one int32 number. Their least significant four bits are left, bit shifted to the proper position and elementwise-ored with each other in the same group to formulate the new number. (diagram needed here for better clarificatfion)

Relay

  1. Support reading INT4 numpy array
  • Numpy doesn’t natively support INT4 array so we need to store 4-bit weights into int32 numpy data type. When the model sees INT32 numpy array assigned to a INT4 variable, it will not trigger the data type and there will be data shape mismatch problem. And there are also some changes needed to make in order to do the constant folding 4-bit parameters.
  1. QNN Wrappers and quantized Resnet
  • We also create the QNN wrappers to build the Neural Networks. The quantized Resnet is written as an example. We also wrote ease-of-use interfaces to specify each layers scaling and datatypes.

End-to-end performance preview

More results will be updated here, here’s an initial overview of the performance on Nvidia T4 GPU.

Resnet 50 int8 tensorcore (ms) int4 tensorcore (ms) int8 dp4a (ms)
Batch=8 8.12 6.7 9.72
Batch=16 12.54 11.7 15.65

Summary

In this RFC we show the performance results and the changes in TVM. There is still a room to improve the speed and we are actively working on it.

14 Likes

CC @janimesh for his interests in the relay support on INT4/INT8 quantized models.

I’m interested this result is based on HWNC or NCHW?

The results are based on HWNC. HWNC gives us the best performance for Conv2d tensorcore INT4/INT8. But now not all operations support HWNC, for example, pooling. So the current workaround is to transpose before the pooling layer. This brings a small overhead in terms of end-to-end performance.

Cutlass open sourced INT4 2d CONV in NHWC and NC64HW64 for both Turing and Ampere. More than welcome to try them out.

1 Like

Hello, I find this work really fun.

I really want to run the models on my environment but there seem to be several issues with the codes that you gave out on related git page (GitHub - Zhen-Dong/HAWQ: Quantization library for PyTorch. Support low-precision and mixed-precision quantization, with hardware implementation through TVM.) I handled the minor issues that I encountered, but from instruction number3, I can’t figure out how to go any further. If possible, any kind of help or advice would be really amazing.

I followed all the instructions on ‘instructions on hardware deployment’ readme page, and when trying to run on TVM interface, I couldn’t pass through number 3. Here are some questions I want to ask.

  1. I downloaded several models from your modelZoo page, and realized some of them has ‘quantized_checkpoint.pth.tar’ file while some of them don’t. All the files have different formats in them. I don’t know how to handle them.
  2. the python code need file that is named “checkpoint.pth.tar” and I just changed from quantized_checkpoint.pth.tar to that, is it correct?
  3. So after 2, still don’t successfully run the code and the error is ‘KeyError: weight_integer’, that means the dict_keys don’t have that kind of key. To which key do I change it? I think there are 5 options (epoch, arch, statedict, best_acc, optimizer)
  4. I don’t know if I have to pass 3 to run every other codes, but at least I wanted to measure inference time (with uniform int4) so I ran the code you suggested on instruction number6-1, and that also doesn’t run well.

Any kind of help would be really greatful. Thank you.

I don’t know if I have to pass 3 to run every other codes, but at least I wanted to measure inference time (with uniform int4) so I ran the code you suggested on instruction number6-1, and that also doesn’t run well.

To measure the inference time, step6 doesn’t require real data but uses random-generated data. Can you elaborate more on the problems?

I followed all the instructions on ‘instructions on hardware deployment’ readme page, and when trying to run on TVM interface, I couldn’t pass through number 3. Here are some questions I want to ask

For the modelZoo problem, can you file an issue in the HAWQ GitHub page? We will keep track of it there.

1 Like

Thank you for the reply and I will file an issu on the github page about the other problems. This is the error I got from running step6.

$ python3 test_resnet_inference_time.py --num-layers 50 --batch-size 8 --data-layout "HWNC" --model-type "int4" Apply tuning log ./mixed_precision_models/tuning_logs/resnet50_HWNC_mixed_batch_8.log Traceback (most recent call last):

File "test_resnet_inference_time.py", line 232, in <module> graph, lib, params = relay.build(func, target=TARGET_NAME, params=params)

File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/relay/build_module.py", line 251, in build graph_json, mod, params = bld_mod.build(mod, target, target_host, params)

File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/relay/build_module.py", line 120, in build self._build(mod, target, target_host)

File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 219, in call raise get_last_ffi_error()

KeyError: 'Traceback (most recent call last):\n [bt] (8) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)+0x8e) [0x7fbb0e61569e]\n [bt] (7) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x91) [0x7fbb0e61a651]\n [bt] (6) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>)#6}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>)+0x27) [0x7fbb0e617627]\n [bt] (5) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*)+0x43) [0x7fbb0e4bcd73]\n [bt] (4) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&)+0x745) [0x7fbb0e4c0215]\n [bt] (3) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call<tvm::RelayExpr, 3, tvm::RelayExpr ()(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr ( const&)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)+0x210) [0x7fbb0e448bf0]\n [bt] (2) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::RelayExpr tvm::relay::LayoutRewriter<tvm::relay::alter_op_layout::AlterTransformMemorizer>(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)+0xa45) [0x7fbb0e4466f5]\n [bt] (1) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::alter_op_layout::AlterTransformMemorizer::CallWithNewLayouts(tvm::relay::Call const&, std::vector<tvm::RelayExpr, std::allocator<tvm::RelayExpr> > const&)+0x773) [0x7fbb0e4441f3]\n [bt] (0) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(+0x13434fb) [0x7fbb0e6ea4fb]\n File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun\n rv = local_pyfunc(*pyargs)\n File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/relay/op/nn/_nn.py", line 98, in alter_op_layout_conv2d\n return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)\n File "<decorator-gen-39>", line 2, in conv2d_alter_layout\n File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/target/generic_func.py", line 267, in dispatch_func\n return dispatch_dict[k](*args, **kwargs)\n File "/home/kjk2020/tvm-newHAWQ/tvm/topi/python/topi/cuda/conv2d_alter_op.py", line 39, in _alter_conv2d_layout\n relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)\n File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/relay/backend/compile_engine.py", line 212, in select_implementation\n return best_plevel_impl, outputs[best_plevel_impl]\nKeyError: None'

  1. Is the branch int4_direct_HWNC ?
  2. Can you try temporarily move ./mixed_precision_models/tuning_logs/resnet50_HWNC_mixed_batch_8.log to other places? I am guessing the new tuning log may cause a problem.
1 Like

Thank you for reading :slight_smile:

  1. When first trying the instructions, I used that branch. But there were some parameters errors so I downloaded whole files from GitHub - Zhen-Dong/HAWQ: Quantization library for PyTorch. Support low-precision and mixed-precision quantization, with hardware implementation through TVM. which contains TVM branch int4_direct_HWNC and for now, I’m working on that version which I see no difference from branch int4_direct_HWNC.

in short …

first trial : git clone --branch int4_direct_HWNC github.com/apache/incubator-tvm

for now : git clone github.com/zhen-dong/hawq

  1. I removed 2 log files from that directories including the one you specified, but seems like the error comes from same cause here’s the error message

$ python3 test_resnet_inference_time.py --num-layers 50 --batch-size 8 --data-layout "HWNC" --model-type "int4" Traceback (most recent call last):

File "test_resnet_inference_time.py", line 232, in <module> graph, lib, params = relay.build(func, target=TARGET_NAME, params=params)

File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/relay/build_module.py", line 251, in build graph_json, mod, params = bld_mod.build(mod, target, target_host, params)

File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/relay/build_module.py", line 120, in build self._build(mod, target, target_host)

File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 219, in call raise get_last_ffi_error()

KeyError: 'Traceback (most recent call last):\n
[bt] (8) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)+0x8e) [0x7fa5cc85769e]\n
[bt] (7) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x91) [0x7fa5cc85c651]\n
[bt] (6) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>)#6}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>)+0x27) [0x7fa5cc859627]\n
[bt] (5) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*)+0x43) [0x7fa5cc6fed73]\n
[bt] (4) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&)+0x745) [0x7fa5cc702215]\n
[bt] (3) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call<tvm::RelayExpr, 3, tvm::RelayExpr ()(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr ( const&)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)+0x210) [0x7fa5cc68abf0]\n
[bt] (2) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::RelayExpr tvm::relay::LayoutRewriter<tvm::relay::alter_op_layout::AlterTransformMemorizer>(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)+0xa45) [0x7fa5cc6886f5]\n
[bt] (1) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::alter_op_layout::AlterTransformMemorizer::CallWithNewLayouts(tvm::relay::Call const&, std::vector<tvm::RelayExpr, std::allocator<tvm::RelayExpr> > const&)+0x773) [0x7fa5cc6861f3]\n
[bt] (0) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(+0x13434fb) [0x7fa5cc92c4fb]\n File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun\n rv = local_pyfunc(*pyargs)\n File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/relay/op/nn/_nn.py", line 98, in alter_op_layout_conv2d\n return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)\n File "<decorator-gen-39>", line 2, in conv2d_alter_layout\n File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/target/generic_func.py", line 267, in dispatch_func\n return dispatch_dict[k](*args, **kwargs)\n File "/home/kjk2020/tvm-newHAWQ/tvm/topi/python/topi/cuda/conv2d_alter_op.py", line 39, in _alter_conv2d_layout\n relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)\n File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/relay/backend/compile_engine.py", line 212, in select_implementation\n return best_plevel_impl, outputs[best_plevel_impl]\nKeyError: None'

first trial : git clone --branch int4_direct_HWNC github.com/apache/incubator-tvm

The branch is not in main TVM repo. This command should not checkout the used branch.

Did you finish the section Install zachzzc's TVM and see no errors?

1 Like

I’m sorry I made you a bit confusing.

I meant first trial : git clone --branch int4_direct_HWNC [ http://github.com/zachzzc/incubator-tvm.git /tvm]

Did you finish the section Install zachzzc's TVM and see no errors?

Yes I did finish all of that section and no, I didn’t see any errors there.

I also checked all the versions like cuda and upgraded them if needed as the page says.

I tried install again from scratch but didn’t see the errors. Did you set the path to my TVM repo? It may points to your other versions of TVM.

My tvm_home path is correctly linked to hawq tvm.

For number6,(you might have already realized from the error message above but,) the issue is in the /tvmhome/python/tvm/relay/backend/compile_engine.py

in “def select_implementation” function, there’s this codeline all_impls = get_valid_implementations(—) , which returns nothing and therefore nothing is run during select_implementation function and finally, outputs[best_plevel_imple] is detected as error since best_plevel_imple is NONE. get_valid_implementation is also like that, and there’s similar function “fstrategy(—)” which is an api function so I did not trace any further.

I thought this problem is related to the one number3 is not functioning well, so I started working on that again.

The file “hawq_utils_resnet50.py” line 483, 484, 485, my machine can’t find any key with those parameters. I looked into the pytorch model and had a bit of clue what the code intended, but apparently the machine doesn’t know what the code lines mean. The dictionary keys that ‘model’ has, are just the keys of checkpoint (epoch, arch, state_dict, best_acc1, optimizer).

There might be something wrong with my PC so I am working to run this in one of our lab servers. Anyways, can you let me know the version of your python and llvm? That might be a problem (I’m not sure tho). Also, parts of the issues above is filed as an issue in the github page.

Thanks a lot for your care :slight_smile:

Anyways, can you let me know the version of your python and llvm? That might be a problem (I’m not sure tho).

python 3.7.4 and llvm 10.0.1

Hello zachzzc, I’m having a hard time trying to run this implementation.

Can you print out the strategy.specializations in op/relay/compile_engine.py :120 and see if it is not empty while running your code?

Now I’ve fixed the hard coded 'hawq_utils_resnet50.py" file to fit resnet 18, and now I get the exact same error as the ‘test_resnet_inference_time.py’ that I mentioned above.

I’ve been working on this for days but couldn’t find what strategy.specializations is, and what it should contain.

Thank you in advance,

It is not empty in my run. I print out the impl.name it shows conv2d_hwnc_tensorcore_direct.cuda for convs and other implementation like injective.cuda, pool.cuda. It should not be completely empty.

Yes, it should not be empty. And the reason it is empty is because strategy.specialization is empty.

If I print out the strategy, it says “relay.OpStrategy(0x6b23fb0)” Also, the fstrategy says “GenericFunc(0x34h8af4)”

But the strategy.specialization is totally empty. Do you have any idea what this specialization is, and can you let me know what is in yours?

Thank you.