8
8
9
9
import json
10
10
import os
11
- from pathlib import Path
11
+ from typing import Dict , Tuple
12
12
13
13
import torch
14
+ from executorch .examples .models .checkpoint import (
15
+ get_checkpoint_dtype ,
16
+ get_default_model_resource_dir ,
17
+ )
14
18
15
19
from executorch .examples .models .llama2 .llama_transformer import ModelArgs , Transformer
16
20
@@ -30,48 +34,31 @@ def convert_to_llama_checkpoint(**kwargs):
30
34
31
35
class Llama2Model (EagerModelBase ):
32
36
def __init__ (self , ** kwargs ):
33
- import pkg_resources
34
-
35
- # default path to the resource file
36
- # It currently supports 3 ways of specifying the checkpoint location:
37
- # 1. Using default path locates in examples/models/llama2/params
38
- # 2. Passing in the checkpoint path and params via kwargs
39
- # 3. Using the path from pkg_resources, only works with buck2
40
- try :
41
- # The 3rd way, if we can import this path, we are running with buck2, all resources can be accessed with pkg_resources.resource_filename
42
- # pyre-ignore
43
- from executorch .examples .models .llama2 import params
44
-
45
- ckpt_dir = Path (
46
- pkg_resources .resource_filename (
47
- "executorch.examples.models.llama2" , "params"
48
- )
49
- )
50
- except :
51
- # The 1st way
52
- ckpt_dir = Path (__file__ ).absolute ().parent / "params"
53
-
54
- # Check if checkpoint_dir was provided for a sharded checkpoint.
55
- checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
37
+ resource_dir = get_default_model_resource_dir (__file__ )
56
38
57
39
# Use single checkpoint file.
58
- checkpoint_path = kwargs .get ("checkpoint" , ckpt_dir / "demo_rand_params.pth" )
40
+ checkpoint_path = kwargs .get (
41
+ "checkpoint" , resource_dir / "demo_rand_params.pth"
42
+ )
43
+ params_path = kwargs .get ("params" , resource_dir / "demo_config.json" )
59
44
60
- params_path = kwargs .get ("params" , ckpt_dir / "demo_config.json" )
45
+ # Check if checkpoint_dir was provided for a sharded checkpoint.
46
+ checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
61
47
62
48
self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
63
49
self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
64
50
self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
65
51
self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
66
52
self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
67
-
68
53
self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
69
54
self .args = kwargs .get ("args" , None )
55
+
70
56
# The example is using a dummy small model with random weights for demo purpose only.
71
- # Follow the instruction in https://github.com/facebookresearch/llama to download the model
57
+ # Follow the instruction in https://github.com/facebookresearch/llama to download the model.
72
58
device = "cpu"
73
59
# flake8: noqa: TOR102
74
60
cps = []
61
+ # Load sharded checkpoint.
75
62
if checkpoint_dir is not None :
76
63
# Load multiple checkpoint; ignore the single path.
77
64
checkpoint_path = None
@@ -98,8 +85,11 @@ def __init__(self, **kwargs):
98
85
else :
99
86
# Do not duplicate layers shared between each checkpoint.
100
87
checkpoint [key ] = cps [0 ][key ]
88
+ # Load single checkpoint.
101
89
else :
102
90
checkpoint = torch .load (checkpoint_path , map_location = device , mmap = True )
91
+
92
+ # If given checkpoint is fairseq, convert to llama checkpoint.
103
93
fairseq2_checkpoint = kwargs .get ("fairseq2" , False )
104
94
if fairseq2_checkpoint :
105
95
print ("Using fairseq2 checkpoint" )
@@ -108,12 +98,12 @@ def __init__(self, **kwargs):
108
98
# NB: some checkpoint contains a "model" field, which is the actual weights dict
109
99
checkpoint = checkpoint ["model" ]
110
100
101
+ # Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2.
111
102
if (not fairseq2_checkpoint ) and checkpoint .get (
112
103
"final_proj.weight" , None
113
104
) is not None :
114
- print (
105
+ raise ValueError (
115
106
"""
116
-
117
107
************************************************************
118
108
This looks like a Fairseq2 checkpoint (based on the presence
119
109
of `final_proj.weight`.
@@ -125,44 +115,28 @@ def __init__(self, **kwargs):
125
115
"""
126
116
)
127
117
128
- # get checkpoint dtype
129
- self .dtype = None
130
- if len (checkpoint ) > 0 :
131
- first_key = next (iter (checkpoint ))
132
- first = checkpoint [first_key ]
133
- self .dtype = first .dtype
134
- mismatched_dtypes = [
135
- (key , value .dtype )
136
- for key , value in checkpoint .items ()
137
- if value .dtype != self .dtype
138
- ]
139
- if len (mismatched_dtypes ) > 0 :
140
- print (
141
- f"Mixed dtype model. Dtype of { first_key } : { first .dtype } . Mismatches in the checkpoint: { mismatched_dtypes } "
142
- )
118
+ # Get checkpoint dtype.
119
+ self .dtype = get_checkpoint_dtype (checkpoint )
120
+
143
121
with open (params_path , "r" ) as f :
144
122
params = json .loads (f .read ())
145
123
output_prune_map = None
146
124
if self .output_prune_map_path is not None :
147
125
with open (self .output_prune_map_path , "r" ) as f :
148
126
output_prune_map = json .load (f )
149
- # change keys from string to int (json only supports string keys)
127
+ # Change keys from string to int (json only supports string keys).
150
128
output_prune_map = {int (k ): v for (k , v ) in output_prune_map .items ()}
151
- max_seq_len = self .max_seq_len
152
- max_batch_size = 1
129
+
153
130
model_args : ModelArgs = ModelArgs (
154
- max_seq_len = max_seq_len ,
155
- max_batch_size = max_batch_size ,
131
+ max_seq_len = self . max_seq_len ,
132
+ max_batch_size = 1 ,
156
133
use_kv_cache = self .use_kv_cache ,
157
134
use_sdpa_with_kv_cache_op = self .use_sdpa_with_kv_cache_op ,
158
135
generate_full_logits = self .generate_full_logits ,
159
136
output_prune_map = output_prune_map ,
160
137
enable_dynamic_shape = self .enable_dynamic_shape ,
161
138
** params ,
162
139
)
163
- if kwargs .get ("fairseq2" , False ):
164
- print ("Using fairseq2 checkpoint" )
165
- checkpoint = convert_to_llama_checkpoint (checkpoint = checkpoint )
166
140
if kwargs .get ("verbose" , False ):
167
141
print ("============= weights ================" )
168
142
print ("{key} : {weights.numel()} : {weights.size()}" )
@@ -234,13 +208,13 @@ def __init__(self, **kwargs):
234
208
print (unexpected )
235
209
print ("============= /unexpected ================" )
236
210
237
- # prune the output layer if output_prune_map is provided
211
+ # Prune the output layer if output_prune_map is provided
238
212
if output_prune_map is not None :
239
213
from .source_transformation .prune_output import prune_output_vocab
240
214
241
215
self .model_ = prune_output_vocab (self .model_ , output_prune_map )
242
216
243
- def get_eager_model (self ):
217
+ def get_eager_model (self ) -> torch . nn . Module :
244
218
if self .dtype :
245
219
# convert to the type of the provided checkpoint
246
220
# input and output are torch.long, so signature unchanged
0 commit comments