Skip to content

Commit 2a39b3c

Browse files
committed
[PH] Port the C2/ONNX loader to using Placeholders.
1 parent 356c17f commit 2a39b3c

20 files changed

+260
-190
lines changed

include/glow/ExecutionEngine/ExecutionEngine.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ void updateVariables(llvm::ArrayRef<Variable *> vars,
9797
void updateVariables(Context &ctx, llvm::ArrayRef<Placeholder *> ph,
9898
llvm::ArrayRef<Tensor *> inputs);
9999

100+
/// This method updates the placeholders in the module. The placeholders are
101+
/// found by name
102+
/// in \p ph with the tensor content values \p inputs.
103+
void updateInputsByName(Context &ctx, Module *mod,
104+
llvm::ArrayRef<llvm::StringRef> ph,
105+
llvm::ArrayRef<Tensor *> inputs);
106+
100107
/// Update the content of the tensors \p vars with some slices that from \p
101108
/// inputs. The data starts at slice \p sampleIdx and wraps around until the
102109
/// data in \p v is filled. All dimensions, except for the first (batch)

include/glow/Importer/Caffe2ModelLoader.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ class Caffe2ModelLoader
6363
/// Loads the caffe2 model that's represented by a network description file,
6464
/// serialized in \p netDescFilename, and weights file, serialized in
6565
/// \p netWeightFilename, and populates the network in \p F.
66-
/// The tensors in \p tensors are stored with the names in the list of names
67-
/// \p names and used as inputs to the network.
66+
/// The list \p types and \p names are used to initialized the inputs and
67+
/// outputs with specific names and types.
6868
Caffe2ModelLoader(const std::string &netDescFilename,
6969
const std::string &netWeightFilename,
7070
llvm::ArrayRef<const char *> names,
71-
llvm::ArrayRef<Tensor *> tensors, Function &F);
71+
llvm::ArrayRef<TypeRef> types, Function &F);
7272
};
7373

7474
} // namespace glow

include/glow/Importer/CommonOperatorLoader.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class CommonOperatorLoader : public ProtobufLoader {
8080
T->template getHandle<int64_t>() =
8181
std::vector<int64_t>(in.dims().begin(), in.dims().end());
8282

83-
createAndRememberVariable(opName, *T);
83+
createAndRegisterConstant(opName, *T);
8484
}
8585

8686
/// Loads Sqrt operator, given its protobuf representation and parsed args.

include/glow/Importer/ONNXIFIModelLoader.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@ class ONNXIFIModelLoader : public ONNXModelLoader {
3838
const onnxTensorDescriptorV1 *weightDescriptors);
3939

4040
/// Mapping between ONNX names for inputs and actual Glow input vars.
41-
llvm::StringMap<Variable *> onnxNameToInputVars_;
41+
llvm::StringMap<Placeholder *> onnxNameToInputVars_;
4242

4343
public:
4444
/// \returns mapping between ONNX names and actual Glow input vars.
45-
const llvm::StringMap<Variable *> &getInputVarsMapping() const {
45+
const llvm::StringMap<Placeholder *> &getInputVarsMapping() const {
4646
return onnxNameToInputVars_;
4747
}
4848

4949
/// \returns mapping between ONNX names and actual Glow output nodes.
50-
const llvm::StringMap<Variable *> &getOutputVarsMapping() const {
50+
const llvm::StringMap<Placeholder *> &getOutputVarsMapping() const {
5151
return outputVarsByName_;
5252
}
5353

include/glow/Importer/ONNXModelLoader.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,19 @@ class ONNXModelLoader
9090
size_t onnxModelSize);
9191

9292
/// Checks that the inputs tensors are compatible with the inputs declared in
93-
/// the ONNX model. The input tensors in \p tensors are stored with the names
94-
/// in the list of names \p tensorNames.
93+
/// the ONNX model. The input types in \p types match the list of names
94+
/// \p tensorNames.
9595
void checkInputs(ONNX_NAMESPACE::GraphProto &net,
9696
llvm::ArrayRef<const char *> tensorNames,
97-
llvm::ArrayRef<Tensor *> tensors);
97+
llvm::ArrayRef<TypeRef> types);
9898

