Skip to content

Commit 682bb08

Browse files
Pass loader and protobuf loader to postmodel load. (#5325)
Summary: Enabling passing loader to the postModelLoader, for customer-specific processing. With TFLite not being protobuf based, this required changes to postModelLoad call. Tests added to ImageLoaderTest. Pull Request resolved: #5325 Reviewed By: hl475 Differential Revision: D27605168 Pulled By: jfix71 fbshipit-source-id: 85488c9bc02505c11a53e71a63e8a43dbeaed8eb
1 parent 2d2b99c commit 682bb08

File tree

4 files changed

+251
-17
lines changed

4 files changed

+251
-17
lines changed

tests/unittests/LoaderTest.cpp

Lines changed: 205 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "glow/ExecutionEngine/ExecutionEngine.h"
2020
#include "glow/Graph/Graph.h"
2121
#include "glow/Importer/Caffe2ModelLoader.h"
22-
22+
#include "glow/Importer/ONNXModelLoader.h"
2323
#include "gtest/gtest.h"
2424

2525
#include "llvm/ADT/StringMap.h"
@@ -40,7 +40,8 @@ class LoaderTest : public ::testing::Test {
4040
using namespace glow;
4141

4242
namespace {
43-
const dim_t BATCH_SIZE = 8;
43+
const dim_t BATCH_SIZE = 3;
44+
const dim_t BATCH_SIZE_TFLITE = 1;
4445
const size_t MINI_BATCH_SIZE = 2;
4546
} // namespace
4647

@@ -51,18 +52,23 @@ class testLoaderExtension : public LoaderExtension {
5152
static size_t index_;
5253
static Loader *loader_;
5354
static PlaceholderBindings *bindings_;
55+
static ProtobufLoader *protobufLoader_;
56+
static TFLiteModelLoader *tfliteloader_;
5457
static bool destructed_;
5558

5659
testLoaderExtension() {
5760
stage_ = 0;
5861
index_ = 0;
5962
loader_ = nullptr;
6063
bindings_ = nullptr;
64+
protobufLoader_ = nullptr;
6165
destructed_ = false;
6266
}
6367

6468
/// Called once after ONNX or Caffe2 model loading.
6569
virtual void postModelLoad(Loader &loader, PlaceholderBindings &bindings,
70+
ProtobufLoader &protobufLoader,
71+
llvm::StringMap<Placeholder *> &outputMap,
6672
llvm::ArrayRef<TypeRef> inputImageType) {
6773

6874
size_t compilationBatchSize = inputImageType[0]->dims()[0];
@@ -72,6 +78,7 @@ class testLoaderExtension : public LoaderExtension {
7278
// To check params are correctly set.
7379
loader_ = &loader;
7480
bindings_ = &bindings;
81+
protobufLoader_ = &protobufLoader;
7582
EXPECT_EQ(BATCH_SIZE, compilationBatchSize);
7683
}
7784
/// Called once at the beginning of the mini-batch inference.
@@ -98,6 +105,22 @@ class testLoaderExtension : public LoaderExtension {
98105
index_ = minibatchIndex;
99106
EXPECT_EQ(MINI_BATCH_SIZE, minibatchSize);
100107
}
108+
/// Called once after TFLite.
109+
virtual void postModelLoad(Loader &loader, PlaceholderBindings &bindings,
110+
TFLiteModelLoader &tfliteLoader,
111+
llvm::StringMap<Placeholder *> &outputMap,
112+
llvm::ArrayRef<TypeRef> inputImageType) {
113+
114+
size_t compilationBatchSize = inputImageType[0]->dims()[0];
115+
// To check the method was executed.
116+
stage_ = 4;
117+
118+
// To check params are correctly set.
119+
loader_ = &loader;
120+
bindings_ = &bindings;
121+
tfliteloader_ = &tfliteLoader;
122+
EXPECT_EQ(BATCH_SIZE_TFLITE, compilationBatchSize);
123+
}
101124
virtual ~testLoaderExtension() { destructed_ = true; }
102125
};
103126

@@ -113,7 +136,8 @@ class secondTestLoaderExtension : public LoaderExtension {
113136
}
114137

115138
/// Called once after ONNX or Caffe2 model loading.
116-
virtual void postModelLoad(Loader &, PlaceholderBindings &,
139+
virtual void postModelLoad(Loader &, PlaceholderBindings &, ProtobufLoader &,
140+
llvm::StringMap<Placeholder *> &,
117141
llvm::ArrayRef<TypeRef> inputImageType) {
118142
stage_ = 1;
119143
}
@@ -127,6 +151,12 @@ class secondTestLoaderExtension : public LoaderExtension {
127151
size_t) {
128152
stage_ = 3;
129153
}
154+
virtual void postModelLoad(Loader &, PlaceholderBindings &,
155+
TFLiteModelLoader &,
156+
llvm::StringMap<Placeholder *> &,
157+
llvm::ArrayRef<TypeRef> inputImageType) {
158+
stage_ = 4;
159+
}
130160

131161
virtual ~secondTestLoaderExtension() { destructed_ = true; }
132162
};
@@ -136,12 +166,14 @@ int testLoaderExtension::stage_;
136166
size_t testLoaderExtension::index_;
137167
Loader *testLoaderExtension::loader_;
138168
PlaceholderBindings *testLoaderExtension::bindings_;
169+
ProtobufLoader *testLoaderExtension::protobufLoader_;
170+
TFLiteModelLoader *testLoaderExtension::tfliteloader_;
139171
bool testLoaderExtension::destructed_;
140172
int secondTestLoaderExtension::stage_;
141173
bool secondTestLoaderExtension::destructed_;
142174

143-
/// This test simulates what can be a Glow applciation (like image_classifier).
144-
TEST_F(LoaderTest, LoaderExtension) {
175+
/// This test simulates what can be a Glow application (like image_classifier).
176+
TEST_F(LoaderTest, LoaderExtensionCaffe2) {
145177
{
146178
std::unique_ptr<ExecutionContext> exContext =
147179
glow::make_unique<ExecutionContext>();
@@ -182,10 +214,11 @@ TEST_F(LoaderTest, LoaderExtension) {
182214

183215
// Get bindings and call post model load extensions.
184216
ASSERT_EQ(testLoaderExtension::stage_, 0);
185-
loader.postModelLoad(bindings, &inputData.getType());
217+
loader.postModelLoad(bindings, caffe2LD, outputMap, &inputData.getType());
186218
ASSERT_EQ(testLoaderExtension::stage_, 1);
187219
ASSERT_EQ(testLoaderExtension::loader_, &loader);
188220
ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
221+
ASSERT_EQ(testLoaderExtension::protobufLoader_, &caffe2LD);
189222
ASSERT_EQ(secondTestLoaderExtension::stage_, 1);
190223

191224
// Allocate tensors to back all inputs and outputs.
@@ -209,6 +242,171 @@ TEST_F(LoaderTest, LoaderExtension) {
209242
ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex);
210243
ASSERT_EQ(testLoaderExtension::loader_, &loader);
211244
ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
245+
ASSERT_EQ(testLoaderExtension::protobufLoader_, &caffe2LD);
246+
ASSERT_EQ(secondTestLoaderExtension::stage_, 2);
247+
248+
// Perform the inference execution for a minibatch.
249+
loader.runInference(exContext.get(), BATCH_SIZE);
250+
251+
// Minibatch inference initialization of loader extensions.
252+
loader.inferEndMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE);
253+
ASSERT_EQ(testLoaderExtension::stage_, 3);
254+
ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex);
255+
ASSERT_EQ(testLoaderExtension::loader_, &loader);
256+
ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
257+
ASSERT_EQ(testLoaderExtension::protobufLoader_, &caffe2LD);
258+
ASSERT_EQ(secondTestLoaderExtension::stage_, 3);
259+
}
260+
261+
// Extension object not destructed yet.
262+
ASSERT_EQ(testLoaderExtension::destructed_, false);
263+
ASSERT_EQ(secondTestLoaderExtension::destructed_, false);
264+
} // End of the loader scope.
265+
266+
// Check that extensions were properly destructed by the Loader destruction.
267+
ASSERT_EQ(testLoaderExtension::destructed_, true);
268+
ASSERT_EQ(secondTestLoaderExtension::destructed_, true);
269+
}
270+
271+
TEST_F(LoaderTest, LoaderExtensionTFlite) {
272+
{
273+
std::unique_ptr<ExecutionContext> exContext =
274+
glow::make_unique<ExecutionContext>();
275+
PlaceholderBindings &bindings = *exContext->getPlaceholderBindings();
276+
llvm::StringMap<Placeholder *> outputMap;
277+
278+
// Create a loader object.
279+
Loader loader;
280+
281+
// Register Loader extensions.
282+
loader.registerExtension(
283+
std::unique_ptr<LoaderExtension>(new testLoaderExtension()));
284+
loader.registerExtension(
285+
std::unique_ptr<LoaderExtension>(new secondTestLoaderExtension()));
286+
287+
// Load a model
288+
std::string NetFilename(GLOW_DATA_PATH
289+
"tests/models/tfliteModels/abs.tflite");
290+
291+
Tensor inputData(ElemKind::FloatTy, {BATCH_SIZE_TFLITE, 10});
292+
TFLiteModelLoader LD(NetFilename, loader.getFunction());
293+
294+
// Check the model was loaded.
295+
EXPECT_EQ(loader.getFunction()->getNodes().size(), 2);
296+
297+
// Get bindings and call post model load extensions.
298+
ASSERT_EQ(testLoaderExtension::stage_, 0);
299+
loader.postModelLoad(bindings, LD, outputMap, &inputData.getType());
300+
ASSERT_EQ(testLoaderExtension::stage_, 4);
301+
ASSERT_EQ(testLoaderExtension::loader_, &loader);
302+
ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
303+
ASSERT_EQ(testLoaderExtension::tfliteloader_, &LD);
304+
ASSERT_EQ(secondTestLoaderExtension::stage_, 4);
305+
306+
// Allocate tensors to back all inputs and outputs.
307+
bindings.allocate(loader.getModule()->getPlaceholders());
308+
309+
// Compile the model.
310+
CompilationContext cctx = loader.getCompilationContext();
311+
cctx.bindings = &bindings;
312+
loader.compile(cctx);
313+
314+
// Load data to input placeholders.
315+
updateInputPlaceholdersByName(bindings, loader.getModule(), {"input"},
316+
{&inputData});
317+
318+
// Run mini-batches.
319+
for (size_t miniBatchIndex = 0; miniBatchIndex < BATCH_SIZE;
320+
miniBatchIndex += MINI_BATCH_SIZE) {
321+
// Minibatch inference initialization of loader extensions.
322+
loader.inferInitMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE);
323+
ASSERT_EQ(testLoaderExtension::stage_, 2);
324+
ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex);
325+
ASSERT_EQ(testLoaderExtension::loader_, &loader);
326+
ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
327+
ASSERT_EQ(testLoaderExtension::tfliteloader_, &LD);
328+
ASSERT_EQ(secondTestLoaderExtension::stage_, 2);
329+
330+
// Perform the inference execution for a minibatch.
331+
loader.runInference(exContext.get(), BATCH_SIZE);
332+
333+
// Minibatch inference initialization of loader extensions.
334+
loader.inferEndMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE);
335+
ASSERT_EQ(testLoaderExtension::stage_, 3);
336+
ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex);
337+
ASSERT_EQ(testLoaderExtension::loader_, &loader);
338+
ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
339+
ASSERT_EQ(testLoaderExtension::tfliteloader_, &LD);
340+
ASSERT_EQ(secondTestLoaderExtension::stage_, 3);
341+
}
342+
343+
// Extension object not destructed yet.
344+
ASSERT_EQ(testLoaderExtension::destructed_, false);
345+
ASSERT_EQ(secondTestLoaderExtension::destructed_, false);
346+
} // End of the loader scope.
347+
348+
// Check that extensions were properly destructed by the Loader destruction.
349+
ASSERT_EQ(testLoaderExtension::destructed_, true);
350+
}
351+
352+
TEST_F(LoaderTest, LoaderExtensionONNX) {
353+
{
354+
std::unique_ptr<ExecutionContext> exContext =
355+
glow::make_unique<ExecutionContext>();
356+
PlaceholderBindings &bindings = *exContext->getPlaceholderBindings();
357+
llvm::StringMap<Placeholder *> outputMap;
358+
359+
// Create a loader object.
360+
Loader loader;
361+
362+
// Register Loader extensions.
363+
loader.registerExtension(
364+
std::unique_ptr<LoaderExtension>(new testLoaderExtension()));
365+
loader.registerExtension(
366+
std::unique_ptr<LoaderExtension>(new secondTestLoaderExtension()));
367+
368+
// Load a model
369+
std::string NetFilename(GLOW_DATA_PATH
370+
"tests/models/onnxModels/clip.onnxtxt");
371+
372+
Tensor inputData(ElemKind::FloatTy, {BATCH_SIZE, 3});
373+
ONNXModelLoader LD(NetFilename, {"x"}, {&inputData.getType()},
374+
*loader.getFunction());
375+
376+
// Check the model was loaded.
377+
EXPECT_EQ(loader.getFunction()->getNodes().size(), 2);
378+
379+
// Get bindings and call post model load extensions.
380+
ASSERT_EQ(testLoaderExtension::stage_, 0);
381+
loader.postModelLoad(bindings, LD, outputMap, &inputData.getType());
382+
ASSERT_EQ(testLoaderExtension::stage_, 1);
383+
ASSERT_EQ(testLoaderExtension::loader_, &loader);
384+
ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
385+
ASSERT_EQ(testLoaderExtension::protobufLoader_, &LD);
386+
ASSERT_EQ(secondTestLoaderExtension::stage_, 1);
387+
388+
// Allocate tensors to back all inputs and outputs.
389+
bindings.allocate(loader.getModule()->getPlaceholders());
390+
391+
// Compile the model.
392+
CompilationContext cctx = loader.getCompilationContext();
393+
cctx.bindings = &bindings;
394+
loader.compile(cctx);
395+
396+
// Load data to input placeholders.
397+
updateInputPlaceholdersByName(bindings, loader.getModule(), {"x"},
398+
{&inputData});
399+
400+
// Run mini-batches.
401+
for (size_t miniBatchIndex = 0; miniBatchIndex < BATCH_SIZE;
402+
miniBatchIndex += MINI_BATCH_SIZE) {
403+
// Minibatch inference initialization of loader extensions.
404+
loader.inferInitMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE);
405+
ASSERT_EQ(testLoaderExtension::stage_, 2);
406+
ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex);
407+
ASSERT_EQ(testLoaderExtension::loader_, &loader);
408+
ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
409+
ASSERT_EQ(testLoaderExtension::protobufLoader_, &LD);
212410
ASSERT_EQ(secondTestLoaderExtension::stage_, 2);
213411

214412
// Perform the inference execution for a minibatch.
@@ -220,6 +418,7 @@ TEST_F(LoaderTest, LoaderExtension) {
220418
ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex);
221419
ASSERT_EQ(testLoaderExtension::loader_, &loader);
222420
ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
421+
ASSERT_EQ(testLoaderExtension::protobufLoader_, &LD);
223422
ASSERT_EQ(secondTestLoaderExtension::stage_, 3);
224423
}
225424

tools/loader/ExecutorCoreHelperFunctions.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,7 @@ std::pair<llvm::StringMap<Placeholder *>, llvm::StringMap<Placeholder *>>
319319
buildAndCompileAndGetInAndOutPair(Loader &loader, PlaceholderBindings &bindings,
320320
llvm::ArrayRef<TypeRef> inputImageType) {
321321
// Load model.
322-
loader.loadModel(inputImageType);
323-
324-
// Post model loader transformation.
325-
loader.postModelLoad(bindings, inputImageType);
322+
loader.loadModel(&bindings, inputImageType);
326323

327324
// Allocate tensors to back all inputs and outputs.
328325
bindings.allocate(loader.getModule()->getPlaceholders());

tools/loader/Loader.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ static void getModelInputs(std::vector<std::string> &inputNames,
438438
}
439439
}
440440

441-
void Loader::loadModel(llvm::ArrayRef<TypeRef> inputType) {
441+
void Loader::loadModel(PlaceholderBindings *bindings,
442+
llvm::ArrayRef<TypeRef> inputType) {
442443

443444
// Get model input names and types.
444445
std::vector<std::string> inputNames;
@@ -467,17 +468,25 @@ void Loader::loadModel(llvm::ArrayRef<TypeRef> inputType) {
467468
// Load the maps between original model names and the placeholders.
468469
inputPlaceholderByName_ = protoLoader->getInputVarsMapping();
469470
outputPlaceholderByName_ = protoLoader->getOutputVarsMapping();
471+
if (bindings) {
472+
postModelLoad(*bindings, *protoLoader.get(), outputPlaceholderByName_,
473+
inputType);
474+
}
470475
} else if (!getTFLiteModelFilename().empty()) {
471476
// For TensorFlowLite format the input placeholder names/types are not
472477
// provided since are used directly from the model.
473-
auto tfliteLoader =
474-
TFLiteModelLoader(getTFLiteModelFilename(), getFunction());
478+
auto tfliteLoader = glow::make_unique<TFLiteModelLoader>(
479+
getTFLiteModelFilename(), getFunction());
475480
// Load the maps between original model names and the placeholders.
476-
inputPlaceholderByName_ = tfliteLoader.getInputPlaceholderMap();
477-
outputPlaceholderByName_ = tfliteLoader.getOutputPlaceholderMap();
481+
inputPlaceholderByName_ = tfliteLoader->getInputPlaceholderMap();
482+
outputPlaceholderByName_ = tfliteLoader->getOutputPlaceholderMap();
478483
// Since TensorFlowLite loader currently does not have the capability to
479484
// enforce the input type (for batching) we must validate that when the
480485
// input type is explicitly given it actually matches the model input type.
486+
if (bindings) {
487+
postModelLoad(*bindings, *tfliteLoader, outputPlaceholderByName_,
488+
inputType);
489+
}
481490
if (inputType.size()) {
482491
CHECK(inputPlaceholderByName_.size() == 1)
483492
<< "Model is expected to have only 1 input!";
@@ -503,6 +512,10 @@ void Loader::loadModel(llvm::ArrayRef<TypeRef> inputType) {
503512
// Load the maps between original model names and the placeholders.
504513
inputPlaceholderByName_ = protoLoader->getInputVarsMapping();
505514
outputPlaceholderByName_ = protoLoader->getOutputVarsMapping();
515+
if (bindings) {
516+
postModelLoad(*bindings, *protoLoader.get(), outputPlaceholderByName_,
517+
inputType);
518+
}
506519
}
507520
}
508521

@@ -820,9 +833,22 @@ Loader &Loader::registerExtension(std::unique_ptr<LoaderExtension> extension) {
820833
}
821834

822835
void Loader::postModelLoad(PlaceholderBindings &bindings,
836+
ProtobufLoader &protoLoader,
837+
llvm::StringMap<Placeholder *> &placeholderMap,
838+
llvm::ArrayRef<TypeRef> inputImageType) {
839+
for (auto &&ext : loaderExtensionList_) {
840+
ext->postModelLoad(*this, bindings, protoLoader, placeholderMap,
841+
inputImageType);
842+
}
843+
}
844+
845+
void Loader::postModelLoad(PlaceholderBindings &bindings,
846+
TFLiteModelLoader &tfloader,
847+
llvm::StringMap<Placeholder *> &placeholderMap,
823848
llvm::ArrayRef<TypeRef> inputImageType) {
824849
for (auto &&ext : loaderExtensionList_) {
825-
ext->postModelLoad(*this, bindings, inputImageType);
850+
ext->postModelLoad(*this, bindings, tfloader, placeholderMap,
851+
inputImageType);
826852
}
827853
}
828854

0 commit comments

Comments
 (0)