Skip to content

Commit a757a84

Browse files
committed
Update on "adapt CI tests to use compiled_rmsnorm"
[ghstack-poisoned]
2 parents ec51c4f + 5124363 commit a757a84

38 files changed

+864
-477
lines changed

.github/workflows/integration_test_4gpu.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ jobs:
3939
4040
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
4141
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
42+
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
4243
mkdir artifacts-to-be-uploaded
4344
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ Our guiding principles when building `torchtitan`:
1818

1919
[![Welcome to torchtitan!](assets/images/titan_play_video.png)](https://youtu.be/ee5DOEqD35I?si=_B94PbVv0V5ZnNKE "Welcome to torchtitan!")
2020

21+
### Dive into the code
22+
23+
You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first:
24+
* [train.py](https://github.com/pytorch/torchtitan/blob/main/train.py) - the main training loop and high-level setup code
25+
* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data / Tensor / Pipeline Parallelisms to the model
26+
* [torchtitan/checkpoint.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints
27+
* [torchtitan/models/llama/model.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants)
28+
2129
## Pre-Release Updates:
2230
#### (4/25/2024): `torchtitan` is now public but in a pre-release state and under development.
2331
Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes from scratch. `torchtitan` is tested and verified with the PyTorch nightly version `torch-2.4.0.dev20240412`. (We recommend latest PyTorch nightly).
@@ -66,7 +74,7 @@ Once you have confirmed access, you can run the following command to download th
6674
```bash
6775
# Get your HF token from https://huggingface.co/settings/tokens
6876

69-
# llama3 tokenizer.model
77+
# llama3 or 3.1 tokenizer.model
7078
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...
7179

7280
# llama2 tokenizer.model

create_seed_checkpoint.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
set -ex
2020

21-
export USE_LIBUV=1
22-
TRAINER_DIR=${1:-/home/$USER/local/torchtitan}
2321
NGPU=1
2422
LOG_RANK=0
2523
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}

estimation.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,19 @@
99
import os
1010

1111
import torch
12-
import torch.nn.functional as F
1312
from torch._guards import active_fake_mode
1413
from torch._subclasses.fake_tensor import FakeTensorMode
15-
from torch.distributed import destroy_process_group
1614
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
17-
from torch.distributed.tensor.parallel import loss_parallel
1815
from torch.testing._internal.distributed.fake_pg import FakeStore
1916

2017
from torchtitan.config_manager import JobConfig
21-
from torchtitan.datasets import create_tokenizer
22-
from torchtitan.float8_linear import build_fp8_linear
23-
from torchtitan.logging_utils import init_logger, logger
24-
from torchtitan.lr_scheduling import get_lr_schedulers
18+
from torchtitan.datasets import build_tokenizer
19+
from torchtitan.float8_linear import Float8Handler
20+
from torchtitan.logging import init_logger, logger
2521
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
22+
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
2623
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
27-
from train import build_optimizers
24+
from train import get_train_context
2825

2926

3027
def estimate_memory(job_config: JobConfig):
@@ -61,16 +58,18 @@ def estimate_memory(job_config: JobConfig):
6158
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
6259
job_config.model.norm_type = "rmsnorm"
6360

64-
if job_config.training.compile:
61+
if job_config.training.compile or job_config.experimental.enable_compiled_autograd:
6562
logger.info("Compile mode is not supported yet. Switching to eager mode.")
6663
job_config.training.compile = False
64+
job_config.experimental.enable_compiled_autograd = False
6765

6866
parallel_dims = ParallelDims(
6967
dp=job_config.training.data_parallel_degree,
7068
tp=job_config.training.tensor_parallel_degree,
7169
pp=job_config.experimental.pipeline_parallel_degree,
7270
world_size=world_size,
7371
enable_loss_parallel=job_config.training.enable_loss_parallel,
72+
dp_type=job_config.training.data_parallel_type,
7473
)
7574

7675
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
@@ -93,16 +92,18 @@ def estimate_memory(job_config: JobConfig):
9392

9493
# build tokenizer
9594
tokenizer_type = model_name_to_tokenizer[model_name]
96-
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
95+
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
9796

98-
# loss_parallel enables dispatching to efficient loss operators
99-
loss_parallel_ctx = (
100-
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
97+
train_context = get_train_context(
98+
parallel_dims.loss_parallel_enabled,
99+
job_config.experimental.enable_compiled_autograd,
101100
)
102101

103102
# loss fn can be shared by pipeline-parallel or non-pp execution
104103
def loss_fn(pred, labels):
105-
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
104+
return torch.nn.functional.cross_entropy(
105+
pred.flatten(0, 1), labels.flatten(0, 1)
106+
)
106107

107108
# build model (using meta init)
108109
model_cls = model_name_to_cls[model_name]
@@ -123,9 +124,10 @@ def loss_fn(pred, labels):
123124
with torch.device("meta"):
124125
whole_model = model_cls.from_model_args(model_config)
125126

126-
# apply fp8 linear module swap
127-
if job_config.training.fp8_linear:
128-
build_fp8_linear(whole_model, job_config)
127+
# a no-op hander if fp8 is not enabled
128+
float8_handler = Float8Handler(job_config, parallel_dims)
129+
# swap to Float8Linear base on fp8 config
130+
float8_handler.convert_to_float8_training(whole_model)
129131

130132
# apply PT-D DP/TP parallelisms and activation checkpointing
131133
model_parts = [whole_model]
@@ -143,7 +145,7 @@ def loss_fn(pred, labels):
143145

144146
# build optimizer after applying parallelisms to the model
145147
optimizers = build_optimizers(model_parts, job_config)
146-
lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config)
148+
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
147149

148150
for model in model_parts:
149151
model.train()
@@ -170,7 +172,7 @@ def loss_fn(pred, labels):
170172
for iter_idx in range(2):
171173
input_ids, labels = batch
172174
# train step
173-
with loss_parallel_ctx():
175+
with train_context():
174176
pred = whole_model(input_ids)
175177
loss = loss_fn(pred, labels)
176178
del pred
@@ -181,9 +183,14 @@ def loss_fn(pred, labels):
181183
torch.nn.utils.clip_grad_norm_(
182184
model.parameters(), job_config.training.max_norm, foreach=True
183185
)
186+
# sync float8 amaxes and scales
187+
float8_handler.sync_float8_amax_and_scale_history(model)
184188
# optimizer step
185189
optimizers.step()
186190
lr_schedulers.step()
191+
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
192+
# it issues a single all-reduce for all parameters at once for better performance
193+
float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model)
187194
optimizers.zero_grad()
188195
print(f"Peak Memory at iter: {iter_idx}")
189196
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
@@ -217,4 +224,4 @@ def loss_fn(pred, labels):
217224
try:
218225
estimate_memory(config)
219226
finally:
220-
destroy_process_group()
227+
torch.distributed.destroy_process_group()

multinode_trainer.slurm

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
5353
export NCCL_BUFFSIZE=2097152
5454
#export TORCH_DIST_INIT_BARRIER=1
5555
export FI_EFA_SET_CUDA_SYNC_MEMOPS=0
56-
#export USE_LIBUV=1
5756
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/llama2_13b.toml"}
5857

5958
dcgmi profile --pause

run_llama_train.sh

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,18 @@
77

88
set -ex
99

10-
# libUV is a scalable backend for TCPStore which is used in processGroup
11-
# rendezvous. This is the recommended backend for distributed training.
12-
export USE_LIBUV=1
13-
TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan}
14-
1510
# use envs as local overrides for convenience
1611
# e.g.
1712
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh
18-
1913
NGPU=${NGPU:-"8"}
20-
NNODES=${NNODES:-"1"}
21-
22-
# by default log just rank 0 output,
2314
LOG_RANK=${LOG_RANK:-0}
24-
25-
2615
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
2716

2817
overrides=""
2918
if [ $# -ne 0 ]; then
3019
overrides="$*"
3120
fi
3221

33-
# Check if --estimate.memory=True is in the arguments
34-
if echo "$overrides" | grep -q -- "--memory_estimation.enabled"; then
35-
# Calculate WORLD_SIZE as the product of NGPU and NNODES
36-
# Export WORLD_SIZE and LOCAL_RANK
37-
export WORLD_SIZE=$((NGPU * NNODES))
38-
export LOCAL_RANK=0
39-
python estimation.py --job.config_file ${CONFIG_FILE} $overrides
40-
else
41-
# Call train.py if not in estimation mode
42-
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
43-
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
44-
train.py --job.config_file ${CONFIG_FILE} $overrides
45-
fi
22+
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
23+
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
24+
train.py --job.config_file ${CONFIG_FILE} $overrides

run_memory_estimation.sh

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -ex
9+
10+
# use envs as local overrides for convenience
11+
# e.g.
12+
# NGPU=4 ./run_memory_estimation.sh
13+
NGPU=${NGPU:-"8"}
14+
NNODES=${NNODES:-"1"}
15+
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
16+
17+
overrides=""
18+
if [ $# -ne 0 ]; then
19+
overrides="$*"
20+
fi
21+
22+
# Calculate WORLD_SIZE as the product of NGPU and NNODES
23+
# Export WORLD_SIZE and LOCAL_RANK
24+
export WORLD_SIZE=$((NGPU * NNODES))
25+
export LOCAL_RANK=0
26+
python estimation.py --job.config_file ${CONFIG_FILE} --memory_estimation.enabled $overrides

test/datasets/test_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88
from torchtitan.datasets.hf_datasets import build_hf_data_loader
9-
from torchtitan.datasets.tokenizer import create_tokenizer
9+
from torchtitan.datasets.tokenizer import build_tokenizer
1010

1111

1212
class TestCheckpoint:
@@ -42,7 +42,7 @@ def _build_dataloader(
4242
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
4343
):
4444
tokenizer_type = "tiktoken"
45-
tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
45+
tokenizer = build_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
4646
return build_hf_data_loader(
4747
dataset_name=dataset_name,
4848
dataset_path=dataset_path,

test_runner.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ def build_test_list():
4646
"""
4747
integration_tests_flavors = defaultdict(list)
4848
integration_tests_flavors["debug_model.toml"] = [
49+
OverrideDefinitions(
50+
[
51+
[
52+
"--checkpoint.enable_checkpoint",
53+
"--experimental.pipeline_parallel_degree 4",
54+
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
55+
"--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b",
56+
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
57+
],
58+
],
59+
"PP looped flexible 1f1b test",
60+
"pp_looped_flexible_1f1b",
61+
requires_seed_checkpoint=True,
62+
ngpu=4,
63+
),
4964
OverrideDefinitions(
5065
[
5166
[
@@ -253,7 +268,7 @@ def build_test_list():
253268
"--experimental.pipeline_parallel_degree 4",
254269
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
255270
"--experimental.pipeline_parallel_schedule interleaved_1f1b",
256-
"--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm throws cuda context error with pp
271+
"--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP
257272
],
258273
],
259274
"PP looped 1f1b test",
@@ -281,6 +296,16 @@ def build_test_list():
281296
"fsdp2_mem_tracker",
282297
ngpu=4,
283298
),
299+
OverrideDefinitions(
300+
[
301+
[
302+
"--training.data_parallel_type ddp",
303+
]
304+
],
305+
"DDP",
306+
"ddp",
307+
ngpu=4,
308+
),
284309
]
285310
return integration_tests_flavors
286311

@@ -312,6 +337,8 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
312337

313338
for override_arg in test_flavor.override_args:
314339
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
340+
if test_name == "fsdp2_mem_tracker":
341+
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_memory_estimation.sh"
315342
cmd += " " + dump_folder_arg
316343
cmd += " " + model_flavor_arg
317344
if override_arg:

0 commit comments

Comments
 (0)