9999
/// Loads the ONNX model that's represented by a model description file,
100100
/// serialized in \p modelDescFilename and populates the network into \p F.
101-
/// The tensors in \p tensors are stored with the names in the list of names
102-
/// \p tensorNames and used as inputs to the network.
101+
/// The types in \p types match the list of names \p tensorNames and used as
102+
/// inputs to the network.
103103
ONNXModelLoader(const std::string &modelDescFilename,
104104
llvm::ArrayRef<const char *> tensorNames,
105-
llvm::ArrayRef<Tensor *> tensors, Function &F);
105+
llvm::ArrayRef<TypeRef> types, Function &F);
106106
};
107107

108108
} // namespace glow

include/glow/Importer/ProtobufLoader.h

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,19 @@ class ProtobufLoader {
9999
/// A list of weight tensors indexed by name.
100100
llvm::StringMap<Tensor *> tensors_;
101101
/// A map from names of the external outputs of the network to Variables.
102-
llvm::StringMap<Variable *> outputVarsByName_;
102+
llvm::StringMap<Placeholder *> outputVarsByName_;
103103

104104
/// \returns the tensor that was registered under the name \p name.
105105
Tensor *getTensorByName(llvm::StringRef name);
106106

107-
/// Create a new variable \p name initialized with \p tensor.
108-
/// \returns The newly created variable.
109-
/// \pre !hasNodeByName(name)
110-
Variable *createAndRememberVariable(
111-
llvm::StringRef name, const Tensor &tensor,
112-
VisibilityKind visibilityKind = VisibilityKind::Private);
107+
/// Create a new constant that's initialized with \p tensor, and register it
108+
/// under the name \p name. \returns The newly created constant.
109+
Variable *createAndRegisterConstant(llvm::StringRef name,
110+
const Tensor &tensor);
111+
112+
/// Create a new Placeholder of type \p T, and register it
113+
/// under the name \p name. \returns The newly created placeholder.
114+
Placeholder *createAndRegisterPlaceholder(llvm::StringRef name, TypeRef T);
113115

114116
/// \returns the NodeValue that was registered with the name \p name or
115117
/// a nullptr wrapped in a NodeValue if no node has been registered with this
@@ -130,26 +132,25 @@ class ProtobufLoader {
130132
bool hasNodeByName(llvm::StringRef name) const;
131133

132134
/// Constructs new ProtobufLoader object. It will populate the network into \p
133-
/// F. The tensors in \p tensors are stored with the names in the list of
134-
/// names \p tensorNames and used as inputs to the network.
135+
/// F. The list \p types and \p names are used to initialized the inputs and
136+
/// outputs with specific names and types.
135137
ProtobufLoader(llvm::ArrayRef<const char *> tensorNames,
136-
llvm::ArrayRef<Tensor *> tensors, Function &F);
138+
llvm::ArrayRef<TypeRef> types, Function &F);
137139

138140
virtual ~ProtobufLoader();
139141

140-
/// \returns the single final output Variable of the network. The function
141-
/// assumes there is only one output, verified via assertion. For image
142+
/// \returns the single final output of the network. The function assumes that
143+
/// there is only one output, verified via assertion. For image
142144
/// classification, this single final output is usually the result of the last
143145
/// softmax or regression layer.
144-
/// \pre outputVarsByName_.size() == 1
145-
Variable *getSingleOutput() {
146+
Placeholder *getSingleOutput() {
146147
assert(outputVarsByName_.size() == 1);
147148
return outputVarsByName_.begin()->second;
148149
}
149150

150-
/// \returns the Variable for the external output with \p name.
151+
/// \returns the Placeholder for the external output with \p name.
151152
/// \pre outputVarsByName_.find(name) != outputVarsByName_.end()
152-
Variable *getOutputByName(llvm::StringRef name) const;
153+
Placeholder *getOutputByName(llvm::StringRef name) const;
153154
};
154155

155156
} // namespace glow

lib/ExecutionEngine/ExecutionEngine.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,21 @@ void glow::updateVariables(Context &ctx, llvm::ArrayRef<Placeholder *> ph,
7979
}
8080
}
8181

