LoRA Adapters: When a Naive Merge Leads to Poor Performance
The case of LoRA adapters fine-tuned with QLoRA
QLoRA is a memory-efficient way to fine-tune LLMs. It quantizes the LLM and then fine-tunes a LoRA adapter on top of it. I have used this method many times in my previous articles to fine-tune GPT-NeoX, Falcon, and Llama 2 models.
QLoRA only saves the fine-tuned adapter and not the entire model since we have kept its parameters frozen.
But then, what do we do with this adapter?
We have two solutions to use it:
Load it on top of the base model every time we need it
Merge it with the base model to get a new model
For both these solutions, we have to be careful. We can’t just naively load the base model and then load the adapter on top of it. We have to load the base model and preprocess it the same way it was during QLoRA fine-tuning, otherwise, we may get a significant performance drop. The same applies if want to merge the adapter.
In this article, I show you how to use the fine-tuned adapter. We will see that merging an adapter fine-tuned with QLoRA is not trivial. There is a method to avoid the performance drop after merging. I will explain and benchmark it. All the code to reproduce my experiments and the optimal merging strategy is available on the notebook page:
Last update: April 18th, 2024
QLoRA: Fine-tuning LoRA Adapters on Top of Quantized LLMs
When we load the model for QLoRA fine-tuning, we pass a BitsAndBytesConfig argument that contains the main hyperparameters for quantization. It usually looks like this for 4-bit quantization:
model_name = "meta-llama/Llama-2-7b-hf"
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_name, quantization_config=bnb_config, device_map={"": 0}
)
The bnb_4bit_compute_dtype is the hyperparameter that we have to remember for the merge. The QLoRA paper explains it here:
We dequantize the storage data type to the computation data type to perform the forward and backward pass, but we only compute weight gradients for the LoRA parameters which use 16-bit BrainFloat.
The LoRA adapter is fine-tuned for the dequantized data. In the code sample above, I store the data in 4-bit (bnb_4bit_quant_type="nf4") but I compute the weight gradient for data dequantized on the fly in bfloat16 (bnb_4bit_compute_dtype=compute_dtype). The LoRA adapter is fine-tuned for the float16 data type.
Note: If your GPU is compatible with bfloat16, I recommend using bfloat16 instead of float16. It improves the training stability.
Intuitively, if you load the adapter on top of a model using another precision, for which it hasn’t been fine-tuned, the results would be unexpected.
For the following experiments, I fine-tuned an adapter for Llama 2 7B on the openassistant-guanaco dataset:
I pushed the adapter on the Hugging Face Hub:
To use this adapter without merging, we simply need to load and activate it on top of the base Llama 2 7B, using the same loading hyperparameters used for fine-tuning, as follows:
if torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
else:
compute_dtype = torch.float16
model_name = "meta-llama/Llama-2-7b-hf"
adapter = "kaitchup/Llama-2-7B-oasstguanaco-adapter-1e"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_name, quantization_config=bnb_config, torch_dtype=compute_dtype, device_map={"": 0}
)
model = PeftModel.from_pretrained(model, adapter)
The model is ready for inference. I evaluated the perplexity of this model on the “test“ split of openassistant-guanaco. It yielded: