Skip to content

Placeholder: add Placeholder as a new kind of input/output to the graph. #1409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/IR.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,24 @@ usage.
8. Low-level IR optimizations are performed.

9. Backend-specific optimizations and code generation are performed.

### Placeholders

We are in the process of adding a new kind of variable: Placeholder. The
motivation and plan for Placeholder variables are described in the issue #1334.

The work on Placeholder variables is ongoing and the following tasks are still
open:

1. Teach the execution engine to bind tensors to the Placeholder nodes.

2. Verify that dotty printing, dump() and debugging work well.

3. Cleanup the APIs that are related to Variable and Placeholder and make them
consistent.

4. Change (some of) the unit tests to use the new Placeholder API.

5. Make sure that our optimizations are correct when placeholder are used.


23 changes: 17 additions & 6 deletions include/glow/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ using NodesList = llvm::iplist<glow::Node>;
using NodesPtrList = std::list<glow::Node *>;
/// List of Functions.
using FunctionList = std::list<Function *>;
/// List of Variables.
using VariablesList = std::list<Variable *>;
using PlaceholderList = std::list<Placeholder *>;
using UnsignedArrayRef = llvm::ArrayRef<size_t>;

class Module final {
Expand All @@ -53,6 +53,8 @@ class Module final {
llvm::StringSet<> uniqueVariableNames_{};
/// A list of variables that the Module owns.
VariablesList vars_;
/// A list of placeholder nodes that the Module owns.
PlaceholderList placeholders_;
/// Deterministic PRNG used to initialize weights in this module.
PseudoRNG PRNG_;

Expand All @@ -67,11 +69,10 @@ class Module final {
llvm::StringSet<> &stringTable);

/// Inserts the variable \p V to the list of variables.
Variable *addVar(Variable *V) {
V->setName(uniqueName(V->getName(), uniqueVariableNames_));
vars_.push_back(V);
return V;
}
Variable *addVar(Variable *V);

/// Inserts the placeholder node \p ph to the list of variables.
Placeholder *addPlaceholder(Placeholder *ph);

/// Return a pointer to a uniqued type \p T.
TypeRef uniqueType(const Type &T);
Expand Down Expand Up @@ -117,9 +118,19 @@ class Module final {

const VariablesList &getVars() const { return vars_; }

/// \returns the list of placeholders that the Module owns.
PlaceholderList &getPlaceholders() { return placeholders_; }

const PlaceholderList &getPlaceholders() const { return placeholders_; }

/// @name High-level Variable builders.
///@{

Placeholder *createPlaceholder(ElemKind T, llvm::ArrayRef<size_t> dims,
llvm::StringRef name);

Placeholder *createPlaceholder(TypeRef T, llvm::StringRef name);

Variable *createVariable(TypeRef T, llvm::StringRef name,
VisibilityKind visibility = VisibilityKind::Private,
bool isTrainable = true);
Expand Down
83 changes: 58 additions & 25 deletions include/glow/Graph/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,44 @@

namespace glow {

class Variable : public Node {
// Storage is the base class for Variables, which are bound to tensors, and
// Placeholder nodes which are unbound.
class Storage : public Node {
public:
Storage(Kinded::Kind k, llvm::StringRef name) : Node(k, name) {}

/// \return the single output value of the node.
NodeValue getOutput() { return getNthResult(0); }

/// Declare the standard Node methods.
/// @{
void visit(Node *parent, NodeWalker *visitor);
void visit(const Node *parent, NodeWalker *visitor) const;
bool isEqual(const Storage &other) const;
unsigned getNumInputs() const;
std::string getInputName(unsigned idx) const;
NodeValue getNthInput(unsigned idx);
llvm::StringRef getOutputName(unsigned idx) const;
bool hasSideEffects() const;
Node *clone() const;
/// @}

/// \returns result type of the variable.
TypeRef getType() const { return Node::getType(0); }

/// Methods that forward to the result type (that must be valid):
/// @{
ElemKind getElementType() const { return getType()->getElementType(); };
llvm::ArrayRef<size_t> dims() const { return getType()->dims(); };
/// @}

static bool classof(const Kinded *k) {
return k->getKind() == Kinded::Kind::VariableKind ||
k->getKind() == Kinded::Kind::PlaceholderKind;
}
};

class Variable : public Storage {
/// Specifies if the variable is trainable.
bool isTrainable_;
/// Specifies the visibility of the variable.
Expand All @@ -40,14 +77,14 @@ class Variable : public Node {
/// Create a new variable and initialize its payload.
Variable(llvm::StringRef name, TypeRef Ty, VisibilityKind visibility,
bool isTrainable)
: Node(Kinded::Kind::VariableKind, name), isTrainable_(isTrainable),
: Storage(Kinded::Kind::VariableKind, name), isTrainable_(isTrainable),
visibility_(visibility) {
addResult(Ty);
payload_.reset(*Ty);
}

Variable(llvm::StringRef name, VisibilityKind visibility, Tensor &&payload)
: Node(Kinded::Kind::VariableKind, name), isTrainable_(false),
: Storage(Kinded::Kind::VariableKind, name), isTrainable_(false),
visibility_(visibility), payload_(std::move(payload)) {
addResult(&payload_.getType());
}
Expand All @@ -62,15 +99,6 @@ class Variable : public Node {
return k->getKind() == Kinded::Kind::VariableKind;
}

/// \returns result type of the variable.
TypeRef getType() const { return Node::getType(0); }

/// Methods that forward to the result type (that must be valid):
/// @{
ElemKind getElementType() const { return getType()->getElementType(); };
llvm::ArrayRef<size_t> dims() const { return getType()->dims(); };
/// @}

/// \returns the visibility of the variable.
VisibilityKind getVisibilityKind() const { return visibility_; }

Expand All @@ -84,22 +112,27 @@ class Variable : public Node {

void assign(const Tensor *t) { payload_.assign(t); }

/// \returns the output NodeValue from the Variable. Variables only have a
/// single output.
NodeValue getOutput() { return getNthResult(0); }

unsigned getNumInputs() const;
std::string getInputName(unsigned idx) const;
NodeValue getNthInput(unsigned idx);
llvm::StringRef getOutputName(unsigned idx) const;
bool hasSideEffects() const;
std::string getDebugDesc() const;
Node *clone() const;

void visit(Node *parent, NodeWalker *visitor);
void visit(const Node *parent, NodeWalker *visitor) const;
llvm::hash_code getHash() const;
};

/// Placeholder nodes are unbound-storage. The content tensors are attached to
/// this node at runtime. Placeholders are used as inputs and output nodes to
/// the network.
class Placeholder : public Storage {
public:
/// Create a new placeholder variable.
Placeholder(llvm::StringRef name, TypeRef Ty)
: Storage(Kinded::Kind::PlaceholderKind, name) {
addResult(Ty);
}

bool isEqual(const Variable &other) const;
static bool classof(const Kinded *k) {
return k->getKind() == Kinded::Kind::PlaceholderKind;
}

std::string getDebugDesc() const;

llvm::hash_code getHash() const;
};
Expand Down
36 changes: 33 additions & 3 deletions lib/Graph/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,17 @@ Function *Module::createFunction(llvm::StringRef name) {
Module::~Module() {
eraseFunctions();

for (auto it = vars_.begin(), e = vars_.end(); it != e;) {
auto cur = it++;
eraseVariable(*cur);
for (auto it = vars_.begin(), e = vars_.end(); it != e; it++) {
Variable *v = *it;
delete v;
}
for (auto it = placeholders_.begin(), e = placeholders_.end(); it != e;
it++) {
Placeholder *p = *it;
delete p;
}
vars_.clear();
placeholders_.clear();
}

void Module::verify() const {
Expand Down Expand Up @@ -325,6 +332,17 @@ static ShapeVector getNewShapeWithoutAxis(llvm::ArrayRef<size_t> dims,
// Node builders
//===----------------------------------------------------------------------===//

Placeholder *Module::createPlaceholder(TypeRef T, llvm::StringRef name) {
auto FT = uniqueType(*T);
return addPlaceholder(new Placeholder(name, FT));
}

Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef<size_t> dims,
llvm::StringRef name) {
auto FT = uniqueType(T, dims);
return createPlaceholder(FT, name);
}

Variable *Module::createVariable(TypeRef T, llvm::StringRef name,
VisibilityKind visibility, bool isTrainable) {
auto FT = uniqueType(*T);
Expand Down Expand Up @@ -389,6 +407,18 @@ llvm::StringRef Module::uniqueName(llvm::StringRef name,
llvm_unreachable("Unable to find a unique a name.");
}

Variable *Module::addVar(Variable *V) {
V->setName(uniqueName(V->getName(), uniqueVariableNames_));
vars_.push_back(V);
return V;
}

Placeholder *Module::addPlaceholder(Placeholder *ph) {
ph->setName(uniqueName(ph->getName(), uniqueVariableNames_));
placeholders_.push_back(ph);
return ph;
}

ConvolutionNode *Function::createConv(llvm::StringRef name, NodeValue input,
size_t depth,
llvm::ArrayRef<unsigned_t> kernels,
Expand Down
35 changes: 22 additions & 13 deletions lib/Graph/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,31 @@

using namespace glow;

/// Equality predicate for variables.
bool Variable::isEqual(const Variable &other) const {
/// A variable should be equal only to itself!
bool Storage::isEqual(const Storage &other) const {
/// A storage should be equal only to itself!
return this == &other;
}

llvm::hash_code Variable::getHash() const {
return llvm::hash_combine(getName(), isTraining(), getType());
}

llvm::hash_code Placeholder::getHash() const {
return llvm::hash_combine(getName());
}
//===----------------------------------------------------------------------===//
// Visitor methods
//===----------------------------------------------------------------------===//

void Variable::visit(Node *parent, NodeWalker *visitor) {
void Storage::visit(Node *parent, NodeWalker *visitor) {
if (!visitor->shouldVisit(parent, this)) {
return;
}
visitor->pre(parent, this);
visitor->post(parent, this);
}

void Variable::visit(const Node *parent, NodeWalker *visitor) const {
void Storage::visit(const Node *parent, NodeWalker *visitor) const {
if (!visitor->shouldVisit(parent, this)) {
return;
}
Expand All @@ -53,28 +56,26 @@ void Variable::visit(const Node *parent, NodeWalker *visitor) const {
//===----------------------------------------------------------------------===//
// Edge getters methods
//===----------------------------------------------------------------------===//
unsigned Variable::getNumInputs() const { return 0; }
unsigned Storage::getNumInputs() const { return 0; }

std::string Variable::getInputName(unsigned idx) const {
std::string Storage::getInputName(unsigned idx) const {
llvm_unreachable("Invalid index");
}

NodeValue Variable::getNthInput(unsigned idx) {
NodeValue Storage::getNthInput(unsigned idx) {
llvm_unreachable("Invalid index");
}

llvm::StringRef Variable::getOutputName(unsigned idx) const {
llvm::StringRef Storage::getOutputName(unsigned idx) const {
if (idx == 0) {
return "Output";
}
llvm_unreachable("Invalid index");
}

bool Variable::hasSideEffects() const { return false; }
bool Storage::hasSideEffects() const { return false; }

Node *Variable::clone() const {
llvm_unreachable("variables can't be cloned.");
}
Node *Storage::clone() const { llvm_unreachable("variables can't be cloned."); }

//===----------------------------------------------------------------------===//
// Debug description methods
Expand All @@ -95,6 +96,14 @@ std::string Variable::getDebugDesc() const {
return db;
}

std::string Placeholder::getDebugDesc() const {
DescriptionBuilder db(getKindName());
db.addParam("name", quote(getName()))
.addParam("output", *getType())
.addParam("users", getNumUsers());
return db;
}

//===----------------------------------------------------------------------===//
// Nodes verification
//===----------------------------------------------------------------------===//
Expand Down
13 changes: 13 additions & 0 deletions tests/unittests/graphTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -886,3 +886,16 @@ TEST(Graph, PostOrderTest) {
EXPECT_EQ(order[12], ret2->getOutput());
EXPECT_EQ(order[13], ret2);
}

TEST(Graph, placeholder) {
Module MD;
Function *F = MD.createFunction("F");
IRFunction M(F);
Node *K = MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input");
Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select");

K = F->createFullyConnected("FC", K, 10);
K = F->createRELU("Relu", K);
K = F->createSoftMax("SoftMax", K, S);
F->createSave("Save", K);
}
2 changes: 2 additions & 0 deletions tools/ClassGen/NodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ int main(int argc, char **argv) {
// Input/Output nodes
//===--------------------------------------------------------------------===//

BB.declareNode("Storage");
BB.declareNode("Variable");
BB.declareNode("Placeholder");

BB.newNode("Save")
.addInput("Input")
Expand Down