82+
void glow::updateInputsByName(Context &ctx, Module *mod,
83+
llvm::ArrayRef<llvm::StringRef> ph,
84+
llvm::ArrayRef<Tensor *> inputs) {
85+
assert(inputs.size() == ph.size() &&
86+
"The number of inputs does not match the number of Placeholders");
87+
88+
for (int i = 0, e = ph.size(); i < e; i++) {
89+
Placeholder *p = mod->getPlaceholderByName(ph[i]);
90+
Tensor *t = inputs[i];
91+
assert(t && "Invalid tensor.");
92+
assert(p && "Invalid placeholder.");
93+
updateVariables(ctx, {p}, {t});
94+
}
95+
}
96+
8297
void ExecutionEngine::run() {
8398
assert(function_ && "No function has been compiled");
8499
function_->execute();

lib/Importer/Caffe2ModelLoader.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,8 @@ void Caffe2ModelLoader::loadNetwork(caffe2::NetDef &net) {
619619
for (int i = 0; i < net.external_output_size(); i++) {
620620
auto &outputName = net.external_output(i);
621621
auto r = getNodeValueByName(outputName);
622-
auto *SN = G_.createSave("save_" + outputName, r);
623-
outputVarsByName_[outputName] = SN->getVariable();
622+
auto *SN = G_.createSavePH("save_" + outputName, r);
623+
outputVarsByName_[outputName] = SN->getPlaceholder();
624624
}
625625
}
626626

@@ -796,9 +796,9 @@ void Caffe2ModelLoader::loadWeights(caffe2::NetDef &net) {
796796
Caffe2ModelLoader::Caffe2ModelLoader(const std::string &netDescFilename,
797797
const std::string &netWeightFilename,
798798
llvm::ArrayRef<const char *> names,
799-
llvm::ArrayRef<Tensor *> tensors,
799+
llvm::ArrayRef<TypeRef> types,
800800
Function &F)
801-
: CommonOperatorLoader(names, tensors, F) {
801+
: CommonOperatorLoader(names, types, F) {
802802
// The caffe2 weights that we are deserializing.
803803
caffe2::NetDef weightsDef;
804804
// The caffe2 network descriptor that we are deserializing.

lib/Importer/ONNXIFIModelLoader.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ void ONNXIFIModelLoader::loadInputs(ONNX_NAMESPACE::GraphProto &net) {
4848

4949
Tensor *T = new Tensor();
5050
setTensorType(in.type(), T);
51-
auto *var =
52-
createAndRememberVariable(in.name(), *T, VisibilityKind::Public);
51+
Placeholder *var = createAndRegisterPlaceholder(in.name(), &T->getType());
5352
onnxNameToInputVars_.try_emplace(in.name(), var);
5453
}
5554
}

lib/Importer/ONNXModelLoader.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,8 @@ bool ONNXModelLoader::setOutputNodes(ONNX_NAMESPACE::GraphProto &net) {
503503
for (int i = 0; i < net.output_size(); i++) {
504504
const auto &outputName = net.output(i).name();
505505
auto r = getNodeValueByName(outputName);
506-
SaveNode *SN = G_.createSave("save_" + outputName, r);
507-
outputVarsByName_[outputName] = SN->getVariable();
506+
SaveNode *SN = G_.createSavePH("save_" + outputName, r);
507+
outputVarsByName_[outputName] = SN->getPlaceholder();
508508
}
509509

510510
return true;
@@ -528,7 +528,7 @@ ONNXModelLoader::ONNXModelLoader(Function &F)
528528

529529
void ONNXModelLoader::checkInputs(ONNX_NAMESPACE::GraphProto &net,
530530
llvm::ArrayRef<const char *> tensorNames,
531-
llvm::ArrayRef<Tensor *> tensors) {
531+
llvm::ArrayRef<TypeRef> types) {
532532
for (size_t i = 0; i < tensorNames.size(); i++) {
533533
// Look if a corresponding input exists.
534534
for (int j = 0; j < net.input_size(); j++) {
@@ -539,7 +539,7 @@ void ONNXModelLoader::checkInputs(ONNX_NAMESPACE::GraphProto &net,
539539
continue;
540540
}
541541

542-
llvm::ArrayRef<size_t> dims = tensors[i]->dims();
542+
llvm::ArrayRef<size_t> dims = types[i]->dims();
543543
const ONNX_NAMESPACE::TensorShapeProto &shape =
544544
valueInfo.type().tensor_type().shape();
545545
(void)shape;
@@ -558,8 +558,8 @@ void ONNXModelLoader::checkInputs(ONNX_NAMESPACE::GraphProto &net,
558558

559559
ONNXModelLoader::ONNXModelLoader(const std::string &modelDescFilename,
560560
llvm::ArrayRef<const char *> tensorNames,
561-
llvm::ArrayRef<Tensor *> tensors, Function &F)
562-
: CommonOperatorLoader(tensorNames, tensors, F) {
561+
llvm::ArrayRef<TypeRef> types, Function &F)
562+
: CommonOperatorLoader(tensorNames, types, F) {
563563
// The ONNX model that we are deserializing.
564564
ONNX_NAMESPACE::ModelProto modelDef;
565565
if (!loadProto(modelDef, modelDescFilename)) {
@@ -568,7 +568,7 @@ ONNXModelLoader::ONNXModelLoader(const std::string &modelDescFilename,
568568
setVersion(modelDef);
569569

570570
ONNX_NAMESPACE::GraphProto graphDef = modelDef.graph();
571-
checkInputs(graphDef, tensorNames, tensors);
571+
checkInputs(graphDef, tensorNames, types);
572572

573573
loadInitializers(graphDef);
574574
if (!loadNetwork(graphDef)) {

lib/Importer/ProtobufLoader.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Tensor *ProtobufLoader::getTensorByName(llvm::StringRef name) {
3434
return tensors_[name];
3535
}
3636

37-
Variable *ProtobufLoader::getOutputByName(llvm::StringRef name) const {
37+
Placeholder *ProtobufLoader::getOutputByName(llvm::StringRef name) const {
3838
assert(outputVarsByName_.count(name) &&
3939
"There is no Variable registered with this name.");
4040
auto it = outputVarsByName_.find(name);
@@ -60,15 +60,23 @@ NodeValue ProtobufLoader::getNodeValueByName(llvm::StringRef name) const {
6060
return node;
6161
}
6262

63-
Variable *ProtobufLoader::createAndRememberVariable(
64-
llvm::StringRef name, const Tensor &tensor, VisibilityKind visibilityKind) {
65-
assert(!hasNodeByName(name) && "Creating an already existing node?!");
63+
Variable *ProtobufLoader::createAndRegisterConstant(llvm::StringRef name,
64+
const Tensor &tensor) {
65+
assert(!hasNodeByName(name) && "Creating an already existing node");
6666
// Note: We do not support training from models loaded from protos, so
6767
// trainable is always set to false here.
68-
Variable *node = G_.getParent()->createVariable(name, tensor, visibilityKind,
69-
/* trainable */ false);
68+
Variable *node =
69+
G_.getParent()->createVariable(name, tensor, VisibilityKind::Private,
70+
/* trainable */ false);
7071
nodeValueByName_[name] = NodeValue(node, 0);
72+
return node;
73+
}
7174

75+
Placeholder *ProtobufLoader::createAndRegisterPlaceholder(llvm::StringRef name,
76+
TypeRef T) {
77+
assert(!hasNodeByName(name) && "Creating an already existing node");
78+
Placeholder *node = G_.getParent()->createPlaceholder(T, name, false);
79+
nodeValueByName_[name] = NodeValue(node, 0);
7280
return node;
7381
}
7482

@@ -80,25 +88,24 @@ ProtobufLoader::getNodeValueOrCreateVariableByName(llvm::StringRef name) {
8088
}
8189

8290
Tensor *T = getTensorByName(name);
83-
return NodeValue(createAndRememberVariable(name, *T), 0);
91+
return NodeValue(createAndRegisterConstant(name, *T), 0);
8492
}
8593

8694
bool ProtobufLoader::hasNodeByName(llvm::StringRef name) const {
8795
return getNodeValueByNameOrNullNodeValue(name).getNode() != nullptr;
8896
}
8997

9098
ProtobufLoader::ProtobufLoader(llvm::ArrayRef<const char *> tensorNames,
91-
llvm::ArrayRef<Tensor *> tensors, Function &F)
99+
llvm::ArrayRef<TypeRef> types, Function &F)
92100
: G_(F) {
93101
// Verify that the version of the library that we linked against is
94102
// compatible with the version of the headers we compiled against.
95103
GOOGLE_PROTOBUF_VERIFY_VERSION;
96104

97-
assert(tensorNames.size() == tensors.size() && "Invalid initialization list");
105+
assert(tensorNames.size() == types.size() && "Invalid initialization list");
98106
for (unsigned i = 0; i < tensorNames.size(); i++) {
99107
assert(!hasNodeByName(tensorNames[i]) && "Input names have duplicate");
100-
createAndRememberVariable(tensorNames[i], *tensors[i],
101-
VisibilityKind::Public);
108+
createAndRegisterPlaceholder(tensorNames[i], types[i]);
102109
}
103110
}
104111

lib/Onnxifi/Base.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,29 +68,28 @@ onnxStatus Graph::initGraph(const void *onnxModel, size_t onnxModelSize,
6868
onnxStatus Graph::run() {
6969
// Copy tensors from the input addresses to the Glow tensors.
7070
llvm::SmallVector<Tensor *, 4> tensors;
71-
llvm::SmallVector<Variable *, 4> vars;
71+
llvm::SmallVector<Placeholder *, 4> phs;
7272
for (auto inputVar : inputVarToBuffer_) {
7373
auto *var = inputVar.first;
7474
auto *type = var->getType();
7575
void *inputBuffer = reinterpret_cast<void *>(inputVar.second);
7676
tensors.push_back(new Tensor(inputBuffer, type));
77-
vars.push_back(var);
77+
phs.push_back(var);
7878
}
7979

8080
// Run inference.
8181
auto &EE = backendPtr_->getEE();
82-
updateVariables(vars, tensors);
82+
updateVariables(ctx_, phs, tensors);
8383
EE.run();
8484

8585
// Copy outputs to the addresses specified in the outputNodeToBuffer_.
8686
for (auto outputVar : outputNodeToBuffer_) {
8787
void *outputAddress = reinterpret_cast<void *>(outputVar.second);
88-
const Tensor &res = outputVar.first->getPayload();
88+
const Tensor *res = ctx_.get(outputVar.first);
8989

90-
memcpy(outputAddress, res.getUnsafePtr(),
91-
res.size() * res.getType().getElementSize());
90+
memcpy(outputAddress, res->getUnsafePtr(),
91+
res->size() * res->getType().getElementSize());
9292
}
93-
9493
return ONNXIFI_STATUS_SUCCESS;
9594
}
9695

lib/Onnxifi/Base.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,19 @@ class Graph {
119119
Context ctx_;
120120

121121
/// Mapping between ONNX name for the input variable and Glow variable.
122-
llvm::StringMap<Variable *> onnxNameToInputVar_;
122+
llvm::StringMap<Placeholder *> onnxNameToInputVar_;
123123

124124
/// Mapping between ONNX name for the output variable and Glow output
125125
/// node.
126-
llvm::StringMap<Variable *> onnxNameToOutputNode_;
126+
llvm::StringMap<Placeholder *> onnxNameToOutputNode_;
127127

128128
/// Mapping between input var and the actual memory address.
129129
/// Inputs will be read from these addresses.
130-
llvm::DenseMap<Variable *, onnxPointer> inputVarToBuffer_;
130+
llvm::DenseMap<Placeholder *, onnxPointer> inputVarToBuffer_;
131131

132132
/// Mapping between output var and the actual memory address.
133133
/// Results must be written to these addresses.
134-
llvm::DenseMap<Variable *, onnxPointer> outputNodeToBuffer_;
134+
llvm::DenseMap<Placeholder *, onnxPointer> outputNodeToBuffer_;
135135
};
136136

137137
typedef Graph *GraphPtr;

tests/unittests/ImporterTestUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ unsigned countNodeKind(Function *F, Kinded::Kind kind) {
5050

5151
/// Helper function to get the save node from a Variable \p var.
5252
/// \pre (var->getUsers().size() == 1)
53-
SaveNode *getSaveNodeFromVariable(Variable *var) {
53+
SaveNode *getSaveNodeFromDest(Storage *var) {
5454
auto &varUsers = var->getUsers();
5555
assert(varUsers.size() == 1);
5656
auto *saveNode = llvm::dyn_cast<SaveNode>(varUsers.begin()->getUser());

0 commit comments

Comments
 (0)