Multi-GPU DPO Training with FSDP: Full Training, LoRA, and QLoRA
DPO training for Qwen2.5 72B and Llama 3.1 70B on consumer hardware
Direct Preference Optimization (DPO) is one of the most popular methods for aligning large language models (LLMs) with human preferences. With parameter-efficient fine-tuning techniques like LoRA and QLoRA, applying DPO to models with up to 8 billion parameters, such as Llama 3.1 8B and Qwen2.5 7B, becomes feasible on consumer hardware, albeit with relatively short training sequences.
However, scaling DPO to much larger models, such as those with 70 billion or more parameters, is much more challenging. Training such models without relying on techniques like LoRA typically requires multiple GPUs or even distributed nodes. Let’s break this down with an example: suppose we aim to train a 70B-parameter model using DPO on a node with 8 H100 GPUs, providing a total of 640 GB of GPU memory. We have:
Policy Model: The parameters of the actively trained policy model occupy 140 GB of GPU memory.
Reference Model: DPO requires a reference model, usually of the same architecture, adding another 140 GB to the memory footprint.
At this point, the parameters alone consume 43.75% of the GPU memory.
Next, consider AdamW’s optimizer states. For each parameter in the policy model, the optimizer maintains two additional states. Assuming these states use a 16-bit data type (half-precision), they require an additional 280 GB.
This leaves only 80 GB of GPU memory available. However, this remaining memory must accommodate activations and gradients. Full DPO training on a single node is feasible, but it requires offloading certain components, such as portions of the optimizer states or the reference model. This offloading significantly slows down training due to the increased reliance on communication between the GPU and CPU RAM, which reduces GPU efficiency as the bottleneck shifts to data transfer overhead.
In this article, we will explore how to leverage PyTorch’s Fully Sharded Data Parallel (FSDP) and parameter-efficient techniques like LoRA/QLoRA to efficiently perform DPO training for large models with a multi-GPU setup. The method is largely similar to the approach used for supervised fine-tuning on multiple GPUs with the primary distinction being the need to handle DPO’s reference model.
For the experiments in this article, I performed DPO training with QLoRA using just four RTX 4090 GPUs from RunPod (referral link).
The code explained in this article is provided in the following notebook, which demonstrates DPO training (LoRA and QLoRA) with FSDP for Qwen2.5 72B. The same approach applies to Llama 3.1 70B, with minor adjustments to the tokenizer’s configuration.
Setting Up FSDP for DPO Training
First, ensure you have a recent version of the Transformers library. FSDP support in Transformers was temporarily broken in an attempt to correct the gradient accumulation issue that we identified in October, but it has been resolved since version 4.46.3. This version also correctly handles gradient accumulation during FSDP training.