I'm not sure but I think the optimizer parameters are moved from GPU to CPU as blocks attached to their model layer or module. Let's say that the AdamW 32-bit parameters of layer 20 occupy 6 GB but only 4 GB remains on a 48 GB GPU. With paged AdamW, these parameters will be moved together to the CPU. The GPU memory occupied is thus 44 GB.
For AdamW 8-bit, same situation, but the optimizer parameters are 4 times smaller (8-bit vs. 32-bit). The AdamW 8-bit parameters of layer 20 occupy 1.5 GB. There is enough space remaining on the GPU so they can be on the GPU. The GPU memory occupied is thus 45.5 GB, i.e., more than for the paged adamw 32-bit.
Thank you. Great explanation. I am now curious which other tricks can be used to fit llama 3.1 8b on a 24GB for full fine tuning. (I saw that Torchtune allows it. )
Great article. I have a question: Do the data types used by the optimizer need to match the data types used for training the model? For example, if I train the model using BF16 or FP32, do these need to be the same as the data type used by the optimizer? If not, theoretically, in any fine-tuning scenario, would using a quantized (8-bit) optimizer be the best choice?
I don't get the memory consumption. For the 8b model the "paged AdamW 8-bit" needs more memory than the "paged AdamW 32-bit" ??
I'm not sure but I think the optimizer parameters are moved from GPU to CPU as blocks attached to their model layer or module. Let's say that the AdamW 32-bit parameters of layer 20 occupy 6 GB but only 4 GB remains on a 48 GB GPU. With paged AdamW, these parameters will be moved together to the CPU. The GPU memory occupied is thus 44 GB.
For AdamW 8-bit, same situation, but the optimizer parameters are 4 times smaller (8-bit vs. 32-bit). The AdamW 8-bit parameters of layer 20 occupy 1.5 GB. There is enough space remaining on the GPU so they can be on the GPU. The GPU memory occupied is thus 45.5 GB, i.e., more than for the paged adamw 32-bit.
Thank you. Great explanation. I am now curious which other tricks can be used to fit llama 3.1 8b on a 24GB for full fine tuning. (I saw that Torchtune allows it. )
For full fine-tuning an 8b model with 24 GB, you need:
- gradient checkpointing
- FlashAttention
- bfloat16
- paged adamw 8bit
- batch size of 1 (or 2)
- short sequence length (less than 1024, maybe 512 or 256)
So yes it's possible but it won't perform well for tasks processing long sequences.
Great article. I have a question: Do the data types used by the optimizer need to match the data types used for training the model? For example, if I train the model using BF16 or FP32, do these need to be the same as the data type used by the optimizer? If not, theoretically, in any fine-tuning scenario, would using a quantized (8-bit) optimizer be the best choice?
The data type don't need to be the same. They are independent.
Indeed, AdamW 8-bit is probably our best option.
When doing pre-training, rather than SFT, is the conclusion that "AdamW 8-bit is probably our best option" valid?
Good question! I don't do pretraining often enough to be 100% sure but I believe 8bit AdamW might also work well for pretraining.