diff --git a/docs/IR.md b/docs/IR.md index 9b582899ca..729a0c240e 100644 --- a/docs/IR.md +++ b/docs/IR.md @@ -263,3 +263,24 @@ usage. 8. Low-level IR optimizations are performed. 9. Backend-specific optimizations and code generation are performed. + +### Placeholders + +We are in the process of adding a new kind of variable: Placeholder. The +motivation and plan for Placeholder variables are described in the issue #1334. + +The work on Placeholder variables is ongoing and the following tasks are still +open: + +1. Teach the execution engine to bind tensors to the Placeholder nodes. + +2. Verify that dotty printing, dump() and debugging work well. + +3. Cleanup the APIs that are related to Variable and Placeholder and make them +consistent. + +4. Change (some of) the unit tests to use the new Placeholder API. + +5. Make sure that our optimizations are correct when placeholder are used. + + diff --git a/include/glow/Graph/Graph.h b/include/glow/Graph/Graph.h index c52744c368..9d3d5e0276 100644 --- a/include/glow/Graph/Graph.h +++ b/include/glow/Graph/Graph.h @@ -38,8 +38,8 @@ using NodesList = llvm::iplist; using NodesPtrList = std::list; /// List of Functions. using FunctionList = std::list; -/// List of Variables. using VariablesList = std::list; +using PlaceholderList = std::list; using UnsignedArrayRef = llvm::ArrayRef; class Module final { @@ -53,6 +53,8 @@ class Module final { llvm::StringSet<> uniqueVariableNames_{}; /// A list of variables that the Module owns. VariablesList vars_; + /// A list of placeholder nodes that the Module owns. + PlaceholderList placeholders_; /// Deterministic PRNG used to initialize weights in this module. PseudoRNG PRNG_; @@ -67,11 +69,10 @@ class Module final { llvm::StringSet<> &stringTable); /// Inserts the variable \p V to the list of variables. - Variable *addVar(Variable *V) { - V->setName(uniqueName(V->getName(), uniqueVariableNames_)); - vars_.push_back(V); - return V; - } + Variable *addVar(Variable *V); + + /// Inserts the placeholder node \p ph to the list of variables. + Placeholder *addPlaceholder(Placeholder *ph); /// Return a pointer to a uniqued type \p T. TypeRef uniqueType(const Type &T); @@ -117,9 +118,19 @@ class Module final { const VariablesList &getVars() const { return vars_; } + /// \returns the list of placeholders that the Module owns. + PlaceholderList &getPlaceholders() { return placeholders_; } + + const PlaceholderList &getPlaceholders() const { return placeholders_; } + /// @name High-level Variable builders. ///@{ + Placeholder *createPlaceholder(ElemKind T, llvm::ArrayRef dims, + llvm::StringRef name); + + Placeholder *createPlaceholder(TypeRef T, llvm::StringRef name); + Variable *createVariable(TypeRef T, llvm::StringRef name, VisibilityKind visibility = VisibilityKind::Private, bool isTrainable = true); diff --git a/include/glow/Graph/Nodes.h b/include/glow/Graph/Nodes.h index b796af512e..4023f9e320 100644 --- a/include/glow/Graph/Nodes.h +++ b/include/glow/Graph/Nodes.h @@ -28,7 +28,44 @@ namespace glow { -class Variable : public Node { +// Storage is the base class for Variables, which are bound to tensors, and +// Placeholder nodes which are unbound. +class Storage : public Node { +public: + Storage(Kinded::Kind k, llvm::StringRef name) : Node(k, name) {} + + /// \return the single output value of the node. + NodeValue getOutput() { return getNthResult(0); } + + /// Declare the standard Node methods. + /// @{ + void visit(Node *parent, NodeWalker *visitor); + void visit(const Node *parent, NodeWalker *visitor) const; + bool isEqual(const Storage &other) const; + unsigned getNumInputs() const; + std::string getInputName(unsigned idx) const; + NodeValue getNthInput(unsigned idx); + llvm::StringRef getOutputName(unsigned idx) const; + bool hasSideEffects() const; + Node *clone() const; + /// @} + + /// \returns result type of the variable. + TypeRef getType() const { return Node::getType(0); } + + /// Methods that forward to the result type (that must be valid): + /// @{ + ElemKind getElementType() const { return getType()->getElementType(); }; + llvm::ArrayRef dims() const { return getType()->dims(); }; + /// @} + + static bool classof(const Kinded *k) { + return k->getKind() == Kinded::Kind::VariableKind || + k->getKind() == Kinded::Kind::PlaceholderKind; + } +}; + +class Variable : public Storage { /// Specifies if the variable is trainable. bool isTrainable_; /// Specifies the visibility of the variable. @@ -40,14 +77,14 @@ class Variable : public Node { /// Create a new variable and initialize its payload. Variable(llvm::StringRef name, TypeRef Ty, VisibilityKind visibility, bool isTrainable) - : Node(Kinded::Kind::VariableKind, name), isTrainable_(isTrainable), + : Storage(Kinded::Kind::VariableKind, name), isTrainable_(isTrainable), visibility_(visibility) { addResult(Ty); payload_.reset(*Ty); } Variable(llvm::StringRef name, VisibilityKind visibility, Tensor &&payload) - : Node(Kinded::Kind::VariableKind, name), isTrainable_(false), + : Storage(Kinded::Kind::VariableKind, name), isTrainable_(false), visibility_(visibility), payload_(std::move(payload)) { addResult(&payload_.getType()); } @@ -62,15 +99,6 @@ class Variable : public Node { return k->getKind() == Kinded::Kind::VariableKind; } - /// \returns result type of the variable. - TypeRef getType() const { return Node::getType(0); } - - /// Methods that forward to the result type (that must be valid): - /// @{ - ElemKind getElementType() const { return getType()->getElementType(); }; - llvm::ArrayRef dims() const { return getType()->dims(); }; - /// @} - /// \returns the visibility of the variable. VisibilityKind getVisibilityKind() const { return visibility_; } @@ -84,22 +112,27 @@ class Variable : public Node { void assign(const Tensor *t) { payload_.assign(t); } - /// \returns the output NodeValue from the Variable. Variables only have a - /// single output. - NodeValue getOutput() { return getNthResult(0); } - - unsigned getNumInputs() const; - std::string getInputName(unsigned idx) const; - NodeValue getNthInput(unsigned idx); - llvm::StringRef getOutputName(unsigned idx) const; - bool hasSideEffects() const; std::string getDebugDesc() const; - Node *clone() const; - void visit(Node *parent, NodeWalker *visitor); - void visit(const Node *parent, NodeWalker *visitor) const; + llvm::hash_code getHash() const; +}; + +/// Placeholder nodes are unbound-storage. The content tensors are attached to +/// this node at runtime. Placeholders are used as inputs and output nodes to +/// the network. +class Placeholder : public Storage { +public: + /// Create a new placeholder variable. + Placeholder(llvm::StringRef name, TypeRef Ty) + : Storage(Kinded::Kind::PlaceholderKind, name) { + addResult(Ty); + } - bool isEqual(const Variable &other) const; + static bool classof(const Kinded *k) { + return k->getKind() == Kinded::Kind::PlaceholderKind; + } + + std::string getDebugDesc() const; llvm::hash_code getHash() const; }; diff --git a/lib/Graph/Graph.cpp b/lib/Graph/Graph.cpp index 7ef85de092..964733d49a 100644 --- a/lib/Graph/Graph.cpp +++ b/lib/Graph/Graph.cpp @@ -54,10 +54,17 @@ Function *Module::createFunction(llvm::StringRef name) { Module::~Module() { eraseFunctions(); - for (auto it = vars_.begin(), e = vars_.end(); it != e;) { - auto cur = it++; - eraseVariable(*cur); + for (auto it = vars_.begin(), e = vars_.end(); it != e; it++) { + Variable *v = *it; + delete v; + } + for (auto it = placeholders_.begin(), e = placeholders_.end(); it != e; + it++) { + Placeholder *p = *it; + delete p; } + vars_.clear(); + placeholders_.clear(); } void Module::verify() const { @@ -325,6 +332,17 @@ static ShapeVector getNewShapeWithoutAxis(llvm::ArrayRef dims, // Node builders //===----------------------------------------------------------------------===// +Placeholder *Module::createPlaceholder(TypeRef T, llvm::StringRef name) { + auto FT = uniqueType(*T); + return addPlaceholder(new Placeholder(name, FT)); +} + +Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef dims, + llvm::StringRef name) { + auto FT = uniqueType(T, dims); + return createPlaceholder(FT, name); +} + Variable *Module::createVariable(TypeRef T, llvm::StringRef name, VisibilityKind visibility, bool isTrainable) { auto FT = uniqueType(*T); @@ -389,6 +407,18 @@ llvm::StringRef Module::uniqueName(llvm::StringRef name, llvm_unreachable("Unable to find a unique a name."); } +Variable *Module::addVar(Variable *V) { + V->setName(uniqueName(V->getName(), uniqueVariableNames_)); + vars_.push_back(V); + return V; +} + +Placeholder *Module::addPlaceholder(Placeholder *ph) { + ph->setName(uniqueName(ph->getName(), uniqueVariableNames_)); + placeholders_.push_back(ph); + return ph; +} + ConvolutionNode *Function::createConv(llvm::StringRef name, NodeValue input, size_t depth, llvm::ArrayRef kernels, diff --git a/lib/Graph/Nodes.cpp b/lib/Graph/Nodes.cpp index 0c98681a43..25b3618058 100644 --- a/lib/Graph/Nodes.cpp +++ b/lib/Graph/Nodes.cpp @@ -21,20 +21,23 @@ using namespace glow; -/// Equality predicate for variables. -bool Variable::isEqual(const Variable &other) const { - /// A variable should be equal only to itself! +bool Storage::isEqual(const Storage &other) const { + /// A storage should be equal only to itself! return this == &other; } llvm::hash_code Variable::getHash() const { return llvm::hash_combine(getName(), isTraining(), getType()); } + +llvm::hash_code Placeholder::getHash() const { + return llvm::hash_combine(getName()); +} //===----------------------------------------------------------------------===// // Visitor methods //===----------------------------------------------------------------------===// -void Variable::visit(Node *parent, NodeWalker *visitor) { +void Storage::visit(Node *parent, NodeWalker *visitor) { if (!visitor->shouldVisit(parent, this)) { return; } @@ -42,7 +45,7 @@ void Variable::visit(Node *parent, NodeWalker *visitor) { visitor->post(parent, this); } -void Variable::visit(const Node *parent, NodeWalker *visitor) const { +void Storage::visit(const Node *parent, NodeWalker *visitor) const { if (!visitor->shouldVisit(parent, this)) { return; } @@ -53,28 +56,26 @@ void Variable::visit(const Node *parent, NodeWalker *visitor) const { //===----------------------------------------------------------------------===// // Edge getters methods //===----------------------------------------------------------------------===// -unsigned Variable::getNumInputs() const { return 0; } +unsigned Storage::getNumInputs() const { return 0; } -std::string Variable::getInputName(unsigned idx) const { +std::string Storage::getInputName(unsigned idx) const { llvm_unreachable("Invalid index"); } -NodeValue Variable::getNthInput(unsigned idx) { +NodeValue Storage::getNthInput(unsigned idx) { llvm_unreachable("Invalid index"); } -llvm::StringRef Variable::getOutputName(unsigned idx) const { +llvm::StringRef Storage::getOutputName(unsigned idx) const { if (idx == 0) { return "Output"; } llvm_unreachable("Invalid index"); } -bool Variable::hasSideEffects() const { return false; } +bool Storage::hasSideEffects() const { return false; } -Node *Variable::clone() const { - llvm_unreachable("variables can't be cloned."); -} +Node *Storage::clone() const { llvm_unreachable("variables can't be cloned."); } //===----------------------------------------------------------------------===// // Debug description methods @@ -95,6 +96,14 @@ std::string Variable::getDebugDesc() const { return db; } +std::string Placeholder::getDebugDesc() const { + DescriptionBuilder db(getKindName()); + db.addParam("name", quote(getName())) + .addParam("output", *getType()) + .addParam("users", getNumUsers()); + return db; +} + //===----------------------------------------------------------------------===// // Nodes verification //===----------------------------------------------------------------------===// diff --git a/tests/unittests/graphTest.cpp b/tests/unittests/graphTest.cpp index 31f898df52..b8230e9a90 100644 --- a/tests/unittests/graphTest.cpp +++ b/tests/unittests/graphTest.cpp @@ -886,3 +886,16 @@ TEST(Graph, PostOrderTest) { EXPECT_EQ(order[12], ret2->getOutput()); EXPECT_EQ(order[13], ret2); } + +TEST(Graph, placeholder) { + Module MD; + Function *F = MD.createFunction("F"); + IRFunction M(F); + Node *K = MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input"); + Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select"); + + K = F->createFullyConnected("FC", K, 10); + K = F->createRELU("Relu", K); + K = F->createSoftMax("SoftMax", K, S); + F->createSave("Save", K); +} diff --git a/tools/ClassGen/NodeGen.cpp b/tools/ClassGen/NodeGen.cpp index 46cb80f0de..45f11ecd42 100644 --- a/tools/ClassGen/NodeGen.cpp +++ b/tools/ClassGen/NodeGen.cpp @@ -37,7 +37,9 @@ int main(int argc, char **argv) { // Input/Output nodes //===--------------------------------------------------------------------===// + BB.declareNode("Storage"); BB.declareNode("Variable"); + BB.declareNode("Placeholder"); BB.newNode("Save") .addInput("Input")