Train Better Llama 3 Embeddings with Simple Contrastive Learning
A simple method to improve the accuracy of RAG systems
In a previous article, we saw how to turn Llama 3 into an embedding model for RAG systems.
The method uses a masked next-token prediction (MNTP) objective for training. MNTP adapts the embeddings of an LLM for the task of encoding text. However, embeddings extracted from LLMs and only trained with MNTP still significantly underperform embedding models trained from scratch.
To improve the performance of an embedding model, contrastive learning is usually employed. It only requires in-domain text, for the unsupervised variant, or a dataset annotated for textual entailment for the supervised version.
In this article, I show how to train the embeddings of Llama 3 with simple contrastive learning. I first review how contrastive learning works and then we will see how to use this technique to train better embedding models extracted from LLMs. Contrastive learning can be applied to the embeddings of any LLMs.
The notebook training Llama 3 with simple contrastive learning is here: