LoRA Adapters: When a Naive Merge Leads to Poor Performance
The case of LoRA adapters fine-tuned with QLoRA
QLoRA is a memory-efficient way to fine-tune LLMs. It quantizes the LLM and then fine-tunes a LoRA adapter on top of it. I have used this method many times in my previous articles to fine-tune GPT-NeoX, Falcon, and Llama 2 models.
QLoRA only saves the fine-tuned adapter and not the entire model since we have kept its parameters frozen.
But then, what do we do with this adapter?
We have two solutions to use it:
Load it on top of the base model every time we need it
Merge it with the base model to get a new model
For both these solutions, we have to be careful. We can’t just naively load the base model and then load the adapter on top of it. We have to load the base model and preprocess it the same way it was during QLoRA fine-tuning, otherwise, we may get a significant performance drop. The same applies if want to merge the adapter.
In this article, I show you how to use the fine-tuned adapter. We will see that merging an adapter fine-tuned with QLoRA is not trivial. There is a method to avoid the performance drop after merging. I will explain and benchmark it. All the code to reproduce my experiments and the optimal merging strategy is available on the notebook page:
Last update: April 18th, 2024
QLoRA: Fine-tuning LoRA Adapters on Top of Quantized LLMs
When we load the model for QLoRA fine-tuning, we pass a BitsAndBytesConfig argument that contains the main hyperparameters for quantization. It usually looks like this for 4-bit quantization:
model_name = "meta-llama/Llama-2-7b-hf"
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_name, quantization_config=bnb_config, device_map={"": 0}
)
The bnb_4bit_compute_dtype is the hyperparameter that we have to remember for the merging. The QLoRA paper explains it here:
We dequantize the storage data type to the computation data type to perform the forward and backward pass, but we only compute weight gradients for the LoRA parameters which use 16-bit BrainFloat.
The LoRA adapter is fine-tuned for the dequantized data. In the code sample above, I store the data in 4-bit (bnb_4bit_quant_type="nf4") but I compute the weight gradient for data dequantized on the fly in bfloat16 (bnb_4bit_compute_dtype=compute_dtype). The LoRA adapter is fine-tuned for the float16 data type.
Note: If your GPU is compatible with bfloat16, I recommend using bfloat16 instead of float16. It improves the training stability.
Intuitively, if you load the adapter on top of a model using another precision, for which it hasn’t been fine-tuned, the results would be unexpected.
For the following experiments, I fine-tuned an adapter for Llama 2 7B on the openassistant-guanaco dataset:
I pushed the adapter on the Hugging Face Hub:
To use this adapter without merging, we simply need to load and activate it on top of the base Llama 2 7B, using the same loading hyperparameters used for fine-tuning, as follows:
if torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
else:
compute_dtype = torch.float16
model_name = "meta-llama/Llama-2-7b-hf"
adapter = "kaitchup/Llama-2-7B-oasstguanaco-adapter-1e"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_name, quantization_config=bnb_config, torch_dtype=compute_dtype, device_map={"": 0}
)
model = PeftModel.from_pretrained(model, adapter)
The model is ready for inference. I evaluated the perplexity of this model on the “test“ split of openassistant-guanaco. It yielded:
4.02
At that point, we don’t know whether this perplexity is good or bad. We can use this score as a baseline. We want to get the same perplexity after merging.
In this example, the model that we fine-tuned with QLoRA is the same and we loaded it with the same quantization_config. It seems optimal.
Let’s say now that you don’t know which quantization_config has been used or that you don’t need quantization because you have enough VRAM. Then, you could drop the quantization_config argument.
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=compute_dtype, device_map={"": 0})
model = PeftModel.from_pretrained(model, adapter)
This model yields a perplexity of 4.07. That’s significantly worse but expected: The adapter is naively loaded on top of a base model without the bitsandbytes configuration for which it has been fine-tuned. The adapter wasn’t optimized for this configuration, hence the worse perplexity. Note: Lower perplexity is better.
Merging a LoRA Adapter
Let’s check what could be the impact of a naive merge, and see how we can do better.
The Naive Merge: Adapter + Base LLM
A merge consists of 4 steps:
Load the base model
Load and activate the adapter on top of the base model
Merge the adapter with the base model
Save the merged model
#Load the base model with default precision
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=compute_dtype)
#Load and activate the adapter on top of the base model
model = PeftModel.from_pretrained(model, adapter)
#Merge the adapter with the base model
model = model.merge_and_unload()
#Save the merged model in a directory "./naive_merge/" in the safetensors format
model.save_pretrained("./naive_merge/", safe_serialization=True)
Then, the model can be loaded for inference as any other model:
model = AutoModelForCausalLM.from_pretrained("./naive_merge/", torch_dtype=compute_dtype, device_map="auto")
I evaluated the model perplexity on the same dataset. It yielded:
4.16
It is significantly worse than our baseline (4.02). Even though the difference looks small, for some models and training configurations, decreasing the perplexity from 4.16 to 4.02 costs several training epochs. If your model is very large, that would be a huge waste of time and money. We don’t want such a waste just because of a naive merge.
Remember, the adapter is fine-tuned with QLoRA on top of a quantized model. So maybe we just need to quantize the merged model with the same QLoRA configuration before using it for inference.
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained("./naive_merge/", quantization_config=bnb_config, torch_dtype=compute_dtype, device_map={"": 0})
This model perplexity decreases to:
4.09
That’s much better. Still, it’s higher than our baseline. It doesn’t work perfectly because the merge was done directly on the base model, and then we quantized it. However, during QLoRA, we fine-tuned on top of a quantized model. These two configurations are slightly different, hence the higher perplexity.