Skip to content

Commit e457deb

Browse files
authored
Bring LLaMa 3.1 405B to TorchTitan family (#481)
With the official launch of LLaMa 3.1 model, we want to add the config to TorchTitan. Of course, there are more work to be done, but we want to go an incremental way. So more PRs will be needed. For now, we try on 128 GPUs with current config (TP=8, FSDP=16). The perf number is wps: 109 mfu: 29%. Loss curve for 3000 steps with 600 warmup (lr = 0.8e-4). <img width="1037" alt="image" src="https://github.com/user-attachments/assets/f57dd3fa-07d8-4ef4-8f68-8f7a08e9652e"> Loss curve for 3000 steps with 600 warmup (lr = 1.1e-4). ![image](https://github.com/user-attachments/assets/429b9738-94cb-4b37-90ef-049a5587ddd0)
1 parent 04d219a commit e457deb

File tree

4 files changed

+65
-3
lines changed

4 files changed

+65
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ Once you have confirmed access, you can run the following command to download th
7474
```bash
7575
# Get your HF token from https://huggingface.co/settings/tokens
7676

77-
# llama3 tokenizer.model
77+
# llama3 or 3.1 tokenizer.model
7878
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...
7979

8080
# llama2 tokenizer.model

torchtitan/datasets/download_tokenizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def hf_download(
2020

2121
try:
2222
hf_hub_download(
23-
repo_id,
24-
tokenizer_path,
23+
repo_id=repo_id,
24+
filename=tokenizer_path,
2525
local_dir=local_dir,
2626
local_dir_use_symlinks=False,
2727
token=hf_token,

torchtitan/models/llama/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,13 @@
4848
multiple_of=4096,
4949
rope_theta=500000,
5050
),
51+
"405B": ModelArgs(
52+
dim=16384,
53+
n_layers=126,
54+
n_heads=128,
55+
n_kv_heads=8,
56+
ffn_dim_multiplier=1.2,
57+
multiple_of=4096,
58+
rope_theta=500000,
59+
),
5160
}

train_configs/llama3_405b.toml

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# torchtitan Config.toml
2+
# NOTE: this toml config is a preset for 128 H100 GPUs.
3+
4+
[job]
5+
dump_folder = "./outputs"
6+
description = "Llama 3 405B training"
7+
8+
[profiling]
9+
enable_profiling = true
10+
save_traces_folder = "profile_trace"
11+
profile_freq = 100
12+
13+
[metrics]
14+
log_freq = 10
15+
enable_tensorboard = true
16+
save_tb_folder = "tb"
17+
18+
[model]
19+
name = "llama3"
20+
flavor = "405B"
21+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
22+
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"
23+
24+
[optimizer]
25+
name = "AdamW"
26+
lr = 0.8e-4
27+
28+
[training]
29+
batch_size = 2
30+
seq_len = 8192
31+
warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps
32+
max_norm = 1.0 # grad norm clipping
33+
steps = 3000
34+
data_parallel_degree = -1
35+
tensor_parallel_degree = 8 # 8-way TP
36+
enable_float8_linear = false
37+
compile = false
38+
dataset = "c4"
39+
40+
[experimental]
41+
pipeline_parallel_degree = 1
42+
43+
[checkpoint]
44+
enable_checkpoint = false
45+
folder = "checkpoint"
46+
interval_type = "steps"
47+
interval = 500
48+
model_weights_only = false
49+
export_dtype = "float32"
50+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
51+
52+
[activation_checkpoint]
53+
mode = 'full' # ['none', 'selective', 'full']

0 commit comments

Comments
 (0)