Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions training/DeepSpeed-SuperOffload/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@

# SuperOffload Fine-Tuning Examples

This directory shows how to fine‑tune popular large language models using [DeepSpeed](https://www.deepspeed.ai/) ZeRO Stage 3 with **SuperOffload**. SuperOffload is an optimized CPU offloading engine for full‑parameter training on emerging “Superchips” (NVIDIA GH200 / GB200, AMD MI300A) that provide very high CPU↔GPU bandwidth. It enables:

* 1× GH200: GPT-OSS-20B, Qwen3-14B, Phi-4
* 2× GH200: Seed-OSS-36B, Qwen3-30B-A3B
* 4× GH200: Llama-70B

With common sequence length and batch size, SuperOffload can deliver up to ~500 TFLOPS on GH200—about 50% higher throughput than ZeRO-Offload.

## Quick Start

### 1. Install dependencies

```bash
pip install -r requirements.txt
```

### 2. No custom model code required

All examples use Hugging Face Transformers and DeepSpeed ZeRO Stage 3, no custom modeling code required.

### 3. Enable SuperOffload (one line)

Add the `super_offload` flag to the `offload_optimizer` block in the ZeRO Stage 3 DeepSpeed config:

```jsonc
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true,
"ratio": 0.90,
"super_offload": true,
"cpuadam_cores_perc": 0.90
}
}
```

To fall back to ZeRO-Offload, remove `"super_offload": true` (and optionally `cpuadam_cores_perc`).

### 4. Run a fine-tuning script

Fine-tune GPT-OSS-20B (1× GH200):

```bash
bash finetune_gpt-oss-20b_1gpu.sh superoffload
```

Fine-tune Qwen3-14B (1× GH200):

```bash
bash finetune_qwen3-14b_1gpu.sh superoffload
```

Fine-tune Phi-4 (1× GH200):

```bash
bash finetune_phi-4_1gpu.sh superoffload
```

Fine-tune Llama 8B (1× GH200):

```bash
bash finetune_llama-8b_1gpu.sh superoffload
```

Fine-tune Seed-OSS-36B (2× GH200):

```bash
bash finetune_seed-oss-36b_2gpu.sh superoffload
```

Fine-tune Llama 70B (4× GH200):

```bash
bash finetune_llama-70b_4gpu.sh superoffload
```

Switch to ZeRO-Offload by replacing `superoffload` with `zerooffload` in the first argument.

Each script optionally accepts a second argument for batch size (default 4):

```bash
bash finetune_qwen3-14b_1gpu.sh superoffload 8
```

Logs, DeepSpeed configs, and outputs are written beside the script location (e.g. `qwen3-14b_superoffload_output/`).


> If a script is missing for a larger model, copy an existing one, change `MODEL_NAME`, and update output naming.


## Notes

* NUMA Binding is required for efficient training on GH200. Each GPU is paired with a CPU to ensure that the training process is launched on the CPU directly associated with that GPU. This pairing improves affinity, delivering higher CPU–GPU bandwidth and greater throughput. In DeepSpeed, we provide a simple interface to enable NUMA binding: simply add the `--bind_cores_to_rank` flag when launching the DeepSpeed engine.
* Memory System Resource Partitioning and Monitoring (MPAM) is essential for achieving optimal throughput performance. In SuperOffload, GPU execution is overlapped with CPU-based Adam execution. MPAM helps reduce interference between these two processes, leading to smoother execution and better performance.

## Citation

If you use SuperOffload, please cite:

```bib
@inproceedings{superoffload,
author = {Xinyu Lian and Masahiro Tanaka and Olatunji Ruwase and Minjia Zhang},
title = "{SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips}",
year = {2026},
booktitle = {Proceedings of the 31st ACM International Conference on Architectural Support for Programming Languages and Operating System (ASPLOS'26)}
}
```
137 changes: 137 additions & 0 deletions training/DeepSpeed-SuperOffload/finetune_gpt-oss-20b_1gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#!/bin/bash
set -e

echo "================================================"
echo "GPT-OSS-20B Fine-tuning with DeepSpeed on 1 GPU"
echo "================================================"

# MODE=Options: "superoffload" or "zerooffload"
MODE=$1
BATCH_SIZE=${2:-4}

SCRIPT_DIR=$(dirname "$0")
MODEL_NAME="openai/gpt-oss-20b"
OUTPUT_DIR="${SCRIPT_DIR}/gpt-oss-20b_${MODE}_output"
DS_CONFIG_JSON="${SCRIPT_DIR}/gpt-oss-20b_${MODE}_config.json"

mkdir -p $OUTPUT_DIR

# Script argument parameters
ACTIVATION_CHECKPOINTING=true
SAVE_CHECKPOINT=false
MAX_LENGTH=8192
LOG_INTERVAL=1
DATASET_NAME="tatsu-lab/alpaca"
DATASET_PERCENTAGE=10.0
USE_WANDB=false
WANDB_PROJECT="gpt-oss-20b"
WANDB_RUN_NAME="gpt-oss-20b-$MODE"
DETERMINISTIC=false
BENCH_STEPS=10
WARMUP_STEPS=20

EPOCHS=1
LR=1e-5
WARMUP=0.05
WEIGHT_DECAY=0.01
SEED=42

ACTIVATION_CHECKPOINTING_FLAG=""
if [ "$ACTIVATION_CHECKPOINTING" = "true" ]; then
ACTIVATION_CHECKPOINTING_FLAG="--activation_checkpointing"
fi

SAVE_CHECKPOINT_ARG=""
if [ "$SAVE_CHECKPOINT" = "true" ]; then
SAVE_CHECKPOINT_ARG="--save_checkpoint"
fi

WANDB_FLAG=""
if [ "$USE_WANDB" = "true" ]; then
WANDB_FLAG="--use_wandb"
fi

DETERMINISTIC_FLAG=""
if [ "$DETERMINISTIC" = "true" ]; then
DETERMINISTIC_FLAG="--deterministic"
fi

# Create DeepSpeed configuration file
if [ "$MODE" = "superoffload" ]; then
cat > "$DS_CONFIG_JSON" << EOF
{
"train_batch_size": $BATCH_SIZE,
"gradient_accumulation_steps": 1,
"bf16": { "enabled": true },
"zero_optimization": {
"stage": 3,
"overlap_comm": false,
"reduce_bucket_size": 8e8,
"sub_group_size": 8e8,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true,
"ratio": 0.90,
"super_offload": true,
"cpuadam_cores_perc": 0.90
}
},
"wall_clock_breakdown": true
}
EOF

