Float16 for Cuda - Performance


#1

Currently, float16 support for CUDA is incomplete - both functionally and performance-wise. There are few posts that suggest some ways to deal with the functional aspect, but these are not merged in yet. This post is for dealing with the second portion - Performance.

I was reading this paper - https://www.comp.nus.edu.sg/~wongwf/papers/hpec17.pdf

This one talks about half2 vs half data types. half2 is basically float16x2. It seems that we can speedup using FP16 on CUDA only when we use half2 datatype, signaling the hardware to performance two float16 operations simultaneously.

Has anybody prototyped this before? Or has idea how to make this happen?

@vinx13 @ibeltagy @hhhh @xyzhou @ydy @tqchen @comaniac


#2

Generally there are two things:

  • Overriding codegen for half and half2 types for arithmetic operators *,+,…
  • Support vectorized type half2: in codegen for CUDA, map float16x2 to half2

#3

Support vectorized type half2: in codegen for CUDA, map float16x2 to half2

So, schedules have to be changed, right?
Or can we reuse the same schedule and somehow for example, convert float16x8 to say (float16x2)x4


#4

Not necessarily. It depends on how vectorized length is set (hard-coded or using autotvm)


#5

Just to echo @janimesh note about performance, I ran some PyTorch code and the equivalent TVM-generated code and compared their float32 vs. float16 performance.
Switching to float16 gave the following speedups:
PyTorch: 1.88x
TVM-generated code: 1.17x (significantly lower than PyTorch)


#6

Thanks @ibeltagy for sharing the observation. We need to deep dive into the TVM-CUDA schedules to understand this. Currently, I am not aware of what needs to go in to get speedup.

Unfortunately, I am busy with other portions of TVM project. Also, I am not familiar with CUDA schedules and codegen to quickly come up with list of tasks.

It might be very helpful if we can get someone who has worked on CUDA schedules to come up with a rough plan, and then we can parallelize the efforts.