diff --git a/_posts/2024-11-18-llama-into-torchtune.md b/_posts/2024-11-18-llama-into-torchtune.md new file mode 100644 index 000000000000..abe9f290987e --- /dev/null +++ b/_posts/2024-11-18-llama-into-torchtune.md @@ -0,0 +1,818 @@ +--- +layout: blog_detail +title: "Distilling Llama3.1 8B into 1B in torchtune" +author: Linda Wang, Evan Smothers, Kartikay Khandelwal +--- + +In this blog, we present a case study on distilling a Llama 3.1 8B model into Llama 3.2 1B using torchtune’s knowledge distillation recipe. We demonstrate how knowledge distillation (KD) can be used in post-training to improve instruction-following task performance and showcase how users can leverage the recipe. + + +## What is Knowledge Distillation? + +[Knowledge Distillation](https://arxiv.org/pdf/1503.02531) is a widely used compression technique that transfers knowledge from a larger (teacher) model to a smaller (student) model. Larger models have more parameters and capacity for knowledge, however, this larger capacity is also more computationally expensive to deploy. Knowledge distillation can be used to compress the knowledge of a larger model into a smaller model. The idea is that performance of smaller models can be improved by learning from larger model’s outputs. + + +## How does Knowledge Distillation work? + +Knowledge is transferred from the teacher to student model by training on a transfer set where the student is trained to imitate the token-level probability distributions of the teacher. The assumption is that the teacher model distribution is similar to the transfer dataset. The diagram below is a simplified representation of how KD works. + +{:style="width:100%"} + + +**Figure 1: Simplified representation of knowledge transfer from teacher to student model** + +As knowledge distillation for LLMs is an active area of research, there are papers, such as [MiniLLM](https://arxiv.org/pdf/2306.08543), [DistiLLM](https://arxiv.org/pdf/2402.03898), [AKL](https://arxiv.org/pdf/2404.02657), and [Generalized KD](https://arxiv.org/pdf/2306.13649), investigating different loss approaches. In this case study, we focus on the standard cross-entropy (CE) loss with the forward [Kullback-Leibler (KL) divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) loss as the baseline. Forward KL divergence aims to minimize the difference by forcing the student’s distribution to align with all of the teacher’s distributions. + + +## Why is Knowledge Distillation useful? + +The idea of knowledge distillation is that a smaller model can achieve better performance using a teacher model’s outputs as an additional signal than it could training from scratch or with supervised fine-tuning. For instance, [Llama 3.2 lightweight 1B and 3B text models](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/) incorporated logits from Llama 3.1 8B and 70B to recover performance after pruning. In addition, for fine-tuning on instruction-following tasks, research in LLM distillation demonstrates that knowledge distillation methods can outperform supervised fine-tuning (SFT) alone. + + +
Model + | +Method + | +DollyEval + | +Self-Inst + | +S-NI + | +
GPT-4 Eval + | +GPT-4 Eval + | +Rouge-L + | +||
Llama 7B + | +SFT + | +73.0 + | +69.2 + | +32.4 + | +
KD + | +73.7 + | +70.5 + | +33.7 + | +|
MiniLLM + | +76.4 + | +73.1 + | +35.5 + | +|
Llama 1.1B + | +SFT + | +22.1 + | +- + | +27.8 + | +
KD + | +22.2 + | +- + | +28.1 + | +|
AKL + | +24.4 + | +- + | +31.4 + | +|
OpenLlama 3B + | +SFT + | +47.3 + | +41.7 + | +29.3 + | +
KD + | +44.9 + | +42.1 + | +27.9 + | +|
SeqKD + | +48.1 + | +46.0 + | +29.1 + | +|
DistiLLM + | +59.9 + | +53.3 + | +37.6 + | +
Supervised fine-tuning + | +Knowledge distillation + | +
---|---|
+
+
+ |
+
+
+
+ |
+
Model + | +TruthfulQA + | +hellaswag + | +commonsense + | +|
mc2 + | +acc + | +acc_norm + | +acc + | +|
Baseline Llama 3.1 8B + | +0.5401 + | +0.5911 + | +0.7915 + | +0.7707 + | +
Fine-tuned Llama 3.1 8B using LoRA + | +0.5475 + | +0.6031 + | +0.7951 + | +0.7789 + | +
Baseline Llama 3.2 1B + | +0.4384 + | +0.4517 + | +0.6064 + | +0.5536 + | +
Fine-tuned Llama 3.2 1B using LoRA + | +0.4492 + | +0.4595 + | +0.6132 + | +0.5528 + | +
KD using baseline 8B as teacher + | +0.444 + | +0.4576 + | +0.6123 + | +0.5561 + | +
KD using fine-tuned 8B as teacher + | +0.4481 + | +0.4603 + | +0.6157 + | +0.5569 + | +
Model + | +TruthfulQA + | +hellaswag + | +commonsense + | +|
mc2 + | +acc + | +acc_norm + | +acc + | +|
Baseline Llama 3.1 8B + | +0.5401 + | +0.5911 + | +0.7915 + | +0.7707 + | +
Fine-tuned Llama 3.1 8B using LoRA + | +0.5475 + | +0.6031 + | +0.7951 + | +0.7789 + | +
Baseline Llama 3.2 1B + | +0.4384 + | +0.4517 + | +0.6064 + | +0.5536 + | +
Fine-tuned Llama 3.2 1B using LoRA + | +0.4492 + | +0.4595 + | +0.6132 + | +0.5528 + | +
KD using baseline 8B and baseline 1B + | +0.444 + | +0.4576 + | +0.6123 + | +0.5561 + | +
KD using baseline 8B and fine-tuned 1B + | +0.4508 + | +0.448 + | +0.6004 + | +0.5274 + | +
KD using fine-tuned 8B and baseline 1B + | +0.4481 + | +0.4603 + | +0.6157 + | +0.5569 + | +
KD using fine-tuned 8B and fine-tuned 1B + | +0.4713 + | +0.4512 + | +0.599 + | +0.5233 + | +
Model + | +learning rate + | +TruthfulQA + | +hellaswag + | +commonsense + | +|
mc2 + | +acc + | +acc_norm + | +acc + | +||
Baseline Llama 3.1 8B + | +- + | +0.5401 + | +0.5911 + | +0.7915 + | +0.7707 + | +
Fine-tuned Llama 3.1 8B using LoRA + | +- + | +0.5475 + | +0.6031 + | +0.7951 + | +0.7789 + | +
Baseline Llama 3.2 1B + | +- + | +0.4384 + | +0.4517 + | +0.6064 + | +0.5536 + | +
Fine-tuned Llama 3.2 1B using LoRA + | +- + | +0.4492 + | +0.4595 + | +0.6132 + | +0.5528 + | +
KD using fine-tuned 8B and baseline 1B + | +3e-4 + | +0.4481 + | +0.4603 + | +0.6157 + | +0.5569 + | +
KD using fine-tuned 8B and baseline 1B + | +1e-3 + | +0.4453 + | +0.4535 + | +0.6071 + | +0.5258 + | +
KD using fine-tuned 8B and baseline 1B + | +1e-4 + | +0.4489 + | +0.4606 + | +0.6156 + | +0.5586 + | +
KD using fine-tuned 8B and baseline 1B + | +1e-5 + | +0.4547 + | +0.4548 + | +0.6114 + | +0.5487 + | +
Model + | +kd_ratio (lr=3e-4) + | +TruthfulQA + | +hellaswag + | +commonsense + | +|
mc2 + | +acc + | +acc_norm + | +acc + | +||
Baseline Llama 3.1 8B + | +- + | +0.5401 + | +0.5911 + | +0.7915 + | +0.7707 + | +
Fine-tuned Llama 3.1 8B using LoRA + | +- + | +0.5475 + | +0.6031 + | +0.7951 + | +0.7789 + | +
Baseline Llama 3.2 1B + | +- + | +0.4384 + | +0.4517 + | +0.6064 + | +0.5536 + | +
Fine-tuned Llama 3.2 1B using LoRA + | +- + | +0.4492 + | +0.4595 + | +0.6132 + | +0.5528 + | +
KD using fine-tuned 8B and baseline 1B + | +0.25 + | +0.4485 + | +0.4595 + | +0.6155 + | +0.5602 + | +
KD using fine-tuned 8B and baseline 1B + | +0.5 + | +0.4481 + | +0.4603 + | +0.6157 + | +0.5569 + | +
KD using fine-tuned 8B and baseline 1B + | +0.75 + | +0.4543 + | +0.463 + | +0.6189 + | +0.5643 + | +
KD using fine-tuned 8B and baseline 1B + | +1.0 + | +0.4537 + | +0.4641 + | +0.6177 + | +0.5717 + | +