[RFC] Add bfloat16 data type

Hi all, We have completed a workable draft of bfloat16 (bf16) in TVM.

We add bfloat16 as a new type named “bf16” in the frontend.

  • Use int16 as the storage type
  • Add legalization to enable computations on bf16
  • Add runtime frontend support (e.g. allow converting numpy’s uint16 array to bf16 NDArray)

Motivation

bfloat16 is a 16-bit float-point data type. You can easily get a bfloat16 by truncating a fp32 number getting the higher-ordered 16 bits. bfloat16 has lower memory consumption and is more friendly to memory-bound applications. It also requires no special hardware instructions, as we can lower the computation on bf16 to casting to fp32 and then using fp32 to do the computations. Thus, we bring bfloat16 datatype in TVM.

Details on legalization

Since most of the HW has no native support for computation on bf16, we added a pass BF16Legalization to use fp32 computing bf16 data. It adds cast_to_fp32() before each Op involing bf16 operands, and use Ops of fp32 to compute. Finally, it adds a ‘cast_to_bf16()’ after each Op that is altered. e.g.

add(a,b) => cast16(add(cast32(a), cast32(b)))

We call this phase as “BF16Promotion”. It is a sub-pass of BF16Legalization pass.

We note that this will add redundant casting. e.g.

add(a, neg(b)) => cast16(add(cast32(a), cast32(cast16(neg(cast32(b)))))

The pattern cast32(cast16(some_fp32_value)) can be simplified to some_fp32_value.

Thus, we add an optimization pass after “BF16Promotion” in BF16Legalization pass, which eliminates redundant casts.

After BF16Legalization pass, there will be no bf16 related computation in the AST, except casting between fp32 and bf16, bf16 value comparasion and assignment.

Casting between fp32 and bf16

We follow PyTorch’s bf16 casting implementation.

Pull request

Design choices in legalization

Please view @tqchen 's post below.

2 Likes

@tqchen As you have proposed to eliminate “bfloat16” dtype after the legalization pass, I have two concerns.

  • The legalization pass will be more compilicated. It has to check every TIR node to replace the bf16 dtype
  • As the casting between bf16 and fp32 will be lowered as function calls as is in your proposal, I am not sure if TVM can correctly vectorize the casting function.

I think we’d better to treat bf16 as a “first class” type as “int16/fp32”, not in the way of treating custom data types.

Thanks @Menooker or the RFC. It would be great a motivation section can be added(why b16 support) for others who does not have a background on this. In terms of techinical choices discussions, it would be great to list the design choices, discuss their pros and cons, and then talk about concerns.

Design Choices and Pros and Cons

  • B0: Do all legalization(cast, and compute) in TIR
  • B1: Do all legalization in the target codegen(LLVM, CUDA etc.)
  • B2: Do compute legalization in TIR, cast legalization in target codegen.

Discussion

Given the above choices, B0 allows us to use the same legalization process for all target backends(e.g. CUDA, LLVM, C if necessary), this is the main reason why doing it in TIR is more desirable. The implementation complexity is not that different, given the main difference is about moving related implementations from the target to the common TIR.

Notably, we don’t have to lower the cast into a specific function, while external function was being used in custom data types as an example, we can certainly in this case directly lowers to sequence of expressions(reinterpret then shift), which will be properly vectorize.

bf16 is already a “first class type” from the moment we bring DataType::kBfloat16. The legalization is necessary for backends that does not support the type, as more backend moves to support it, we could optionally skip the legalization step for those backend. This is another reason why it is important to have the b16 support either in TIR or the backend itself, instead of splitting the support into two parts.

ok, working on it to implement a new legalization pass

Updated design details

Details on legalization

Since most of the HW has no native support for computation on bf16, we added a pass BF16Legalization to use fp32 computing bf16 data. It has 3 sub-passes: Promotion, Elimilination and Lowering.

BF16Promotion

It adds cast_to_fp32() before each Op involing bf16 operands, and use Ops of fp32 to compute. Finally, it adds a ‘cast_to_bf16()’ after each Op that is altered. e.g.

add(a,b) => cast16(add(cast32(a), cast32(b)))

We call this phase as “BF16Promotion”. It is a sub-pass of BF16Legalization pass.

BF16CastElimination

We note that this will add redundant casting. e.g.

add(a, neg(b)) => cast16(add(cast32(a), cast32(cast16(neg(cast32(b)))))

The pattern cast32(cast16(some_fp32_value)) can be simplified to some_fp32_value.

Thus, we add an optimization pass after “BF16Promotion” in BF16Legalization pass, which eliminates redundant casts.

BF16Lowering

This pass replace all dtypes of bf16 to uint16. It also lowers the cast between bf16 and fp32 with shifting and other TIR nodes.

After BF16Legalization pass, there will be no bf16 related node in the IR.