Skip to content

Commit 8f2e31a

Browse files
committed
Add support for seed checkpoint creation for meta-init flow
ghstack-source-id: 5179c55 Pull Request resolved: #172
1 parent dca7657 commit 8f2e31a

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed

run_llama_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
2121

2222
torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
2323
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
24-
train.py --job.config_file ${CONFIG_FILE}
24+
train.py --job.config_file ${CONFIG_FILE} --training.checkpoint_folder /data/users/whc/torchtrain

seed_checkpoint.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3+
4+
import os
5+
6+
import torch.distributed.checkpoint as DCP
7+
8+
from torchtrain.config_manager import JobConfig
9+
from torchtrain.datasets import create_tokenizer
10+
from torchtrain.float8_linear import build_fp8_linear
11+
from torchtrain.logging_utils import init_logger, logger
12+
from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
13+
14+
_is_local_logging = True
15+
if "SLURM_JOB_ID" in os.environ:
16+
_is_local_logging = False
17+
18+
19+
def main(job_config: JobConfig):
20+
init_logger()
21+
22+
model_name = job_config.model.name
23+
24+
# build tokenizer
25+
tokenizer_type = model_name_to_tokenizer[model_name]
26+
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
27+
28+
# build model (using meta init)
29+
model_cls = model_name_to_cls[model_name]
30+
model_config = models_config[model_name][job_config.model.flavor]
31+
model_config.vocab_size = tokenizer.n_words
32+
logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}")
33+
model = model_cls.from_model_args(model_config)
34+
35+
# apply fp8 linear module swap
36+
if job_config.training.fp8_linear:
37+
build_fp8_linear(model, job_config)
38+
39+
model.reset_parameters()
40+
41+
checkpoint_id = os.path.join(job_config.training.checkpoint_folder, "step-0")
42+
logger.info(f"Creating seed (step-0) checkpoint in {checkpoint_id}")
43+
DCP.save(
44+
state_dict={
45+
"model": model.state_dict(),
46+
},
47+
checkpoint_id=checkpoint_id,
48+
)
49+
50+
51+
"""
52+
1. how do i serialize enough info about the model config to ensure i don't try to load an incompatible checkpoint later?
53+
- maybe skip this. users responsible to manage their checkpoints, and we can partially help by managing their 'dump folder'?
54+
55+
2. would i apply fp8 before creating the seed or not? I think probably before
56+
3. can i skip optimizer in seed file? i think so. optimizer can later create its states from the model post-sharding
57+
"""
58+
if __name__ == "__main__":
59+
config = JobConfig()
60+
config.parse_args()
61+
main(config)

torchtrain/checkpoint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
123123
)
124124

125125
def load(self, step: int = -1) -> bool:
126+
logger.info(f"Trying Loading a checkpoint from '{self.folder}'")
126127
if not self.folder:
127128
return False
128129
if not os.path.isdir(self.folder):
@@ -140,10 +141,12 @@ def load(self, step: int = -1) -> bool:
140141
return False
141142
step = max(step_counts)
142143

144+
# We won't have optimizer states to load, if we are loading a seed checkpoint
145+
states = {"model": self.states["model"]} if step == 0 else self.states
143146
logger.info(f"Loading the checkpoint at step {step}")
144147
begin = time.monotonic()
145148
dcp.load(
146-
self.states,
149+
states,
147150
checkpoint_id=self.create_checkpoint_id(step),
148151
)
149152
logger.info(

0 commit comments

Comments
 (0)