19
19
#include " glow/ExecutionEngine/ExecutionEngine.h"
20
20
#include " glow/Graph/Graph.h"
21
21
#include " glow/Importer/Caffe2ModelLoader.h"
22
-
22
+ # include " glow/Importer/ONNXModelLoader.h "
23
23
#include " gtest/gtest.h"
24
24
25
25
#include " llvm/ADT/StringMap.h"
@@ -40,7 +40,8 @@ class LoaderTest : public ::testing::Test {
40
40
using namespace glow ;
41
41
42
42
namespace {
43
- const dim_t BATCH_SIZE = 8 ;
43
+ const dim_t BATCH_SIZE = 3 ;
44
+ const dim_t BATCH_SIZE_TFLITE = 1 ;
44
45
const size_t MINI_BATCH_SIZE = 2 ;
45
46
} // namespace
46
47
@@ -51,18 +52,23 @@ class testLoaderExtension : public LoaderExtension {
51
52
static size_t index_;
52
53
static Loader *loader_;
53
54
static PlaceholderBindings *bindings_;
55
+ static ProtobufLoader *protobufLoader_;
56
+ static TFLiteModelLoader *tfliteloader_;
54
57
static bool destructed_;
55
58
56
59
testLoaderExtension () {
57
60
stage_ = 0 ;
58
61
index_ = 0 ;
59
62
loader_ = nullptr ;
60
63
bindings_ = nullptr ;
64
+ protobufLoader_ = nullptr ;
61
65
destructed_ = false ;
62
66
}
63
67
64
68
// / Called once after ONNX or Caffe2 model loading.
65
69
virtual void postModelLoad (Loader &loader, PlaceholderBindings &bindings,
70
+ ProtobufLoader &protobufLoader,
71
+ llvm::StringMap<Placeholder *> &outputMap,
66
72
llvm::ArrayRef<TypeRef> inputImageType) {
67
73
68
74
size_t compilationBatchSize = inputImageType[0 ]->dims ()[0 ];
@@ -72,6 +78,7 @@ class testLoaderExtension : public LoaderExtension {
72
78
// To check params are correctly set.
73
79
loader_ = &loader;
74
80
bindings_ = &bindings;
81
+ protobufLoader_ = &protobufLoader;
75
82
EXPECT_EQ (BATCH_SIZE, compilationBatchSize);
76
83
}
77
84
// / Called once at the beginning of the mini-batch inference.
@@ -98,6 +105,22 @@ class testLoaderExtension : public LoaderExtension {
98
105
index_ = minibatchIndex;
99
106
EXPECT_EQ (MINI_BATCH_SIZE, minibatchSize);
100
107
}
108
+ // / Called once after TFLite.
109
+ virtual void postModelLoad (Loader &loader, PlaceholderBindings &bindings,
110
+ TFLiteModelLoader &tfliteLoader,
111
+ llvm::StringMap<Placeholder *> &outputMap,
112
+ llvm::ArrayRef<TypeRef> inputImageType) {
113
+
114
+ size_t compilationBatchSize = inputImageType[0 ]->dims ()[0 ];
115
+ // To check the method was executed.
116
+ stage_ = 4 ;
117
+
118
+ // To check params are correctly set.
119
+ loader_ = &loader;
120
+ bindings_ = &bindings;
121
+ tfliteloader_ = &tfliteLoader;
122
+ EXPECT_EQ (BATCH_SIZE_TFLITE, compilationBatchSize);
123
+ }
101
124
virtual ~testLoaderExtension () { destructed_ = true ; }
102
125
};
103
126
@@ -113,7 +136,8 @@ class secondTestLoaderExtension : public LoaderExtension {
113
136
}
114
137
115
138
// / Called once after ONNX or Caffe2 model loading.
116
- virtual void postModelLoad (Loader &, PlaceholderBindings &,
139
+ virtual void postModelLoad (Loader &, PlaceholderBindings &, ProtobufLoader &,
140
+ llvm::StringMap<Placeholder *> &,
117
141
llvm::ArrayRef<TypeRef> inputImageType) {
118
142
stage_ = 1 ;
119
143
}
@@ -127,6 +151,12 @@ class secondTestLoaderExtension : public LoaderExtension {
127
151
size_t ) {
128
152
stage_ = 3 ;
129
153
}
154
+ virtual void postModelLoad (Loader &, PlaceholderBindings &,
155
+ TFLiteModelLoader &,
156
+ llvm::StringMap<Placeholder *> &,
157
+ llvm::ArrayRef<TypeRef> inputImageType) {
158
+ stage_ = 4 ;
159
+ }
130
160
131
161
virtual ~secondTestLoaderExtension () { destructed_ = true ; }
132
162
};
@@ -136,12 +166,14 @@ int testLoaderExtension::stage_;
136
166
size_t testLoaderExtension::index_;
137
167
Loader *testLoaderExtension::loader_;
138
168
PlaceholderBindings *testLoaderExtension::bindings_;
169
+ ProtobufLoader *testLoaderExtension::protobufLoader_;
170
+ TFLiteModelLoader *testLoaderExtension::tfliteloader_;
139
171
bool testLoaderExtension::destructed_;
140
172
int secondTestLoaderExtension::stage_;
141
173
bool secondTestLoaderExtension::destructed_;
142
174
143
- // / This test simulates what can be a Glow applciation (like image_classifier).
144
- TEST_F (LoaderTest, LoaderExtension ) {
175
+ // / This test simulates what can be a Glow application (like image_classifier).
176
+ TEST_F (LoaderTest, LoaderExtensionCaffe2 ) {
145
177
{
146
178
std::unique_ptr<ExecutionContext> exContext =
147
179
glow::make_unique<ExecutionContext>();
@@ -182,10 +214,11 @@ TEST_F(LoaderTest, LoaderExtension) {
182
214
183
215
// Get bindings and call post model load extensions.
184
216
ASSERT_EQ (testLoaderExtension::stage_, 0 );
185
- loader.postModelLoad (bindings, &inputData.getType ());
217
+ loader.postModelLoad (bindings, caffe2LD, outputMap, &inputData.getType ());
186
218
ASSERT_EQ (testLoaderExtension::stage_, 1 );
187
219
ASSERT_EQ (testLoaderExtension::loader_, &loader);
188
220
ASSERT_EQ (testLoaderExtension::bindings_, &bindings);
221
+ ASSERT_EQ (testLoaderExtension::protobufLoader_, &caffe2LD);
189
222
ASSERT_EQ (secondTestLoaderExtension::stage_, 1 );
190
223
191
224
// Allocate tensors to back all inputs and outputs.
@@ -209,6 +242,171 @@ TEST_F(LoaderTest, LoaderExtension) {
209
242
ASSERT_EQ (testLoaderExtension::index_, miniBatchIndex);
210
243
ASSERT_EQ (testLoaderExtension::loader_, &loader);
211
244
ASSERT_EQ (testLoaderExtension::bindings_, &bindings);
245
+ ASSERT_EQ (testLoaderExtension::protobufLoader_, &caffe2LD);
246
+ ASSERT_EQ (secondTestLoaderExtension::stage_, 2 );
247
+
248
+ // Perform the inference execution for a minibatch.
249
+ loader.runInference (exContext.get (), BATCH_SIZE);
250
+
251
+ // Minibatch inference initialization of loader extensions.
252
+ loader.inferEndMiniBatch (bindings, miniBatchIndex, MINI_BATCH_SIZE);
253
+ ASSERT_EQ (testLoaderExtension::stage_, 3 );
254
+ ASSERT_EQ (testLoaderExtension::index_, miniBatchIndex);
255
+ ASSERT_EQ (testLoaderExtension::loader_, &loader);
256
+ ASSERT_EQ (testLoaderExtension::bindings_, &bindings);
257
+ ASSERT_EQ (testLoaderExtension::protobufLoader_, &caffe2LD);
258
+ ASSERT_EQ (secondTestLoaderExtension::stage_, 3 );
259
+ }
260
+
261
+ // Extension object not destructed yet.
262
+ ASSERT_EQ (testLoaderExtension::destructed_, false );
263
+ ASSERT_EQ (secondTestLoaderExtension::destructed_, false );
264
+ } // End of the loader scope.
265
+
266
+ // Check that extensions were properly destructed by the Loader destruction.
267
+ ASSERT_EQ (testLoaderExtension::destructed_, true );
268
+ ASSERT_EQ (secondTestLoaderExtension::destructed_, true );
269
+ }
270
+
271
+ TEST_F (LoaderTest, LoaderExtensionTFlite) {
272
+ {
273
+ std::unique_ptr<ExecutionContext> exContext =
274
+ glow::make_unique<ExecutionContext>();
275
+ PlaceholderBindings &bindings = *exContext->getPlaceholderBindings ();
276
+ llvm::StringMap<Placeholder *> outputMap;
277
+
278
+ // Create a loader object.
279
+ Loader loader;
280
+
281
+ // Register Loader extensions.
282
+ loader.registerExtension (
283
+ std::unique_ptr<LoaderExtension>(new testLoaderExtension ()));
284
+ loader.registerExtension (
285
+ std::unique_ptr<LoaderExtension>(new secondTestLoaderExtension ()));
286
+
287
+ // Load a model
288
+ std::string NetFilename (GLOW_DATA_PATH
289
+ " tests/models/tfliteModels/abs.tflite" );
290
+
291
+ Tensor inputData (ElemKind::FloatTy, {BATCH_SIZE_TFLITE, 10 });
292
+ TFLiteModelLoader LD (NetFilename, loader.getFunction ());
293
+
294
+ // Check the model was loaded.
295
+ EXPECT_EQ (loader.getFunction ()->getNodes ().size (), 2 );
296
+
297
+ // Get bindings and call post model load extensions.
298
+ ASSERT_EQ (testLoaderExtension::stage_, 0 );
299
+ loader.postModelLoad (bindings, LD, outputMap, &inputData.getType ());
300
+ ASSERT_EQ (testLoaderExtension::stage_, 4 );
301
+ ASSERT_EQ (testLoaderExtension::loader_, &loader);
302
+ ASSERT_EQ (testLoaderExtension::bindings_, &bindings);
303
+ ASSERT_EQ (testLoaderExtension::tfliteloader_, &LD);
304
+ ASSERT_EQ (secondTestLoaderExtension::stage_, 4 );
305
+
306
+ // Allocate tensors to back all inputs and outputs.
307
+ bindings.allocate (loader.getModule ()->getPlaceholders ());
308
+
309
+ // Compile the model.
310
+ CompilationContext cctx = loader.getCompilationContext ();
311
+ cctx.bindings = &bindings;
312
+ loader.compile (cctx);
313
+
314
+ // Load data to input placeholders.
315
+ updateInputPlaceholdersByName (bindings, loader.getModule (), {" input" },
316
+ {&inputData});
317
+
318
+ // Run mini-batches.
319
+ for (size_t miniBatchIndex = 0 ; miniBatchIndex < BATCH_SIZE;
320
+ miniBatchIndex += MINI_BATCH_SIZE) {
321
+ // Minibatch inference initialization of loader extensions.
322
+ loader.inferInitMiniBatch (bindings, miniBatchIndex, MINI_BATCH_SIZE);
323
+ ASSERT_EQ (testLoaderExtension::stage_, 2 );
324
+ ASSERT_EQ (testLoaderExtension::index_, miniBatchIndex);
325
+ ASSERT_EQ (testLoaderExtension::loader_, &loader);
326
+ ASSERT_EQ (testLoaderExtension::bindings_, &bindings);
327
+ ASSERT_EQ (testLoaderExtension::tfliteloader_, &LD);
328
+ ASSERT_EQ (secondTestLoaderExtension::stage_, 2 );
329
+
330
+ // Perform the inference execution for a minibatch.
331
+ loader.runInference (exContext.get (), BATCH_SIZE);
332
+
333
+ // Minibatch inference initialization of loader extensions.
334
+ loader.inferEndMiniBatch (bindings, miniBatchIndex, MINI_BATCH_SIZE);
335
+ ASSERT_EQ (testLoaderExtension::stage_, 3 );
336
+ ASSERT_EQ (testLoaderExtension::index_, miniBatchIndex);
337
+ ASSERT_EQ (testLoaderExtension::loader_, &loader);
338
+ ASSERT_EQ (testLoaderExtension::bindings_, &bindings);
339
+ ASSERT_EQ (testLoaderExtension::tfliteloader_, &LD);
340
+ ASSERT_EQ (secondTestLoaderExtension::stage_, 3 );
341
+ }
342
+
343
+ // Extension object not destructed yet.
344
+ ASSERT_EQ (testLoaderExtension::destructed_, false );
345
+ ASSERT_EQ (secondTestLoaderExtension::destructed_, false );
346
+ } // End of the loader scope.
347
+
348
+ // Check that extensions were properly destructed by the Loader destruction.
349
+ ASSERT_EQ (testLoaderExtension::destructed_, true );
350
+ }
351
+
352
+ TEST_F (LoaderTest, LoaderExtensionONNX) {
353
+ {
354
+ std::unique_ptr<ExecutionContext> exContext =
355
+ glow::make_unique<ExecutionContext>();
356
+ PlaceholderBindings &bindings = *exContext->getPlaceholderBindings ();
357
+ llvm::StringMap<Placeholder *> outputMap;
358
+
359
+ // Create a loader object.
360
+ Loader loader;
361
+
362
+ // Register Loader extensions.
363
+ loader.registerExtension (
364
+ std::unique_ptr<LoaderExtension>(new testLoaderExtension ()));
365
+ loader.registerExtension (
366
+ std::unique_ptr<LoaderExtension>(new secondTestLoaderExtension ()));
367
+
368
+ // Load a model
369
+ std::string NetFilename (GLOW_DATA_PATH
370
+ " tests/models/onnxModels/clip.onnxtxt" );
371
+
372
+ Tensor inputData (ElemKind::FloatTy, {BATCH_SIZE, 3 });
373
+ ONNXModelLoader LD (NetFilename, {" x" }, {&inputData.getType ()},
374
+ *loader.getFunction ());
375
+
376
+ // Check the model was loaded.
377
+ EXPECT_EQ (loader.getFunction ()->getNodes ().size (), 2 );
378
+
379
+ // Get bindings and call post model load extensions.
380
+ ASSERT_EQ (testLoaderExtension::stage_, 0 );
381
+ loader.postModelLoad (bindings, LD, outputMap, &inputData.getType ());
382
+ ASSERT_EQ (testLoaderExtension::stage_, 1 );
383
+ ASSERT_EQ (testLoaderExtension::loader_, &loader);
384
+ ASSERT_EQ (testLoaderExtension::bindings_, &bindings);
385
+ ASSERT_EQ (testLoaderExtension::protobufLoader_, &LD);
386
+ ASSERT_EQ (secondTestLoaderExtension::stage_, 1 );
387
+
388
+ // Allocate tensors to back all inputs and outputs.
389
+ bindings.allocate (loader.getModule ()->getPlaceholders ());
390
+
391
+ // Compile the model.
392
+ CompilationContext cctx = loader.getCompilationContext ();
393
+ cctx.bindings = &bindings;
394
+ loader.compile (cctx);
395
+
396
+ // Load data to input placeholders.
397
+ updateInputPlaceholdersByName (bindings, loader.getModule (), {" x" },
398
+ {&inputData});
399
+
400
+ // Run mini-batches.
401
+ for (size_t miniBatchIndex = 0 ; miniBatchIndex < BATCH_SIZE;
402
+ miniBatchIndex += MINI_BATCH_SIZE) {
403
+ // Minibatch inference initialization of loader extensions.
404
+ loader.inferInitMiniBatch (bindings, miniBatchIndex, MINI_BATCH_SIZE);
405
+ ASSERT_EQ (testLoaderExtension::stage_, 2 );
406
+ ASSERT_EQ (testLoaderExtension::index_, miniBatchIndex);
407
+ ASSERT_EQ (testLoaderExtension::loader_, &loader);
408
+ ASSERT_EQ (testLoaderExtension::bindings_, &bindings);
409
+ ASSERT_EQ (testLoaderExtension::protobufLoader_, &LD);
212
410
ASSERT_EQ (secondTestLoaderExtension::stage_, 2 );
213
411
214
412
// Perform the inference execution for a minibatch.
@@ -220,6 +418,7 @@ TEST_F(LoaderTest, LoaderExtension) {
220
418
ASSERT_EQ (testLoaderExtension::index_, miniBatchIndex);
221
419
ASSERT_EQ (testLoaderExtension::loader_, &loader);
222
420
ASSERT_EQ (testLoaderExtension::bindings_, &bindings);
421
+ ASSERT_EQ (testLoaderExtension::protobufLoader_, &LD);
223
422
ASSERT_EQ (secondTestLoaderExtension::stage_, 3 );
224
423
}
225
424
0 commit comments