diff --git a/training/DeepSpeed-ZenFlow/benchmark/README.md b/training/DeepSpeed-ZenFlow/benchmark/README.md new file mode 100644 index 000000000..1b2385104 --- /dev/null +++ b/training/DeepSpeed-ZenFlow/benchmark/README.md @@ -0,0 +1,74 @@ +# ZenFlow Benchmark Example + + +Please install DeepSpeed via pip install deepspeed if you haven't already done so. + +```bash +pip install -r requirements.txt +``` + + +The script `zf_benchmark.py ` demonstrates how to offload the state of a model. Here is the example usage. + +```python +$ deepspeed --num_gpus=4 zf_benchmark.py --hidden_dim 4096 --nlayers 4 --iteration 5 --pin_memory_opts 1 --topk_ratios 0.1 --update_intervals 2 --overlap_steps +... +time (ms) | selective_optimizer_update: 19.20 | selective_optimizer_process: 28.80 | selective_optimizer_sync: 0.05 +time (ms) | fwd_microstep: 54.76 | bwd_microstep: 122.95 | bwd_inner_microstep: 12.22 | bwd_allreduce_microstep: 103.64 | step_microstep: 0.34 +Step 0 time: 178.66ms +time (ms) | optimizer_allgather: 26.19 | optimizer_gradients: 26.06 | optimizer_step: 128.20 +time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 0.57 | selective_optimizer_step: 1.48 | selective_optimizer_sync: 0.00 +time (ms) | fwd_microstep: 0.38 | bwd_microstep: 57.88 | bwd_inner_microstep: 1.06 | bwd_allreduce_microstep: 56.50 | step_microstep: 183.27 +time (ms) | fwd: 55.15 | bwd: 180.82 | bwd_inner: 13.28 | bwd_allreduce: 160.15 | step: 183.61 +Step 1 time: 242.16ms +time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 1.58 | selective_optimizer_step: 0.00 | selective_optimizer_sync: 0.00 +time (ms) | fwd_microstep: 0.30 | bwd_microstep: 16.73 | bwd_inner_microstep: 1.39 | bwd_allreduce_microstep: 14.96 | step_microstep: 0.20 +Step 2 time: 17.60ms +time (ms) | optimizer_allgather: 0.65 | optimizer_gradients: 16.95 | optimizer_step: 108.45 +time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 0.56 | selective_optimizer_step: 1.42 | selective_optimizer_sync: 0.00 +time (ms) | fwd_microstep: 0.29 | bwd_microstep: 36.65 | bwd_inner_microstep: 0.95 | bwd_allreduce_microstep: 35.51 | step_microstep: 128.57 +time (ms) | fwd: 0.59 | bwd: 53.39 | bwd_inner: 2.33 | bwd_allreduce: 50.48 | step: 128.77 +Step 3 time: 166.10ms +time (ms) | selective_optimizer_update: 0.00 | selective_optimizer_process: 1.57 | selective_optimizer_step: 0.00 | selective_optimizer_sync: 0.00 +time (ms) | fwd_microstep: 0.31 | bwd_microstep: 15.47 | bwd_inner_microstep: 1.33 | bwd_allreduce_microstep: 13.97 | step_microstep: 0.23 +... +[Summary] pin_memory=False topk_ratio=0.1 update_interval=2 overlap_step=False avg_accumulation_step=16.77ms avg_update_step=171.38ms +``` + +`run_benchmark.sh` shows how to run the script with different configurations. The script outputs the time for offloading and loading the states. + +```python +$ ./run_benchmark.sh +... ++---------+--------------+-------------------+----------------+--------------+-----------------+----------------+----------------+-------------------------------------+ +| trial | topk_ratio | update_interval | overlap_step | pin_memory | avg_step (ms) | avg_bwd (ms) | avg_fwd (ms) | avg_selective_optimizer_step (ms) | +|---------+--------------+-------------------+----------------+--------------+-----------------+----------------+----------------+-------------------------------------| +| 1 | 0.1 | 2 | False | False | 24.0153 | 12.8377 | 1.91733 | 0.247 | +| 1 | 0.1 | 2 | False | False | 22.8293 | 12.5187 | 1.73767 | 0.258333 | +| 1 | 0.1 | 2 | False | True | 21.6523 | 10.2863 | 1.97767 | 0.250333 | +| 1 | 0.1 | 4 | False | False | 14.2108 | 10.9072 | 1.2436 | 0.1484 | +| 1 | 0.1 | 4 | False | False | 13.6408 | 10.8386 | 1.2208 | 0.1456 | +| 1 | 0.1 | 4 | False | True | 12.863 | 9.0592 | 1.2148 | 0.1464 |... +``` + + +**Notes:** Each row in the table represents the average performance metrics for a specific configuration of ZenFlow’s offloading setup, defined by: + +- **`topk_ratio`**: The fraction of parameters selected for offloading during each update. +- **`update_interval`**: How often (in steps) the offloading state is updated. +- **`overlap_step`**: Whether overlapping offloading with computation is enabled. +- **`pin_memory`**: Whether pinned host memory is used to speed up data transfer between CPU and GPU. + +The performance metrics include: + +- **`avg_step (ms)`**: Total time per training step — the primary measure of end-to-end training performance. +- **`avg_bwd (ms)`**: Time spent in the backward pass, including gradient computation and allreduce. +- **`avg_fwd (ms)`**: Time spent in the forward pass. +- **`avg_selective_optimizer_step (ms)`**: Time spent in the selective optimizer step — indicates overhead introduced by ZenFlow’s offloading logic. + +**Tips for Analysis:** + +- Lower **`avg_step`** means faster training. +- Comparing configurations helps identify performance trade-offs (e.g., `pin_memory=True` often reduces transfer latency). +- A higher **`update_interval`** typically reduces offloading frequency and overhead. +- Enabling **`overlap_step=True`** can further hide offloading latency behind computation when the model update phase is longer. \ No newline at end of file diff --git a/training/DeepSpeed-ZenFlow/benchmark/output_table.py b/training/DeepSpeed-ZenFlow/benchmark/output_table.py new file mode 100644 index 000000000..20e932f08 --- /dev/null +++ b/training/DeepSpeed-ZenFlow/benchmark/output_table.py @@ -0,0 +1,79 @@ +import re +from collections import defaultdict +import pandas as pd +from tabulate import tabulate + +def parse_log_file(log_file_path): + with open(log_file_path, 'r') as f: + lines = f.readlines() + + # Regex patterns + trial_header_re = re.compile( + r"\[Trial (\d+)] pin_memory=(\d), topk=([\d.]+), update=(\d+), overlap_step=(\d+) \(MASTER_PORT=\d+\)" + ) + time_metrics_re = re.compile(r"\|\s*([^:|]+):\s*([\d.]+)") + + trials = [] + current_config = None + current_step_metrics = [] + + def finalize_trial(): + if current_config and current_step_metrics: + # Get all unique keys + all_keys = set() + for step in current_step_metrics: + all_keys.update(step.keys()) + # Aggregate and average + agg = {k: 0.0 for k in all_keys} + for step in current_step_metrics: + for k in all_keys: + agg[k] += step.get(k, 0.0) + avg = {f"avg_{k}": agg[k] / len(current_step_metrics) for k in all_keys} + trials.append({**current_config, **avg, "num_steps": len(current_step_metrics)}) + + for line in lines: + header_match = trial_header_re.search(line) + if header_match: + finalize_trial() + trial_id, pin_memory, topk, update, overlap = header_match.groups() + current_config = { + "trial": int(trial_id), + "pin_memory": bool(int(pin_memory)), + "topk_ratio": float(topk), + "update_interval": int(update), + "overlap_step": bool(int(overlap)) + } + current_step_metrics = [] + continue + + if "[Rank 0]" in line and "time (ms)" in line: + metrics = {k.strip(): float(v) for k, v in time_metrics_re.findall(line)} + current_step_metrics.append(metrics) + + finalize_trial() + return pd.DataFrame(trials) + +if __name__ == "__main__": + + log_file = "zf_benchmark.log" + df = parse_log_file(log_file) + df = df.sort_values(by=["topk_ratio", "overlap_step", "update_interval", "pin_memory"]) + cols_to_display = [ + "trial", "topk_ratio", "update_interval", "overlap_step", "pin_memory", + "avg_step", "avg_bwd", "avg_fwd", "avg_selective_optimizer_step" + ] + + headers_with_units = { + "trial": "trial", + "pin_memory": "pin_memory", + "update_interval": "update_interval", + "overlap_step": "overlap_step", + "topk_ratio": "topk_ratio", + "avg_step": "avg_step (ms)", + "avg_bwd": "avg_bwd (ms)", + "avg_fwd": "avg_fwd (ms)", + "avg_selective_optimizer_step": "avg_selective_optimizer_step (ms)" + + } + headers = [headers_with_units[col] for col in cols_to_display] + print(tabulate(df[cols_to_display], headers=headers, tablefmt="psql", showindex=False)) diff --git a/training/DeepSpeed-ZenFlow/benchmark/requirements.txt b/training/DeepSpeed-ZenFlow/benchmark/requirements.txt new file mode 100644 index 000000000..3b55f553b --- /dev/null +++ b/training/DeepSpeed-ZenFlow/benchmark/requirements.txt @@ -0,0 +1,7 @@ +torch>=2.5.1 +deepspeed>=0.16.0 +datasets>=2.14.1 +transformers>=4.37.2 +numpy>=1.21.0 +tabulate +pandas \ No newline at end of file diff --git a/training/DeepSpeed-ZenFlow/benchmark/run_benchmark.sh b/training/DeepSpeed-ZenFlow/benchmark/run_benchmark.sh new file mode 100644 index 000000000..e79ec9031 --- /dev/null +++ b/training/DeepSpeed-ZenFlow/benchmark/run_benchmark.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +NGPUS=2 +HIDDEN_SIZE=4096 +NUM_LAYERS=4 +TRIALS=1 + +PIN_MEMORY_OPTS=(0 1) +TOPK_RATIOS=(0.1 0.2) +UPDATE_INTERVALS=(2 4) +OVERLAP_STEPS=(1 0) + +for pin_memory in "${PIN_MEMORY_OPTS[@]}"; do + for topk in "${TOPK_RATIOS[@]}"; do + for update in "${UPDATE_INTERVALS[@]}"; do + for overlap in "${OVERLAP_STEPS[@]}"; do + for ((trial=0; trial<$TRIALS; trial++)); do + # Generate a random port between 20000 and 65000 + MASTER_PORT=$((20000 + RANDOM % 45000)) + echo "[Trial $((trial+1))] pin_memory=$pin_memory, topk=$topk, update=$update, overlap_step=$overlap (MASTER_PORT=$MASTER_PORT)" | tee -a zf_benchmark.log + deepspeed --master_port $MASTER_PORT \ + --num_gpus=$NGPUS \ + zf_benchmark.py \ + --hidden_dim $HIDDEN_SIZE \ + --nlayers $NUM_LAYERS \ + --iteration 5 \ + --pin_memory_opts $pin_memory \ + --topk_ratios $topk \ + --update_intervals $update \ + --overlap_steps $overlap | tee -a zf_benchmark.log + done + done + done + done +done +python output_table.py diff --git a/training/DeepSpeed-ZenFlow/benchmark/zf_benchmark.py b/training/DeepSpeed-ZenFlow/benchmark/zf_benchmark.py new file mode 100644 index 000000000..f09b82a9b --- /dev/null +++ b/training/DeepSpeed-ZenFlow/benchmark/zf_benchmark.py @@ -0,0 +1,150 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +import torch +import deepspeed.comm as dist +import time + +import deepspeed + +class SimpleModel(torch.nn.Module): + + def __init__(self, hidden_dim, empty_grad=False, nlayers=1): + super(SimpleModel, self).__init__() + self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)]) + if empty_grad: + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + for l in self.linears: + x = l(x) + return self.cross_entropy_loss(x, y) + + +def random_dataset(total_samples, hidden_dim, device, dtype): + train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype) + train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) + train_dataset = torch.utils.data.TensorDataset(train_data, train_label) + return train_dataset + + +def random_dataloader(model, total_samples, hidden_dim, device, dtype): + batch_size = model.train_micro_batch_size_per_gpu() + train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) + return train_loader + + +def run_model(model, config_dict, hidden_dim, dtype, pin_memory, topk_ratio, update_interval, overlap_step, iteration): + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + + data_loader = random_dataloader(model=model, + total_samples=iteration, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + + time_step_list = [] + accumulation_step_time_list = [] + update_step_time_list = [] + + dist.barrier() + for i, batch in enumerate(data_loader): + step_start_time = time.time() + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + step_end_time = time.time() + step_time = step_end_time - step_start_time + if dist.get_rank() == 0: + print(f"Step {i} time: {step_time*1000:.2f}ms") + if i >= update_interval: + time_step_list.append(step_time) + if (i + 1) % update_interval == 0: + update_step_time_list.append(step_time) + else: + accumulation_step_time_list.append(step_time) + + if dist.get_rank() == 0: + with open("zenflow_report.log", "a") as f: + msg = f"{1 if pin_memory else 0}," \ + f"{topk_ratio}," \ + f"{update_interval}," \ + f"{overlap_step}," \ + f"{sum(accumulation_step_time_list) / len(accumulation_step_time_list):.2f}," \ + f"{sum(update_step_time_list) / len(update_step_time_list):.2f}" + f.write(f"{msg}\n") + print(f"[Summary] pin_memory={pin_memory} topk_ratio={topk_ratio} update_interval={update_interval} overlap_step={overlap_step} avg_accumulation_step={sum(accumulation_step_time_list) * 1000 / len(accumulation_step_time_list):.2f}ms avg_update_step={sum(update_step_time_list) * 1000 / len(update_step_time_list):.2f}ms") + + model.destroy() + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--nlayers", type=int, default=1) + parser.add_argument("--hidden_dim", type=int, default=1024) + parser.add_argument("--dtype", choices=['torch.bfloat16', 'torch.float16', 'torch.float32'], default='torch.bfloat16') + parser.add_argument("--iteration", type=int, default=5) + parser.add_argument("--local_rank", type=int, default=-1) + + parser.add_argument("--pin_memory_opts", type=int, required=True) + parser.add_argument("--topk_ratios", type=float, required=True) + parser.add_argument("--update_intervals", type=int, required=True) + parser.add_argument("--overlap_steps", type=int, required=True) + + # Optional: explicitly receive master_port (though deepspeed handles it via env) + parser.add_argument("--master_port", type=int, default=None) + + args = parser.parse_args() + dtype = eval(args.dtype) + + + pin_memory = bool(args.pin_memory_opts) + topk_ratio = args.topk_ratios + update_interval = args.update_intervals + overlap_step = bool(args.overlap_steps) + total_iteration = args.iteration * update_interval + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": pin_memory + }, + "zenflow": { + "topk_ratio": topk_ratio, + "update_interval": update_interval, + "full_warm_up_rounds": 0, + "overlap_step": overlap_step + }, + }, + "wall_clock_breakdown": True, + "zero_allow_untested_optimizer": True + } + + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = SimpleModel(args.hidden_dim, nlayers=args.nlayers) + run_model(model, config_dict, args.hidden_dim, dtype, + pin_memory, topk_ratio, update_interval, overlap_step, + total_iteration) + + +if __name__ == "__main__": + main() diff --git a/training/DeepSpeed-ZenFlow/finetuning/README.md b/training/DeepSpeed-ZenFlow/finetuning/README.md new file mode 100644 index 000000000..6594bee7c --- /dev/null +++ b/training/DeepSpeed-ZenFlow/finetuning/README.md @@ -0,0 +1,87 @@ + +# ZenFlow Llama-2 Fine-Tuning Example + +This project demonstrates how to fine-tune a [Llama-2](https://huggingface.co/meta-llama) model using [DeepSpeed](https://www.deepspeed.ai/) with **ZenFlow**, a stall-free offloading engine for large-scale model training. + +## Quick Start + +1. **Install dependencies** + +```bash +pip install -r requirements.txt +``` + +2. **Configure training** + +Edit `zf_config.json` to enable ZenFlow: + +```json +"zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "zenflow": { + "topk_ratio": 0.1, + "update_interval": 4, + "full_warm_up_rounds": 0, + "overlap_step": true + } +} +``` + +3. **Run fine-tuning** + +```bash +bash finetune_llama.sh +``` + +This runs LLaMA-2 fine-tuning using DeepSpeed + ZenFlow, saving checkpoints to `./alpaca_output`. + +## Example Output + +Below is a sample log showing step time and loss values. You can see significant speedup after the first full step: + +``` +ZenFlowCPUAdam initialized with overlap step. +Step 5, Loss: 1.2599, Time: 719.58ms +Step 6, Loss: 0.9847, Time: 702.81ms <-- gradient accumulation with overlapped update +Step 7, Loss: 0.6220, Time: 705.50ms +Step 8, Loss: 0.5173, Time: 1912.92ms <-- full optimizer step of remaining part and update parameters +Step 9, Loss: 0.4557, Time: 890.60ms +Step 10, Loss: 0.3882, Time: 740.11ms +Step 11, Loss: 0.3627, Time: 731.95ms +Step 12, Loss: 0.3341, Time: 2221.18ms +Step 13, Loss: 0.2453, Time: 1061.80ms +``` + +## Key Insight +Steps like 5,6 and 7 are accumulation steps where ZenFlow overlaps part of the optimizer step in the background. These steps remain fast (~700ms). + +Steps 8 performs the remaining part of optimizer step and updates parameters to the GPU (2–2.2s). + +Without ZenFlow, a full update would take nearly 4 seconds, and ZenFlow distributes half of this cost across earlier accumulation steps via asynchronous overlap. + +This demonstrates how ZenFlow hides much of the CPU offload cost, enabling near stall-free training. Crucially, ZenFlow not only overlaps the CPU optimizer step but also maintains training progress on the GPU by immediately updating the most important gradients. + +## Notes + +- To change model, batch size, or epochs, modify `finetune_llama.sh`. +- All DeepSpeed and ZenFlow options are controlled via `zf_config.json`. + +## Citation + +To cite DeepSpeed Chat, please cite our [arxiv report](https://arxiv.org/abs/2505.12242): + +```bib +@misc{lan2025zenflowenablingstallfreeoffloading, + title={ZenFlow: Enabling Stall-Free Offloading Training via Asynchronous Updates}, + author={Tingfeng Lan and Yusen Wu and Bin Ma and Zhaoyuan Su and Rui Yang and Tekin Bicer and Masahiro Tanaka and Olatunji Ruwase and Dong Li and Yue Cheng}, + year={2025}, + eprint={2505.12242}, + archivePrefix={arXiv}, + primaryClass={cs.DC}, + url={https://arxiv.org/abs/2505.12242}, +} +``` diff --git a/training/DeepSpeed-ZenFlow/finetuning/finetune_llama.py b/training/DeepSpeed-ZenFlow/finetuning/finetune_llama.py new file mode 100644 index 000000000..6978008ab --- /dev/null +++ b/training/DeepSpeed-ZenFlow/finetuning/finetune_llama.py @@ -0,0 +1,112 @@ +import torch +import time +import deepspeed +import argparse +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + default_data_collator +) +import random +import numpy as np +from deepspeed import comm as dist + +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + +def preprocess_alpaca(example, tokenizer, max_length=512): + prompt = f"### Instruction:\n{example['instruction']}\n\n" + if example.get("input", ""): + prompt += f"### Input:\n{example['input']}\n\n" + prompt += f"### Response:\n{example['output']}" + tokenized = tokenizer(prompt, truncation=True, max_length=max_length, padding="max_length") + tokenized["labels"] = tokenized["input_ids"].copy() + return tokenized + +def main(args): + set_seed(args.seed) + + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.bfloat16) + + # Load Alpaca 52K dataset + dataset = load_dataset("tatsu-lab/alpaca") + + tokenized_dataset = dataset["train"].map(lambda x: preprocess_alpaca(x, tokenizer), batched=False) + + # Create DataLoader - let DeepSpeed handle the actual batching + train_dataloader = DataLoader( + tokenized_dataset, + batch_size=1, # This will be overridden by DeepSpeed config + collate_fn=default_data_collator, + shuffle=True + ) + + # DeepSpeed will automatically parse the config file passed via --deepspeed argument + model_engine, optimizer, train_dataloader, lr_scheduler = deepspeed.initialize( + args=args, + model=model, + model_parameters=model.parameters(), + training_data=tokenized_dataset, + collate_fn=default_data_collator + ) + + model_engine.train() + global_step = 0 + + for epoch in range(args.num_train_epochs): + if dist.get_rank() == 0: + print(f"Starting epoch {epoch + 1}/{args.num_train_epochs}") + + for step, batch in enumerate(train_dataloader): + step_start_time = time.time() + batch = {k: v.to(model_engine.device) for k, v in batch.items()} + outputs = model_engine(**batch) + loss = outputs.loss + + model_engine.backward(loss) + model_engine.step() + + step_time = time.time() - step_start_time + global_step += 1 + + if dist.get_rank() == 0: # Print every 10 steps + print(f"Step {global_step}, Loss: {loss.item():.4f}, Time: {step_time*1000:.0f}ms") + + # Save model using DeepSpeed's save_checkpoint method + if dist.get_rank() == 0: + model_engine.save_checkpoint(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + print("Training complete!") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, required=True) + parser.add_argument('--local_rank', + type=int, + default=-1, + help='local rank passed from distributed launcher') + parser.add_argument("--lr", type=float, required=True) + parser.add_argument("--batch_size", type=int, required=True) + parser.add_argument("--weight_decay", type=float, default=0.01) + parser.add_argument("--warmup", type=float, default=0.01) + parser.add_argument("--num_train_epochs", type=int, default=3) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--seed", type=int, default=42) + parser = deepspeed.add_config_arguments(parser) + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/training/DeepSpeed-ZenFlow/finetuning/finetune_llama.sh b/training/DeepSpeed-ZenFlow/finetuning/finetune_llama.sh new file mode 100644 index 000000000..abd753a74 --- /dev/null +++ b/training/DeepSpeed-ZenFlow/finetuning/finetune_llama.sh @@ -0,0 +1,46 @@ +#!/bin/bash +export CUDA_DEVICE_MAX_CONNECTIONS=1 +GPUS_PER_NODE=2 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE * $NNODES)) + +# Model parameters +MODEL_NAME="meta-llama/Llama-2-7b-hf" +OUTPUT_DIR="./alpaca_output" +EPOCHS=3 +SEED=42 + +# ZenFlow config file path +DS_CONFIG_JSON="./zf_config.json" + +# Note: LR, batch_size, weight_decay are defined in the config file +# These parameters are kept for fallback only +LR=2e-5 +BATCH_SIZE=8 +WARMUP=0.03 +WEIGHT_DECAY=0.01 + +# Create output directory if it doesn't exist +mkdir -p $OUTPUT_DIR + +# DeepSpeed command +if [ -f "$DS_CONFIG_JSON" ]; then + echo "[INFO] Using DeepSpeed config file: $DS_CONFIG_JSON" + CMD="deepspeed --num_gpus=$GPUS_PER_NODE finetune_llama.py \ + --deepspeed_config=$DS_CONFIG_JSON \ + --model_name $MODEL_NAME \ + --num_train_epochs $EPOCHS \ + --lr $LR \ + --batch_size $BATCH_SIZE \ + --weight_decay $WEIGHT_DECAY \ + --output_dir $OUTPUT_DIR \ + --seed $SEED" +else + echo "[ERROR] DeepSpeed config file not found: $DS_CONFIG_JSON" + exit 1 +fi + +echo "[INFO] Running DeepSpeed training with ZenFlow:" +echo $CMD +eval $CMD \ No newline at end of file diff --git a/training/DeepSpeed-ZenFlow/finetuning/requirements.txt b/training/DeepSpeed-ZenFlow/finetuning/requirements.txt new file mode 100644 index 000000000..21c220e92 --- /dev/null +++ b/training/DeepSpeed-ZenFlow/finetuning/requirements.txt @@ -0,0 +1,5 @@ +torch>=2.5.1 +deepspeed>=0.16.0 +datasets>=2.14.1 +transformers>=4.37.2 +numpy>=1.21.0 diff --git a/training/DeepSpeed-ZenFlow/finetuning/zf_config.json b/training/DeepSpeed-ZenFlow/finetuning/zf_config.json new file mode 100644 index 000000000..bf492923c --- /dev/null +++ b/training/DeepSpeed-ZenFlow/finetuning/zf_config.json @@ -0,0 +1,30 @@ +{ + "train_batch_size": 8, + "bf16": { "enabled": true }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "zenflow": { + "topk_ratio": 0.1, + "update_interval": 4, + "full_warm_up_rounds": 0, + "overlap_step": true + } + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 2e-5, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 0.01 + } + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": 1.0, + "zero_allow_untested_optimizer": true +} +