Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,15 @@ def __init__(self, model_args: ModelArgs):
super().__init__()
self.model_args = model_args
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
self.register_buffer(
"freqs_cis", self._precompute_freqs_cis(), persistent=False
)

# TODO persistent should be set to false, since this buffer can be recomputed.
# however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
# compile or pipeline-tracer will not correctly handle non-persistent buffers,
# so we need to fix that. (2) if we initialize pipeline-parallel models from
# a seed checkpoint rather than calling init_weights, we need freqs_cis to be
# initialized by the checkpoint, or we need to add a separate initializer for
# just the non-persistent buffers that is called after loading checkpoints.
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)

def _precompute_freqs_cis(self):
return precompute_freqs_cis(
Expand Down