Skip to content

Commit 60b1173

Browse files
committed
Make freqs_cis a persistent buffer for pp init
currently, planning to use a 'seed checkpoint' to initialize the pipeline parallel model chunks after moving them from meta device to cuda/empty. non-persistent buffers are incompatible with this approach, as they are missing from the checkpoint and thus require manual init. an alternative is to manually run the initializer for just the non-persistent buffers after loading a seed-checkpoint, but this approach is nearly equivalent with less code changes. ghstack-source-id: b482284 Pull Request resolved: #201
1 parent d758406 commit 60b1173

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

torchtrain/models/llama/model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,15 @@ def __init__(self, model_args: ModelArgs):
309309
super().__init__()
310310
self.model_args = model_args
311311
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
312-
self.register_buffer(
313-
"freqs_cis", self._precompute_freqs_cis(), persistent=False
314-
)
312+
313+
# TODO persistent should be set to false, since this buffer can be recomputed.
314+
# however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
315+
# compile or pipeline-tracer will not correctly handle non-persistent buffers,
316+
# so we need to fix that. (2) if we initialize pipeline-parallel models from
317+
# a seed checkpoint rather than calling init_weights, we need freqs_cis to be
318+
# initialized by the checkpoint, or we need to add a separate initializer for
319+
# just the non-persistent buffers that is called after loading checkpoints.
320+
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
315321

316322
def _precompute_freqs_cis(self):
317323
return precompute_freqs_cis(

0 commit comments

Comments
 (0)