Torch Compile (torch.compile
) was first introduced with PyTorch 2.0, but it took several updates and optimizations before it could reliably support most large language models (LLMs). In fact, I started drafting this article about 8 months ago, but back then, most of my tests failed—either due to errors from torch.compile
or because it simply wasn’t effective enough. I kept testing regularly until I finally saw enough successful results to share.
My initial goal was to benchmark torch.compile
for fine-tuning LLMs, but, unfortunately, it’s still not working well enough for this purpose. After fixing several errors, I had to give up, as it seems that the torch_compile
argument in Transformers’ TrainingArguments
is still incompatible with most LLMs and training configurations.
However, when it comes to inference, torch.compile
can genuinely speed up decoding with only a small increase in memory usage.
In this article, we’ll go over how torch.compile
works and measure its impact on inference performance with LLMs. To use torch.compile
in your code, you only need to add a single line. For this article, I tested it with Llama 3.2 and also tried it with bitsandbytes
quantization, using two different GPUs: Google Colab’s L4 and A100.
I’ve created a notebook demonstrating how to use torch.compile
and benchmarking its performance here: