17
17
#include " Loader.h"
18
18
19
19
#include " glow/Base/Tensor.h"
20
+ #include " glow/Converter/TypeAToTypeBFunctionConverter.h"
20
21
#include " glow/ExecutionEngine/ExecutionEngine.h"
21
22
#include " glow/IR/IR.h"
22
23
#include " glow/Quantization/Serialization.h"
@@ -89,13 +90,14 @@ llvm::cl::opt<std::string> loadProfileFileOpt(
89
90
llvm::cl::value_desc(" profile.yaml" ), llvm::cl::Optional,
90
91
llvm::cl::cat(loaderCat));
91
92
92
- llvm::cl::list<std::string> doNotQuantizeNodesOpt (
93
- " do_not_quantize_nodes " ,
93
+ llvm::cl::list<std::string> keepOriginalPrecisionForNodesOpt (
94
+ " keep-original-precision-for-nodes " ,
94
95
llvm::cl::desc (
95
96
" Use to specify the name of nodes (e.g. Add, Div, etc.) that should "
96
- " not be quantized. All nodes of the listed kinds would not be "
97
- " quantized; e.g. if Add is specififed and there are multiple Add nodes "
98
- " in the input loaded model, none would be quantized." ),
97
+ " be kept as is when conversion/quantization is requested. "
98
+ " All nodes of the listed kinds will be kept as is;"
99
+ " e.g. if Add is specified and there are multiple Add nodes "
100
+ " in the input loaded model, none would be quantized/converted." ),
99
101
llvm::cl::value_desc(" NodeNames (e.g. Add,Div)" ), llvm::cl::ZeroOrMore,
100
102
llvm::cl::CommaSeparated, llvm::cl::cat(loaderCat));
101
103
@@ -123,6 +125,11 @@ llvm::cl::opt<bool> dumpGraphOpt("dumpGraph",
123
125
llvm::cl::desc (" Prints Graph to stdout" ),
124
126
llvm::cl::cat(modelExportCat));
125
127
128
+ llvm::cl::opt<bool >
129
+ convertToFP16 (" convert-to-fp16" ,
130
+ llvm::cl::desc (" Run all floating-point computation in fp16." ),
131
+ llvm::cl::init(false ), llvm::cl::cat(loaderCat));
132
+
126
133
// / Emit a bundle into the specified output directory.
127
134
llvm::cl::opt<std::string>
128
135
emitBundle (" emit-bundle" ,
@@ -217,6 +224,16 @@ void Loader::compile(Context &ctx) {
217
224
F_ = ::profileQuantization (ctx, F_);
218
225
}
219
226
227
+ // By default, when converting models, all nodes that can be
228
+ // converted are converted. However, some models may need to
229
+ // keep higher precision for some nodes to prevent high accuracy loss.
230
+ // Those nodes are gathered via the keepOriginalPrecisionForNodesOpt
231
+ // option and passed to the related conversion function.
232
+ KindSet keepOriginalPrecisionForNodes;
233
+ for (llvm::StringRef kindName : keepOriginalPrecisionForNodesOpt) {
234
+ keepOriginalPrecisionForNodes.insert (getKindFromNodeName (kindName));
235
+ }
236
+
220
237
// Load the quantization profile and transform the graph.
221
238
if (!loadProfileFileOpt.empty ()) {
222
239
// The profiled graph was optimized before it was instrumentated. In this
@@ -233,25 +250,24 @@ void Loader::compile(Context &ctx) {
233
250
std::string oldName = F_->getName ();
234
251
F_->setName (" old" );
235
252
236
- // By default, when quantizing loaded models, all nodes that can be
237
- // quantized are quantized. However, some models that are loaded may need to
238
- // keep higher precision for some nodes to prevent high accuracy loss. This
239
- // set is passed into quantizeFunction() to prevent quantization.
240
- KindSet doNotQuantizeKinds;
241
- for (llvm::StringRef kindName : doNotQuantizeNodesOpt) {
242
- doNotQuantizeKinds.insert (getKindFromNodeName (kindName));
243
- }
244
-
245
253
// Quantize the graph based on the captured profile.
246
- auto *Q = quantization::quantizeFunction (EE_, quantizationInfos, F_,
247
- oldName, doNotQuantizeKinds );
254
+ auto *Q = quantization::quantizeFunction (
255
+ EE_, quantizationInfos, F_, oldName, keepOriginalPrecisionForNodes );
248
256
249
257
// Erase the original function so that the redundant variables that are only
250
258
// referenced by the original function will be removed.
251
259
Q->getParent ()->eraseFunction (F_);
252
260
F_ = Q;
253
261
}
254
262
263
+ if (convertToFP16) {
264
+ TypeAToTypeBFunctionConverter converter (*F_, ElemKind::FloatTy,
265
+ ElemKind::Float16Ty,
266
+ &keepOriginalPrecisionForNodes);
267
+ converter.convert ();
268
+ ::optimize (F_, glow::CompilationMode::Infer);
269
+ }
270
+
255
271
if (emittingBundle ()) {
256
272
// Emit IR for the graph, compile it and save as a bundle.
257
273
EE_.save (CompilationMode::Infer, F_, emitBundle, networkName);
0 commit comments