elif [ "$MODE" = "zerooffload" ]; then
cat > "$DS_CONFIG_JSON" << EOF
{
"train_batch_size": $BATCH_SIZE,
"gradient_accumulation_steps": 1,
"bf16": { "enabled": true },
"zero_optimization": {
"stage": 3,
"overlap_comm": false,
"reduce_bucket_size": 8e8,
"sub_group_size": 8e8,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"wall_clock_breakdown": true
}
EOF
fi

# Set number of GPUs
GPUS_PER_NODE=1

CMD="deepspeed --num_gpus=$GPUS_PER_NODE finetune_zero3.py \
--deepspeed_config=$DS_CONFIG_JSON \
--model_name $MODEL_NAME \
--leaf_module "GptOssExperts" \
--num_train_epochs $EPOCHS \
--lr $LR \
--batch_size $BATCH_SIZE \
--weight_decay $WEIGHT_DECAY \
--output_dir $OUTPUT_DIR \
--seed $SEED \
--max_length $MAX_LENGTH \
--log_interval $LOG_INTERVAL \
--dataset_name $DATASET_NAME \
--dataset_percentage $DATASET_PERCENTAGE \
--bench_steps $BENCH_STEPS \
--warmup_steps $WARMUP_STEPS \
--attn_implementation eager \
$ACTIVATION_CHECKPOINTING_FLAG \
$SAVE_CHECKPOINT_ARG \
$WANDB_FLAG \
--wandb_project $WANDB_PROJECT \
--wandb_run_name $WANDB_RUN_NAME \
$DETERMINISTIC_FLAG"

echo "Starting training with MODE $MODE"
echo "================================================"
eval $CMD

echo "================================================"
echo "Training completed"
echo "================================================"
130 changes: 130 additions & 0 deletions training/DeepSpeed-SuperOffload/finetune_llama-70b_4gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/bin/bash
set -e

echo "================================================"
echo "Llama-3.3-70B-Instruct Fine-tuning with DeepSpeed on 4 GPU"
echo "================================================"

# MODE=Options: "superoffload" or "zerooffload"
MODE=$1
BATCH_SIZE=${2:-4}

SCRIPT_DIR=$(dirname "$0")
MODEL_NAME="meta-llama/Llama-3.3-70B-Instruct"
OUTPUT_DIR="${SCRIPT_DIR}/llama-3.3-70b-instruct_${MODE}_output"
DS_CONFIG_JSON="${SCRIPT_DIR}/llama-3.3-70b-instruct_${MODE}_config.json"

mkdir -p $OUTPUT_DIR

# Script argument parameters
ACTIVATION_CHECKPOINTING=true
SAVE_CHECKPOINT=false
MAX_LENGTH=4096
LOG_INTERVAL=1
DATASET_NAME="tatsu-lab/alpaca"
DATASET_PERCENTAGE=10.0
USE_WANDB=false
WANDB_PROJECT="llama-3.3-70b-instruct"
WANDB_RUN_NAME="llama-3.3-70b-instruct-$MODE"
DETERMINISTIC=false
BENCH_STEPS=10
WARMUP_STEPS=20

EPOCHS=1
LR=1e-5
WARMUP=0.05
WEIGHT_DECAY=0.01
SEED=42

ACTIVATION_CHECKPOINTING_FLAG=""
if [ "$ACTIVATION_CHECKPOINTING" = "true" ]; then
ACTIVATION_CHECKPOINTING_FLAG="--activation_checkpointing"
fi

SAVE_CHECKPOINT_ARG=""
if [ "$SAVE_CHECKPOINT" = "true" ]; then
SAVE_CHECKPOINT_ARG="--save_checkpoint"
fi

WANDB_FLAG=""
if [ "$USE_WANDB" = "true" ]; then
WANDB_FLAG="--use_wandb"
fi

DETERMINISTIC_FLAG=""
if [ "$DETERMINISTIC" = "true" ]; then
DETERMINISTIC_FLAG="--deterministic"
fi

# Create DeepSpeed configuration file
if [ "$MODE" = "superoffload" ]; then
cat > "$DS_CONFIG_JSON" << EOF
{
"train_batch_size": $BATCH_SIZE,
"gradient_accumulation_steps": 1,
"bf16": { "enabled": true },
"zero_optimization": {
"stage": 3,
"overlap_comm": false,
"reduce_bucket_size": 4e8,
"sub_group_size": 4e8,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true,
"ratio": 0.90,
"super_offload": true,
"cpuadam_cores_perc": 0.90
}
},
"wall_clock_breakdown": true
}
EOF

