BF16 vs. FP16 vs. FP32 for Gemma 3 Inference — Mind Your Data Type
Mitigating Numerical Issues When Converting a Model from BF16 to FP16
The earliest large language models (LLMs) were trained and deployed using the float32 data type. Since each float32 parameter consumes 4 bytes, a 70-billion-parameter model like Llama 3 requires nearly 280 GB of memory, making large-scale deployment expensive and inefficient.
To mitigate this, LLM developers gradually transitioned to 16-bit precision to cut memory consumption in half. Initially, they adopted mixed-precision training and inference, keeping critical modules and operations in float32 while downcasting less sensitive components to float16, a 2-byte data type. However, float16 often leads to instability during training and inference, requiring loss scaling and other mitigation strategies.
As an alternative, bfloat16 has become widely adopted due to its superior dynamic range. It matches float32 while halving memory requirements. Most modern GPUs now support bfloat16, but many deployed older GPUs lack native support. Additionally, some CUDA kernels are optimized specifically for float16 inference, leading to better performance in certain scenarios.
If a model has been trained and released in bfloat16—like Gemma 3—can we simply cast it to float16 without degrading its performance? In most cases, naive casting to float16 has no visible impact on model quality. However, for models with certain architectural properties and, more specifically, trained with TPUs, such as Gemma 3, converting to float16 can completely break the model, causing it to generate nonsensical output.
In this article, we will evaluate the impact of using bfloat16, float16, and float32 with Gemma 3 in its 4B, 12B, and 27B variants. We will also explore different strategies to improve the float16 version of the models and assess whether accuracy loss can be mitigated.
The following notebook implements the evaluations and strategies discussed in this article: