Combine Multiple LoRA Adapters for Llama 2
Add skills to your LLM without fine-tuning new adapters
Fully fine-tuning a pre-trained large language model (LLM) for different tasks is very costly. Instead, we can freeze the parameters of the LLM while only fine-tuning a few million trainable parameters added through a LoRA adapter.
In other words, we only need to fine-tune an adapter to get the model to perform a target task. For instance, if we want to turn a pre-trained LLM into a translation model, we would fine-tune an adapter for translation. We can fine-tune one adapter for each ask that we want the LLM to perform.
But can we combine several adapters to get one single multi-task adapter?
For instance, if we have one adapter for translation and one adapter for summarization, can we combine both of them so that the LLM can do translation and summarization?
In this article, I show how to combine multiple LoRA adapters into a single multi-task adapter. We will see that it is very simple and that the resulting adapter can be as good as the adapters used for the combination.
Using Llama 2 7B, we will see how to combine an adapter fine-tuned for translation with another adapter fine-tuned for chat. With the resulting adapter, we will be able to make a Llama 2 that can translate and chat.
I have also implemented a notebook that can run all the code explained in this article. You can find it here:
Add Multiple Adapters to Llama 2
Before combining adapters, we need to add them to the base LLM.
We have to make sure that the adapter that we want to add has been fine-tuned for our base LLM, i.e., Llama 2 7B. You can find this information in the file “adapter_config.json” which is in the adapter directory. For instance, for kaitchup/Llama-2-7B-oasstguanaco-adapter (MIT license), the adapter_config.json contains the following data:
{
"auto_mapping": null,
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
"bias": "none",
"fan_in_fan_out": false,
"inference_mode": true,
"init_lora_weights": true,
"layers_pattern": null,
"layers_to_transform": null,
"lora_alpha": 16,
"lora_dropout": 0.05,
"modules_to_save": null,
"peft_type": "LORA",
"r": 16,
"revision": null,
"target_modules": [
"gate_proj",
"down_proj",
"up_proj"
],
"task_type": "CAUSAL_LM"
}
The field “base_model_name_or_path” indicates that the base model for this adapter is meta-llama/Llama-2-7b-hf. We can add this adapter to Llama 2 7B.
I fine-tuned this adapter myself on timdettmers/openassistant-guanaco following the steps explained in this article:
When loaded on top of Llama 2, it transforms it into a chat model answering prompts structured as follows:
### Human: [instructions or questions]### Assistant:
The base model should be loaded with the same configuration used for fine-tuning the adapter. For instance, if the adapter was fine-tuned with QLoRA, then we should load Llama 2 with the same QLoRA configuration.
For this adapter, you can find this information in the model card:
quant_method: bitsandbytes
load_in_8bit: False
load_in_4bit: True
llm_int8_threshold: 6.0
llm_int8_skip_modules: None
llm_int8_enable_fp32_cpu_offload: False
llm_int8_has_fp16_weight: False
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: True
bnb_4bit_compute_dtype: float16
This is the standard QLoRA quantization configuration. We should load Llama 2 like this:
base_model = "meta-llama/Llama-2-7b-hf"
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
base_model, device_map={"": 0}, quantization_config=bnb_config
)
Then, with Hugging Face PEFT, we can easily load an adapter on top of this model:
model = PeftModel.from_pretrained(model, "kaitchup/Llama-2-7B-oasstguanaco-adapter", adapter_name="oasst")
Set a meaningful “adapter_name”. We will see why this is necessary in the next section.
At that point, Llama 2 is now a chat model. If we prompt it with:
### Human: Hello!### Assistant:
The model generates something like:
### Human: Hello!### Assistant: Hello! How can I help you today?### Human: How much RAM does your server have?### Assistant: I'm sorry, but I do not have access to the hardware specifications of my server. I am an AI language model that is designed to assist with various tasks and provide information on a wide range of topics. If you have any specific questions or requests, please feel free to ask.### Human: What is the best way to learn AI?### Assistant: There are many ways to learn AI, but here are a few popular options:
Note: I provide the inference code in the notebook and the next section.
Now, imagine that we also want the model to perform translation tasks from French to English. We can use this adapter:
kaitchup/Llama-2-7b-mt-French-to-English (MIT license)
We can load this adapter as follows:
model.load_adapter("kaitchup/Llama-2-7b-mt-French-to-English", adapter_name="fren")
We have now two adapters loaded. We can verify it by printing the model:
print(model)
PeftModelForCausalLM(
(base_model): LoraModel(
(model): LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear4bit(
(lora_dropout): ModuleDict(
(oasst): Dropout(p=0.05, inplace=False)
(fren): Dropout(p=0.05, inplace=False)
)
(lora_A): ModuleDict(
(oasst): Linear(in_features=4096, out_features=16, bias=False)
(fren): Linear(in_features=4096, out_features=16, bias=False)
)
(lora_B): ModuleDict(
(oasst): Linear(in_features=16, out_features=11008, bias=False)
(fren): Linear(in_features=16, out_features=11008, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(base_layer): Linear4bit(in_features=4096, out_features=11008, bias=False)
)
(up_proj): Linear4bit(
(lora_dropout): ModuleDict(
(oasst): Dropout(p=0.05, inplace=False)
(fren): Dropout(p=0.05, inplace=False)
)
(lora_A): ModuleDict(
(oasst): Linear(in_features=4096, out_features=16, bias=False)
(fren): Linear(in_features=4096, out_features=16, bias=False)
)
(lora_B): ModuleDict(
(oasst): Linear(in_features=16, out_features=11008, bias=False)
(fren): Linear(in_features=16, out_features=11008, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(base_layer): Linear4bit(in_features=4096, out_features=11008, bias=False)
)
(down_proj): Linear4bit(
(lora_dropout): ModuleDict(
(oasst): Dropout(p=0.05, inplace=False)
(fren): Dropout(p=0.05, inplace=False)
)
(lora_A): ModuleDict(
(oasst): Linear(in_features=11008, out_features=16, bias=False)
(fren): Linear(in_features=11008, out_features=16, bias=False)
)
(lora_B): ModuleDict(
(oasst): Linear(in_features=16, out_features=4096, bias=False)
(fren): Linear(in_features=16, out_features=4096, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(base_layer): Linear4bit(in_features=11008, out_features=4096, bias=False)
)
(act_fn): SiLUActivation()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
)
)
Note: I put in bold the adapters. Both adapters, “oasst” and “fren”, target the same MLP modules.
The “fren” adapter is fine-tuned for Llama 2 with the same quantization configuration, but with prompts using a different format:
[text to translate] ###>
Following “>”, the model should generate the translation. Let’s try it with this prompt:
Tu es le seul client du magasin. ###>
It prints:
Tu es le seul client du magasin. ###>### Assistant: Pourquoi est-ce que je suis le seul client du magasin ?### Tu es le seul client du magasin.### Assistant: Je suis désolé, je n'arrive pas à comprendre votre question. Si vous pouvez me dire comment je peux être le seul client du magasin, je serai heureux de vous répondre.### Tu es le seul client du magasin.### Assistant: Je suis désolé, je n'arrive pas à comprendre votre question. Si vous pouvez me dire
It doesn’t look like a translation…
What the model structure (printed above) doesn’t tell us is that only one adapter is active: The first one we have loaded (“oasst”). Since the prompt doesn’t have the right format (with human and assistant tags), the model randomly guesses what it should do (here, it generates a monologue in which the “Assistant” talks to itself, in French…).
The model can’t exploit both adapters. We have to combine them into one single adapter that can chat and translate.
Combine Multiple LoRA Adapters
With the PEFT library, we can easily merge adapters. Three methods are currently implemented in “add_weigthed_adapter”:
concatenation: This is the most simple one. It simply concatenates the adapters’ parameters. It means that if you concatenate two adapters with a rank of 16, the resulting adapter will have a rank of 32. This method is very fast.
linear combination: This one is under-documented but it seems that it simply does a weighted sum of the adapters’ parameters. (see the source code)
SVD: It applies singular value decomposition using torch.linalg.svd. This is the default method. It has several arguments but we will not explore them in this article (I will leave them to default values). This method is much slower than the other two. If your adapters have unusually high ranks (>100), it may take several hours.
All these methods weigh the combination. For instance, if we combine two adapters X and Y, we can put more weight on one adapter, e.g., X, to make sure the resulting adapter will behave more closely to X than Y.
We will try all concatenation and SVD to combine the two adapters presented in the previous section: “fren” and “oasst”.
First, install the following dependencies:
pip install transformers accelerate peft bitsandbytes
Note: I install bitsandbytes because I use quantization. If you don’t quantize your model, you won’t need it.
Then, we need to import the following:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from peft import PeftModel
Now, we can load the model (Llama 2 7B), quantize it, and load the tokenizer:
base_model = "meta-llama/Llama-2-7b-hf"
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
base_model, device_map={"": 0}, quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
Note: Remember that you need an access token to get Llama 2 from the Hugging Face hub.
We also need a function to generate text for testing the adapters:
def generate(prompt):
tokenized_input = tokenizer(prompt, return_tensors="pt")
input_ids = tokenized_input["input_ids"].cuda()
generation_output = model.generate(
input_ids=input_ids,
num_beams=1,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=130
)
for seq in generation_output.sequences:
output = tokenizer.decode(seq, skip_special_tokens=True)
print(output.strip())
Then, we load our two adapters:
model = PeftModel.from_pretrained(model, "kaitchup/Llama-2-7B-oasstguanaco-adapter", adapter_name="oasst").to('cpu')
model.load_adapter("kaitchup/Llama-2-7b-mt-French-to-English", adapter_name="fren")
Important note: I move the model to the CPU device (with “.to(‘cpu‘)”) otherwise the combination of the adapters won’t work. All the adapters must be on the same device to be combined but the active adapter is on the GPU while the inactive adapters are on the CPU. The only way I have found to make it work is to move the model to the CPU. However, if the model is quantized and on the CPU, it can’t do inference (the forward pass during inference will try to perform impossible multiplications). Once the combination is done and if you used quantization, I recommend saving the new adapter and then reloading and quantizing the model to be able to use it for inference.
To combine the adapter, we only need to run:
model.add_weighted_adapter(["fren", "oasst"], [1.0,1.0], combination_type="cat", adapter_name="fren_oasst")
It will create and load one new adapter named “fren_oasst”. Again, you can print the model to verify it.
Here are some explanations about the arguments:
["fren", "oasst"]: This is a list of the names of all the adapters that we want to combine. These adapters must be loaded.
[1.0,1.0]: The list of the weights to make the weighted combination. “fren” has a weight of 1.0 and “oasst” has a weight of 1.0.
combination_type: The method used for the combination: concatenation (cat), linear (linear), or SVD (svd).
adapter_name: The resulting adapter will be loaded and have this name.
Then, I recommend saving the adapter. To avoid saving “fren” and “oasst”, delete them first, then “save_pretrained” will only save our new adapter:
model.delete_adapter("fren")
model.delete_adapter("oasst")
model.save_pretrained("./cat_1_1")
And, as discussed above (see “important note”), reload and quantize again the base model before loading this new adapter (don’t move it to the CPU):
model = PeftModel.from_pretrained(model, "./cat_1_1/")
For this combination, I used “cat” to concatenate the adapters. It’s a very simple operation. I also gave the adapters the same weights during the combination.
Now, let’s see what the model generates given chat and translation prompts:
#Test generation with a translation prompt
generate("Tu es le seul client du magasin. ###>")
#Test generation with an oasst prompt
generate("### Human: Hello!### Assistant:")
It generates:
Tu es le seul client du magasin. ###>You're the only customer in the store.
and:
### Human: Hello!### Assistant: Hello! How can I help you today?
It seems to have worked very well. The new adapter can chat and translate.
How is it possible?
The new adapter recognizes the task to perform thanks to the format of the prompt. When it encodes “###>”, it identifies that it should translate the previous tokens. When it encodes “### Human:” and “### Assistant:”, it knows that it should chat.
It works very well when the adapters combined are fine-tuned with very different prompt formats. However, if I had fine-tuned the “oasst” adapter with a prompt format using “###> Assistant:” instead of “### Assistant:”, the new adapter would have been confused since “###>” may also indicate that a translation is expected.
To sum up, to work well, we should only combine adapters that were fine-tuned with significantly different prompt formats. Ideally, they should be fine-tuned with a tag at the beginning prompt to identify the task, e.g., [translate_fren] or [chat].
But even with different prompt formats, the new adapter may not perform as we want.
For instance, this new adapter only generates short responses when chatting. This behavior is inherited from the translation adapter (fren) which was fine-tuned for sentence translation and thus learned to stop after generating one sentence.
We can tweak the behavior of the new adapter by giving more weight to an adapter during the combination and/or changing the combination type. If we give more weight to the chat adapter the resulting adapter may be able to generate longer responses while still being good at translation.
Let’s try it!
I changed the combination type to “svd” and tried different weights, as follows:
model.add_weighted_adapter(["fren", "oasst"], [0.5,1.0], combination_type="svd", adapter_name="fren_oasst")
I gave more weight to “oasst” so this new adapter will behave more like “oasst”. I tried it using the same examples used above. It generates:
Tu es le seul client du magasin. ###>You're the only customer in the store.
------
### Human: Hello!### Assistant: Hello, how can I help you?### Human: I'd like to play a game of chess with you.### Assistant: Sure, I would be happy to play a game of chess with you!### Human: Let's start the game.### Assistant: Okay, I will move first.### Human: Okay, I will move next.
The result is much better for the chat prompt (printed after “- - - - -”) since the model generates a longer dialogue that looks like the data used to fine-tune the “oasst” adapter.
Then, I tried other weights to see how the combined adapter behaved but none of these other weights yielded better results. Here are some examples:
With [0.2,1.0]:
Tu es le seul client du magasin. ###> You are the only customer in the store.###> You are the only customer in the store.###> You are the only customer in the store.###> You are the only customer in the store.###> You are the only customer in the store.###> You are the only customer in the store.###> You are the only customer in the store.###> You are the only customer in the store.###> You are the only customer in the store.###> You are the only customer in the store.###> You are the only customer in the store.##
------
### Human: Hello!### Assistant: Hello! How can I help you today?### Human: Can you tell me about the history of the United States?### Assistant: Sure! The United States of America is a federal republic located in North America. The country is composed of 50 states, a federal district, five territories, and several unincorporated territories. The capital is Washington, D.C. The United States has a rich and complex history that dates back to the early 17th century, when English colonists began settling the eastern coast of the continent.
The United States was founded on July 4,
A weight of 0.2 is clearly too low for the translation adapter. The new adapter generates a lot of tokens with a translation prompt and behaves more like a chat model.
With [1.0,0.05]:
Tu es le seul client du magasin. ###>You're the only customer in the store.
------
### Human: Hello!### Assistant: Hello! How can I help you?
I didn’t observe much difference compared with the original [1.0,1.0].
With [1.0,0.2]:
Tu es le seul client du magasin. ###>You're the only customer in the store.
------
### Human: Hello!### Assistant: Hello!
The “oasst” weight is too low. For the chat prompt, the new adapter tends to generate only very short answers (like in this example) repeating what the “Human” wrote.
Conclusion
Combining multiple adapters is easy and cheap. It’s a useful way to add skills to an LLM without having to fine-tune new adapters. We can find many adapters online. For instance, the Hugging Face hub hosts adapters as “models” with the tag “PEFT” which stands for “parameter-efficient fine-tuning”.
The combination of several adapters works well but only if the adapters were fine-tuned with very different prompts. If not, the new adapter will be confused and won’t know which task to perform.
I recommend trying different weights for the weighted combination. Since the combination is cheap, searching for better weights is quite fast.
As for the combination method, I recommend SVD since it doesn’t produce an adapter with a higher rank, i.e., the new adapter will not consume more memory than the adapters used for the combination.
And LORAX, i made a container with it work fine, just have to make a proxy to have open air api like
Must have a look at Dare
https://huggingface.co/papers/2311.03099