elif [ "$MODE" = "zerooffload" ]; then
cat > "$DS_CONFIG_JSON" << EOF
{
"train_batch_size": $BATCH_SIZE,
"gradient_accumulation_steps": 1,
"bf16": { "enabled": true },
"zero_optimization": {
"stage": 3,
"overlap_comm": false,
"reduce_bucket_size": 4e8,
"sub_group_size": 4e8,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"wall_clock_breakdown": true
}
EOF
fi

GPUS_PER_NODE=4

CMD="deepspeed --num_gpus=$GPUS_PER_NODE --bind_cores_to_rank finetune_zero3.py \
--deepspeed_config=$DS_CONFIG_JSON \
--model_name $MODEL_NAME \
--num_train_epochs $EPOCHS \
--lr $LR \
--batch_size $BATCH_SIZE \
--weight_decay $WEIGHT_DECAY \
--output_dir $OUTPUT_DIR \
--seed $SEED \
--max_length $MAX_LENGTH \
--log_interval $LOG_INTERVAL \
--dataset_name $DATASET_NAME \
--dataset_percentage $DATASET_PERCENTAGE \
--bench_steps $BENCH_STEPS \
--warmup_steps $WARMUP_STEPS \
$ACTIVATION_CHECKPOINTING_FLAG \
$SAVE_CHECKPOINT_ARG \
$WANDB_FLAG \
--wandb_project $WANDB_PROJECT \
--wandb_run_name $WANDB_RUN_NAME \
$DETERMINISTIC_FLAG"

echo "Starting training with MODE $MODE"
echo "================================================"
eval $CMD
Loading