Skip to content

Commit db57a85

Browse files
committed
Placeholder: add the graph creators and a simple unit test that checks that we can construct the graph.
This commit is the first few steps in the implementation of the Placeholder node. The plan is discussed in #1334. Placeholder are unbound variables where the content of the tensor is provided by the execution engine or the AOT compiler at runtime. This commit introduces Storage as a superclass for Variable (bound) and Placeholder (unbound). This PR also adds a simple unit test that checks that we can create new Placeholder nodes. This PR does not implement any of the infrastructure that's required to do anything useful with Placeholder variables, such as loading content into them, deleting them, pretty dotty-printer, etc. Next steps (not necessarily in any specific order): 1. Teach the execution engine to bind tensors to the Placeholder nodes. 2. Verify that dotty printing, dump() and debugging work well. 3. Cleanup the APIs that are related to Variable and Placeholder and make them consistent. 4. Change (some of) the unit tests to use the new Placeholder API. 5. Make sure that our optimizations are correct when placeholder are used.
1 parent 8f5b236 commit db57a85

File tree

4 files changed

+64
-9
lines changed

4 files changed

+64
-9
lines changed

include/glow/Graph/Graph.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ using NodesList = llvm::iplist<glow::Node>;
3838
using NodesPtrList = std::list<glow::Node *>;
3939
/// List of Functions.
4040
using FunctionList = std::list<Function *>;
41-
/// List of Variables.
4241
using VariablesList = std::list<Variable *>;
42+
using PlaceholderList = std::list<Placeholder *>;
4343
using UnsignedArrayRef = llvm::ArrayRef<size_t>;
4444

4545
class Module final {
@@ -53,6 +53,8 @@ class Module final {
5353
llvm::StringSet<> uniqueVariableNames_{};
5454
/// A list of variables that the Module owns.
5555
VariablesList vars_;
56+
/// A list of placeholder nodes that the Module owns.
57+
PlaceholderList placeholders_;
5658
/// Deterministic PRNG used to initialize weights in this module.
5759
PseudoRNG PRNG_;
5860

@@ -67,11 +69,10 @@ class Module final {
6769
llvm::StringSet<> &stringTable);
6870

6971
/// Inserts the variable \p V to the list of variables.
70-
Variable *addVar(Variable *V) {
71-
V->setName(uniqueName(V->getName(), uniqueVariableNames_));
72-
vars_.push_back(V);
73-
return V;
74-
}
72+
Variable *addVar(Variable *V);
73+
74+
/// Inserts the placeholder node \p ph to the list of variables.
75+
Placeholder *addPlaceholder(Placeholder *ph);
7576

7677
/// Return a pointer to a uniqued type \p T.
7778
TypeRef uniqueType(const Type &T);
@@ -117,9 +118,19 @@ class Module final {
117118

118119
const VariablesList &getVars() const { return vars_; }
119120

121+
/// \returns the list of placeholders that the Module owns.
122+
PlaceholderList &getPlaceholders() { return placeholders_; }
123+
124+
const PlaceholderList &getPlaceholders() const { return placeholders_; }
125+
120126
/// @name High-level Variable builders.
121127
///@{
122128

129+
Placeholder *createPlaceholder(ElemKind T, llvm::ArrayRef<size_t> dims,
130+
llvm::StringRef name);
131+
132+
Placeholder *createPlaceholder(TypeRef T, llvm::StringRef name);
133+
123134
Variable *createVariable(TypeRef T, llvm::StringRef name,
124135
VisibilityKind visibility = VisibilityKind::Private,
125136
bool isTrainable = true);

include/glow/Graph/Nodes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class Storage : public Node {
3434
public:
3535
Storage(Kinded::Kind k, llvm::StringRef name) : Node(k, name) {}
3636

37+
/// \return the single output value of the node.
3738
NodeValue getOutput() { return getNthResult(0); }
3839

3940
/// Declare the standard Node methods.

lib/Graph/Graph.cpp

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,17 @@ Function *Module::createFunction(llvm::StringRef name) {
5454
Module::~Module() {
5555
eraseFunctions();
5656

57-
for (auto it = vars_.begin(), e = vars_.end(); it != e;) {
58-
auto cur = it++;
59-
eraseVariable(*cur);
57+
for (auto it = vars_.begin(), e = vars_.end(); it != e; it++) {
58+
Variable *v = *it;
59+
delete v;
60+
}
61+
for (auto it = placeholders_.begin(), e = placeholders_.end(); it != e;
62+
it++) {
63+
Placeholder *p = *it;
64+
delete p;
6065
}
66+
vars_.clear();
67+
placeholders_.clear();
6168
}
6269

6370
void Module::verify() const {
@@ -325,6 +332,17 @@ static ShapeVector getNewShapeWithoutAxis(llvm::ArrayRef<size_t> dims,
325332
// Node builders
326333
//===----------------------------------------------------------------------===//
327334

335+
Placeholder *Module::createPlaceholder(TypeRef T, llvm::StringRef name) {
336+
auto FT = uniqueType(*T);
337+
return addPlaceholder(new Placeholder(name, FT));
338+
}
339+
340+
Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef<size_t> dims,
341+
llvm::StringRef name) {
342+
auto FT = uniqueType(T, dims);
343+
return createPlaceholder(FT, name);
344+
}
345+
328346
Variable *Module::createVariable(TypeRef T, llvm::StringRef name,
329347
VisibilityKind visibility, bool isTrainable) {
330348
auto FT = uniqueType(*T);
@@ -389,6 +407,18 @@ llvm::StringRef Module::uniqueName(llvm::StringRef name,
389407
llvm_unreachable("Unable to find a unique a name.");
390408
}
391409

410+
Variable *Module::addVar(Variable *V) {
411+
V->setName(uniqueName(V->getName(), uniqueVariableNames_));
412+
vars_.push_back(V);
413+
return V;
414+
}
415+
416+
Placeholder *Module::addPlaceholder(Placeholder *ph) {
417+
ph->setName(uniqueName(ph->getName(), uniqueVariableNames_));
418+
placeholders_.push_back(ph);
419+
return ph;
420+
}
421+
392422
ConvolutionNode *Function::createConv(llvm::StringRef name, NodeValue input,
393423
size_t depth,
394424
llvm::ArrayRef<unsigned_t> kernels,

tests/unittests/graphTest.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,3 +886,16 @@ TEST(Graph, PostOrderTest) {
886886
EXPECT_EQ(order[12], ret2->getOutput());
887887
EXPECT_EQ(order[13], ret2);
888888
}
889+
890+
TEST(Graph, placeholder) {
891+
Module MD;
892+
Function *F = MD.createFunction("F");
893+
IRFunction M(F);
894+
Node *K = MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input");
895+
Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select");
896+
897+
K = F->createFullyConnected("FC", K, 10);
898+
K = F->createRELU("Relu", K);
899+
K = F->createSoftMax("SoftMax", K, S);
900+
F->createSave("Save", K);
901+
}

0 commit comments

Comments
 (0)