Speculative Decoding for Faster Inference with Mixtral-8x7B and Gemma
Using quantized models for memory-efficiency
Larger language models typically deliver superior performance but at the cost of reduced inference speed. For example, Llama 2 70B significantly outperforms Llama 2 7B in downstream tasks, but its inference speed is approximately 10 times slower.
Many techniques and adjustments of decoding hyperparameters can speed up inference for very large LLMs. Speculative decoding, in particular, can be very effective in many use cases.
Speculative decoding uses a small LLM to generate the tokens which are then validated, or corrected if needed, by a much better and larger LLM. If the small LLM is accurate enough, speculative decoding can dramatically speed up inference.
In this article, I first explain how speculative decoding works. Then, I show how to run speculative decoding with different pairs of models involving Gemma, Mixtral-8x7B, Llama 2, and Pythia, all quantized. I benchmarked the inference throughput and memory consumption to highlight what configurations work the best.
The notebook running speculative decoding with Mixtral, Gemma, and other LLMs is available here: