Skip to content

Commit 4745070

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Llama2 model cleanup (#5859)
Summary: - Removes redundant steps in the Llama2 export - Factors out checkpointing to be shared with future Llama models (namely 3.2 multimodal) - Comments and orders code more clearly PR chain: - [Add kwarg example inputs to eager model base](#5765) - **YOU ARE HERE ~>** [Llama2 model cleanup](#5859) - [Accept model type parameter in export_llama](#5910) - [Export TorchTune llama3_2_vision in ET](#5911) - [Add et version of TorchTune MHA for swapping with custom op](#5912) Pull Request resolved: #5859 Test Plan: Ensure export + eval is similar before and after for Stories 110M: ``` python -m examples.models.llama2.eval_llama -c <checkpoint.pth> -p <params.json> -t <tokenizer.model/bin> -d fp32 --max_seq_len 2048 --limit 1000 ``` Before: ``` wikitext: {'word_perplexity,none': 14464.645927166595, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 5.99788806086652, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 2.5844545973083983, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} ``` After: ``` wikitext: {'word_perplexity,none': 14464.299192404438, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 5.997861173678705, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 2.584448130015399, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} ``` Reviewed By: malfet, dbort Differential Revision: D64145852 Pulled By: dvorjackz fbshipit-source-id: daeee834955e154e7c8262ce776bd3039991027d
1 parent 4e8f609 commit 4745070

File tree

3 files changed

+103
-55
lines changed

3 files changed

+103
-55
lines changed

examples/models/checkpoint.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from pathlib import Path
10+
from typing import Any, Dict, Optional
11+
12+
13+
def get_default_model_resource_dir(model_file_path: str) -> Path:
14+
"""
15+
Get the default path to resouce files (which contain files such as the
16+
checkpoint and param files), either:
17+
1. Uses the path from pkg_resources, only works with buck2
18+
2. Uses default path located in examples/models/llama2/params
19+
20+
Expected to be called from with a `model.py` file located in a
21+
`executorch/examples/models/<model_name>` directory.
22+
23+
Args:
24+
model_file_path: The file path to the eager model definition.
25+
For example, `executorch/examples/models/llama2/model.py`,
26+
where `executorch/examples/models/llama2` contains all
27+
the llama2-related files.
28+
29+
Returns:
30+
The path to the resource directory containing checkpoint, params, etc.
31+
"""
32+
33+
try:
34+
import pkg_resources
35+
36+
# 1st way: If we can import this path, we are running with buck2 and all resources can be accessed with pkg_resources.
37+
# pyre-ignore
38+
from executorch.examples.models.llama2 import params # noqa
39+
40+
# Get the model name from the cwd, assuming that this module is called from a path such as
41+
# examples/models/<model_name>/model.py.
42+
model_name = Path(model_file_path).parent.name
43+
resource_dir = Path(
44+
pkg_resources.resource_filename(
45+
f"executorch.examples.models.{model_name}", "params"
46+
)
47+
)
48+
except:
49+
# 2nd way.
50+
resource_dir = Path(model_file_path).absolute().parent / "params"
51+
52+
return resource_dir
53+
54+
55+
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
56+
"""
57+
Get the dtype of the checkpoint, returning "None" if the checkpoint is empty.
58+
"""
59+
dtype = None
60+
if len(checkpoint) > 0:
61+
first_key = next(iter(checkpoint))
62+
first = checkpoint[first_key]
63+
dtype = first.dtype
64+
mismatched_dtypes = [
65+
(key, value.dtype)
66+
for key, value in checkpoint.items()
67+
if value.dtype != dtype
68+
]
69+
if len(mismatched_dtypes) > 0:
70+
raise ValueError(
71+
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
72+
)
73+
return dtype

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ runtime.python_library(
4646
"//caffe2:torch",
4747
"//executorch/examples/models:model_base",
4848
"//executorch/examples/models/llama2:llama_transformer",
49+
"//executorch/examples/models:checkpoint",
4950
],
5051
)
5152

examples/models/llama2/model.py

Lines changed: 29 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88

99
import json
1010
import os
11-
from pathlib import Path
11+
from typing import Dict, Tuple
1212

1313
import torch
14+
from executorch.examples.models.checkpoint import (
15+
get_checkpoint_dtype,
16+
get_default_model_resource_dir,
17+
)
1418

1519
from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer
1620

@@ -30,48 +34,31 @@ def convert_to_llama_checkpoint(**kwargs):
3034

3135
class Llama2Model(EagerModelBase):
3236
def __init__(self, **kwargs):
33-
import pkg_resources
34-
35-
# default path to the resource file
36-
# It currently supports 3 ways of specifying the checkpoint location:
37-
# 1. Using default path locates in examples/models/llama2/params
38-
# 2. Passing in the checkpoint path and params via kwargs
39-
# 3. Using the path from pkg_resources, only works with buck2
40-
try:
41-
# The 3rd way, if we can import this path, we are running with buck2, all resources can be accessed with pkg_resources.resource_filename
42-
# pyre-ignore
43-
from executorch.examples.models.llama2 import params
44-
45-
ckpt_dir = Path(
46-
pkg_resources.resource_filename(
47-
"executorch.examples.models.llama2", "params"
48-
)
49-
)
50-
except:
51-
# The 1st way
52-
ckpt_dir = Path(__file__).absolute().parent / "params"
53-
54-
# Check if checkpoint_dir was provided for a sharded checkpoint.
55-
checkpoint_dir = kwargs.get("checkpoint_dir", None)
37+
resource_dir = get_default_model_resource_dir(__file__)
5638

5739
# Use single checkpoint file.
58-
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
40+
checkpoint_path = kwargs.get(
41+
"checkpoint", resource_dir / "demo_rand_params.pth"
42+
)
43+
params_path = kwargs.get("params", resource_dir / "demo_config.json")
5944

60-
params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
45+
# Check if checkpoint_dir was provided for a sharded checkpoint.
46+
checkpoint_dir = kwargs.get("checkpoint_dir", None)
6147

6248
self.use_kv_cache = kwargs.get("use_kv_cache", False)
6349
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
6450
self.generate_full_logits = kwargs.get("generate_full_logits", False)
6551
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
6652
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
67-
6853
self.max_seq_len = kwargs.get("max_seq_len", 128)
6954
self.args = kwargs.get("args", None)
55+
7056
# The example is using a dummy small model with random weights for demo purpose only.
71-
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
57+
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
7258
device = "cpu"
7359
# flake8: noqa: TOR102
7460
cps = []
61+
# Load sharded checkpoint.
7562
if checkpoint_dir is not None:
7663
# Load multiple checkpoint; ignore the single path.
7764
checkpoint_path = None
@@ -98,8 +85,11 @@ def __init__(self, **kwargs):
9885
else:
9986
# Do not duplicate layers shared between each checkpoint.
10087
checkpoint[key] = cps[0][key]
88+
# Load single checkpoint.
10189
else:
10290
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
91+
92+
# If given checkpoint is fairseq, convert to llama checkpoint.
10393
fairseq2_checkpoint = kwargs.get("fairseq2", False)
10494
if fairseq2_checkpoint:
10595
print("Using fairseq2 checkpoint")
@@ -108,12 +98,12 @@ def __init__(self, **kwargs):
10898
# NB: some checkpoint contains a "model" field, which is the actual weights dict
10999
checkpoint = checkpoint["model"]
110100

101+
# Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2.
111102
if (not fairseq2_checkpoint) and checkpoint.get(
112103
"final_proj.weight", None
113104
) is not None:
114-
print(
105+
raise ValueError(
115106
"""
116-
117107
************************************************************
118108
This looks like a Fairseq2 checkpoint (based on the presence
119109
of `final_proj.weight`.
@@ -125,44 +115,28 @@ def __init__(self, **kwargs):
125115
"""
126116
)
127117

128-
# get checkpoint dtype
129-
self.dtype = None
130-
if len(checkpoint) > 0:
131-
first_key = next(iter(checkpoint))
132-
first = checkpoint[first_key]
133-
self.dtype = first.dtype
134-
mismatched_dtypes = [
135-
(key, value.dtype)
136-
for key, value in checkpoint.items()
137-
if value.dtype != self.dtype
138-
]
139-
if len(mismatched_dtypes) > 0:
140-
print(
141-
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
142-
)
118+
# Get checkpoint dtype.
119+
self.dtype = get_checkpoint_dtype(checkpoint)
120+
143121
with open(params_path, "r") as f:
144122
params = json.loads(f.read())
145123
output_prune_map = None
146124
if self.output_prune_map_path is not None:
147125
with open(self.output_prune_map_path, "r") as f:
148126
output_prune_map = json.load(f)
149-
# change keys from string to int (json only supports string keys)
127+
# Change keys from string to int (json only supports string keys).
150128
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
151-
max_seq_len = self.max_seq_len
152-
max_batch_size = 1
129+
153130
model_args: ModelArgs = ModelArgs(
154-
max_seq_len=max_seq_len,
155-
max_batch_size=max_batch_size,
131+
max_seq_len=self.max_seq_len,
132+
max_batch_size=1,
156133
use_kv_cache=self.use_kv_cache,
157134
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
158135
generate_full_logits=self.generate_full_logits,
159136
output_prune_map=output_prune_map,
160137
enable_dynamic_shape=self.enable_dynamic_shape,
161138
**params,
162139
)
163-
if kwargs.get("fairseq2", False):
164-
print("Using fairseq2 checkpoint")
165-
checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint)
166140
if kwargs.get("verbose", False):
167141
print("============= weights ================")
168142
print("{key} : {weights.numel()} : {weights.size()}")
@@ -234,13 +208,13 @@ def __init__(self, **kwargs):
234208
print(unexpected)
235209
print("============= /unexpected ================")
236210

237-
# prune the output layer if output_prune_map is provided
211+
# Prune the output layer if output_prune_map is provided
238212
if output_prune_map is not None:
239213
from .source_transformation.prune_output import prune_output_vocab
240214

241215
self.model_ = prune_output_vocab(self.model_, output_prune_map)
242216

243-
def get_eager_model(self):
217+
def get_eager_model(self) -> torch.nn.Module:
244218
if self.dtype:
245219
# convert to the type of the provided checkpoint
246220
# input and output are torch.long, so signature unchanged

0 commit comments

Comments
 (0)