The Unreasonable Impact of Gradient Checkpointing for Fine-tuning LLMs
How does gradient checkpointing reduce memory consumption
Fine-tuning large language models (LLMs) demands a significant amount of GPU memory, as we need to store the model’s weights, optimizer states, gradients, and activations.
Activations refer to all the tensors generated during the forward pass, which are essential for computing gradients and updating weights during the backward pass. These activations can be extremely large, and their size primarily depends on the model architecture, training batch size, and sequence length used for fine-tuning. For large batches, the activations can consume over 90% of the total memory required for fine-tuning.
For deep LLMs, activations quickly become unmanageable, even with small batch sizes, making it often impossible to keep all activations in memory. This is where gradient checkpointing becomes essential.
In this article, we'll explore how gradient checkpointing significantly reduces memory usage, cutting activation memory by up to 70%. We’ll demonstrate this by fine-tuning two models, SmolLM (130M parameters) and Qwen2 (1.5B parameters), both with and without gradient checkpointing, to measure the actual memory savings. Since gradient checkpointing involves recomputing some activations, we'll also evaluate its impact on training time.
You can find the complete code and training logs in this notebook: