11
11
import logging
12
12
13
13
import torch
14
- from executorch .backends .arm .arm_backend import generate_ethosu_compile_spec
15
14
15
+ from executorch .backends .arm .arm_backend import generate_ethosu_compile_spec
16
16
from executorch .backends .arm .arm_partitioner import ArmPartitioner
17
17
from executorch .exir import EdgeCompileConfig , ExecutorchBackendConfig
18
18
19
+ from ..models import MODEL_NAME_TO_MODEL
20
+ from ..models .model_factory import EagerModelFactory
19
21
from ..portable .utils import export_to_edge , save_pte_program
20
22
21
23
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
22
- logging .basicConfig (level = logging .INFO , format = FORMAT )
23
-
24
- # TODO: When we have a more reliable quantization flow through to
25
- # Vela, and use the models in their original form with a
26
- # quantization step in our example. This will take the models
27
- # from examples/models/ and quantize then export to delegate.
24
+ logging .basicConfig (level = logging .WARNING , format = FORMAT )
25
+
26
+ # Quantize model if required using the standard export quantizaion flow.
27
+ # For now we're using the xnnpack quantizer as this produces reasonable
28
+ # output for our arithmetic behaviour.
29
+ from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
30
+ from torch .ao .quantization .quantizer .xnnpack_quantizer import (
31
+ get_symmetric_quantization_config ,
32
+ XNNPACKQuantizer ,
33
+ )
34
+
35
+
36
+ def quantize (model , example_inputs ):
37
+ """This is the official recommended flow for quantization in pytorch 2.0 export"""
38
+ logging .info ("Quantizing Model..." )
39
+ logging .debug (f"Original model: { model } " )
40
+ quantizer = XNNPACKQuantizer ()
41
+ # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
42
+ operator_config = get_symmetric_quantization_config (is_per_channel = False )
43
+ quantizer .set_global (operator_config )
44
+ m = prepare_pt2e (model , quantizer )
45
+ # calibration
46
+ m (* example_inputs )
47
+ m = convert_pt2e (m )
48
+ logging .debug (f"Quantized model: { m } " )
49
+ # make sure we can export to flat buffer
50
+ return m
28
51
29
52
30
53
# Two simple models
@@ -93,7 +116,7 @@ def forward(self, x):
93
116
"-m" ,
94
117
"--model_name" ,
95
118
required = True ,
96
- help = f"Provide model name. Valid ones: { list (models .keys ())} " ,
119
+ help = f"Provide model name. Valid ones: { set ( list (models .keys ()) + list ( MODEL_NAME_TO_MODEL . keys () ))} " ,
97
120
)
98
121
parser .add_argument (
99
122
"-d" ,
@@ -103,10 +126,22 @@ def forward(self, x):
103
126
default = False ,
104
127
help = "Flag for producing ArmBackend delegated model" ,
105
128
)
129
+ parser .add_argument (
130
+ "-q" ,
131
+ "--quantize" ,
132
+ action = "store_true" ,
133
+ required = False ,
134
+ default = False ,
135
+ help = "Produce a quantized model" ,
136
+ )
106
137
107
138
args = parser .parse_args ()
108
139
109
- if args .model_name not in models .keys ():
140
+ # support models defined within this file or examples/models/ lists
141
+ if (
142
+ args .model_name not in models .keys ()
143
+ and args .model_name not in MODEL_NAME_TO_MODEL .keys ()
144
+ ):
110
145
raise RuntimeError (f"Model { args .model_name } is not a valid name." )
111
146
112
147
if (
@@ -116,28 +151,47 @@ def forward(self, x):
116
151
):
117
152
raise RuntimeError (f"Model { args .model_name } cannot be delegated." )
118
153
119
- model = models [args .model_name ]()
120
- example_inputs = models [args .model_name ].example_input
154
+ # 1. pick model from one of the supported lists
155
+ model = None
156
+ example_inputs = None
157
+
158
+ # 1.a. models in this file
159
+ if args .model_name in models .keys ():
160
+ model = models [args .model_name ]()
161
+ example_inputs = models [args .model_name ].example_input
162
+ # 1.b. models in the examples/models/
163
+ # IFF the model is not in our local models
164
+ elif args .model_name in MODEL_NAME_TO_MODEL .keys ():
165
+ logging .warning (
166
+ "Using a model from examples/models not all of these are currently supported"
167
+ )
168
+ model , example_inputs , _ = EagerModelFactory .create_model (
169
+ * MODEL_NAME_TO_MODEL [args .model_name ]
170
+ )
121
171
122
172
model = model .eval ()
123
173
124
174
# pre-autograd export. eventually this will become torch.export
125
175
model = torch ._export .capture_pre_autograd_graph (model , example_inputs )
126
176
177
+ # Quantize if required
178
+ if args .quantize :
179
+ model = quantize (model , example_inputs )
180
+
127
181
edge = export_to_edge (
128
182
model ,
129
183
example_inputs ,
130
184
edge_compile_config = EdgeCompileConfig (
131
185
_check_ir_validity = False ,
132
186
),
133
187
)
134
- logging .info (f"Exported graph:\n { edge .exported_program ().graph } " )
188
+ logging .debug (f"Exported graph:\n { edge .exported_program ().graph } " )
135
189
136
190
if args .delegate is True :
137
191
edge = edge .to_backend (
138
192
ArmPartitioner (generate_ethosu_compile_spec ("ethos-u55-128" ))
139
193
)
140
- logging .info (f"Lowered graph:\n { edge .exported_program ().graph } " )
194
+ logging .debug (f"Lowered graph:\n { edge .exported_program ().graph } " )
141
195
142
196
exec_prog = edge .to_executorch (
143
197
config = ExecutorchBackendConfig (extract_constant_segment = False )
@@ -146,4 +200,4 @@ def forward(self, x):
146
200
model_name = f"{ args .model_name } " + (
147
201
"_arm_delegate" if args .delegate is True else ""
148
202
)
149
- save_pte_program (exec_prog . buffer , model_name )
203
+ save_pte_program (exec_prog , model_name )
0 commit comments