Skip to content

Commit ca441f5

Browse files
AntleraJoshWoo2003
andcommitted
Add Llama-2 fine-tuning scripts and configuration for ZenFlow
- Introduced `finetune_llama.py` for fine-tuning the Llama-2 model using DeepSpeed and ZenFlow. - Added `finetune_llama.sh` for automated training setup with environment variables and DeepSpeed command. - Added `zf_config.json` example for DeepSpeed configuration with ZenFlow optimizations. Signed-off-by: Tingfeng Lan <[email protected]> Co-authored-by: Yusen Wu <[email protected]>
1 parent e877698 commit ca441f5

File tree

5 files changed

+273
-0
lines changed

5 files changed

+273
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
2+
# ZenFlow Llama-2 Fine-Tuning Example
3+
4+
This project demonstrates how to fine-tune a [Llama-2](https://huggingface.co/meta-llama) model using [DeepSpeed](https://www.deepspeed.ai/) with **ZenFlow**, a stall-free offloading engine for large-scale model training.
5+
6+
## Quick Start
7+
8+
1. **Install dependencies**
9+
10+
```bash
11+
pip install -r requirements.txt
12+
```
13+
14+
2. **Configure training**
15+
16+
Edit `zf_config.json` to enable ZenFlow:
17+
18+
```json
19+
"zero_optimization": {
20+
"stage": 2,
21+
"offload_optimizer": {
22+
"device": "cpu",
23+
"pin_memory": true
24+
},
25+
"zenflow": {
26+
"topk_ratio": 0.1,
27+
"update_interval": 4,
28+
"full_warm_up_rounds": 0,
29+
"overlap_step": true
30+
}
31+
}
32+
```
33+
34+
3. **Run fine-tuning**
35+
36+
```bash
37+
bash finetune_llama.sh
38+
```
39+
40+
This runs LLaMA-2 fine-tuning using DeepSpeed + ZenFlow, saving checkpoints to `./alpaca_output`.
41+
42+
## Example Output
43+
44+
Below is a sample log showing step time and loss values. You can see significant speedup after the first full step:
45+
46+
```
47+
ZenFlowCPUAdam initialized with overlap step.
48+
Step 5, Loss: 1.2599, Time: 719.58ms
49+
Step 6, Loss: 0.9847, Time: 702.81ms
50+
Step 7, Loss: 0.6220, Time: 705.50ms
51+
Step 8, Loss: 0.5173, Time: 1912.92ms
52+
Step 9, Loss: 0.4557, Time: 890.60ms
53+
Step 10, Loss: 0.3882, Time: 740.11ms
54+
Step 11, Loss: 0.3627, Time: 731.95ms
55+
Step 12, Loss: 0.3341, Time: 2221.18ms
56+
Step 13, Loss: 0.2453, Time: 1061.80ms
57+
```
58+
59+
ZenFlow reduces optimizer-induced stalls by overlapping CPU computation and GPU execution.
60+
61+
## Notes
62+
63+
- To change model, batch size, or epochs, modify `finetune_llama.sh`.
64+
- All DeepSpeed and ZenFlow options are controlled via `zf_config.json`.
65+
66+
## Citation
67+
68+
To cite DeepSpeed Chat, please cite our [arxiv report](https://arxiv.org/abs/2505.12242):
69+
70+
```bib
71+
@misc{lan2025zenflowenablingstallfreeoffloading,
72+
title={ZenFlow: Enabling Stall-Free Offloading Training via Asynchronous Updates},
73+
author={Tingfeng Lan and Yusen Wu and Bin Ma and Zhaoyuan Su and Rui Yang and Tekin Bicer and Dong Li and Yue Cheng},
74+
year={2025},
75+
eprint={2505.12242},
76+
archivePrefix={arXiv},
77+
primaryClass={cs.DC},
78+
url={https://arxiv.org/abs/2505.12242},
79+
}
80+
```
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import torch
2+
import time
3+
import deepspeed
4+
import argparse
5+
from datasets import load_dataset
6+
from torch.utils.data import DataLoader
7+
from transformers import (
8+
AutoModelForCausalLM,
9+
AutoTokenizer,
10+
default_data_collator
11+
)
12+
import random
13+
import numpy as np
14+
from deepspeed import comm as dist
15+
16+
import os
17+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
18+
19+
20+
def set_seed(seed):
21+
random.seed(seed)
22+
np.random.seed(seed)
23+
torch.manual_seed(seed)
24+
if torch.cuda.is_available():
25+
torch.cuda.manual_seed_all(seed)
26+
27+
def preprocess_alpaca(example, tokenizer, max_length=512):
28+
prompt = f"### Instruction:\n{example['instruction']}\n\n"
29+
if example.get("input", ""):
30+
prompt += f"### Input:\n{example['input']}\n\n"
31+
prompt += f"### Response:\n{example['output']}"
32+
tokenized = tokenizer(prompt, truncation=True, max_length=max_length, padding="max_length")
33+
tokenized["labels"] = tokenized["input_ids"].copy()
34+
return tokenized
35+
36+
def main(args):
37+
set_seed(args.seed)
38+
39+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
40+
if tokenizer.pad_token is None:
41+
tokenizer.pad_token = tokenizer.eos_token
42+
43+
model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.bfloat16)
44+
45+
# Load Alpaca 52K dataset
46+
dataset = load_dataset("tatsu-lab/alpaca")
47+
48+
tokenized_dataset = dataset["train"].map(lambda x: preprocess_alpaca(x, tokenizer), batched=False)
49+
50+
# Create DataLoader - let DeepSpeed handle the actual batching
51+
train_dataloader = DataLoader(
52+
tokenized_dataset,
53+
batch_size=1, # This will be overridden by DeepSpeed config
54+
collate_fn=default_data_collator,
55+
shuffle=True
56+
)
57+
58+
# DeepSpeed will automatically parse the config file passed via --deepspeed argument
59+
model_engine, optimizer, train_dataloader, lr_scheduler = deepspeed.initialize(
60+
args=args,
61+
model=model,
62+
model_parameters=model.parameters(),
63+
training_data=tokenized_dataset,
64+
collate_fn=default_data_collator
65+
)
66+
67+
model_engine.train()
68+
global_step = 0
69+
70+
for epoch in range(args.num_train_epochs):
71+
if dist.get_rank() == 0:
72+
print(f"Starting epoch {epoch + 1}/{args.num_train_epochs}")
73+
74+
for step, batch in enumerate(train_dataloader):
75+
step_start_time = time.time()
76+
batch = {k: v.to(model_engine.device) for k, v in batch.items()}
77+
outputs = model_engine(**batch)
78+
loss = outputs.loss
79+
80+
model_engine.backward(loss)
81+
model_engine.step()
82+
83+
step_time = time.time() - step_start_time
84+
global_step += 1
85+
86+
if dist.get_rank() == 0: # Print every 10 steps
87+
print(f"Step {global_step}, Loss: {loss.item():.4f}, Time: {step_time*1000:.0f}ms")
88+
89+
# Save model using DeepSpeed's save_checkpoint method
90+
if dist.get_rank() == 0:
91+
model_engine.save_checkpoint(args.output_dir)
92+
tokenizer.save_pretrained(args.output_dir)
93+
print("Training complete!")
94+
95+
if __name__ == "__main__":
96+
parser = argparse.ArgumentParser()
97+
parser.add_argument("--model_name", type=str, required=True)
98+
parser.add_argument('--local_rank',
99+
type=int,
100+
default=-1,
101+
help='local rank passed from distributed launcher')
102+
parser.add_argument("--lr", type=float, required=True)
103+
parser.add_argument("--batch_size", type=int, required=True)
104+
parser.add_argument("--weight_decay", type=float, default=0.01)
105+
parser.add_argument("--warmup", type=float, default=0.01)
106+
parser.add_argument("--num_train_epochs", type=int, default=3)
107+
parser.add_argument("--output_dir", type=str, required=True)
108+
parser.add_argument("--seed", type=int, default=42)
109+
parser = deepspeed.add_config_arguments(parser)
110+
args = parser.parse_args()
111+
112+
main(args)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/bin/bash
2+
export CUDA_DEVICE_MAX_CONNECTIONS=1
3+
GPUS_PER_NODE=2
4+
NNODES=1
5+
NODE_RANK=0
6+
WORLD_SIZE=$(($GPUS_PER_NODE * $NNODES))
7+
8+
# Model parameters
9+
MODEL_NAME="meta-llama/Llama-2-7b-hf"
10+
OUTPUT_DIR="./alpaca_output"
11+
EPOCHS=3
12+
SEED=42
13+
14+
# ZenFlow config file path
15+
DS_CONFIG_JSON="./zf_config.json"
16+
17+
# Note: LR, batch_size, weight_decay are defined in the config file
18+
# These parameters are kept for fallback only
19+
LR=2e-5
20+
BATCH_SIZE=32
21+
WARMUP=0.03
22+
WEIGHT_DECAY=0.01
23+
24+
# Create output directory if it doesn't exist
25+
mkdir -p $OUTPUT_DIR
26+
27+
# DeepSpeed command
28+
if [ -f "$DS_CONFIG_JSON" ]; then
29+
echo "[INFO] Using DeepSpeed config file: $DS_CONFIG_JSON"
30+
CMD="deepspeed --num_gpus=$GPUS_PER_NODE finetune_llama.py \
31+
--deepspeed_config=$DS_CONFIG_JSON \
32+
--model_name $MODEL_NAME \
33+
--num_train_epochs $EPOCHS \
34+
--lr $LR \
35+
--batch_size $BATCH_SIZE \
36+
--weight_decay $WEIGHT_DECAY \
37+
--output_dir $OUTPUT_DIR \
38+
--seed $SEED"
39+
else
40+
echo "[ERROR] DeepSpeed config file not found: $DS_CONFIG_JSON"
41+
exit 1
42+
fi
43+
44+
echo "[INFO] Running DeepSpeed training with ZenFlow:"
45+
echo $CMD
46+
eval $CMD
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
torch>=2.5.1
2+
deepspeed>=0.16.0
3+
datasets>=2.14.1
4+
transformers>=4.37.2
5+
numpy>=1.21.0
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"train_batch_size": 32,
3+
"bf16": { "enabled": true },
4+
"zero_optimization": {
5+
"stage": 2,
6+
"offload_optimizer": {
7+
"device": "cpu",
8+
"pin_memory": true
9+
},
10+
"zenflow": {
11+
"topk_ratio": 0.1,
12+
"update_interval": 4,
13+
"full_warm_up_rounds": 0,
14+
"overlap_step": true
15+
}
16+
},
17+
"optimizer": {
18+
"type": "AdamW",
19+
"params": {
20+
"lr": 2e-5,
21+
"betas": [0.9, 0.999],
22+
"eps": 1e-8,
23+
"weight_decay": 0.01
24+
}
25+
},
26+
"gradient_accumulation_steps": 1,
27+
"gradient_clipping": 1.0,
28+
"zero_allow_untested_optimizer": true
29+
}
30+

0 commit comments

Comments
 (0)