Skip to content

[Placeholder] Allow the differentiation of Placeholder nodes. #1612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/glow/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ class Module final {
///@{

Placeholder *createPlaceholder(ElemKind T, llvm::ArrayRef<size_t> dims,
llvm::StringRef name);
llvm::StringRef name, bool isTrainable);

Placeholder *createPlaceholder(TypeRef T, llvm::StringRef name);
Placeholder *createPlaceholder(TypeRef T, llvm::StringRef name,
bool isTrainable);

Variable *createVariable(TypeRef T, llvm::StringRef name,
VisibilityKind visibility = VisibilityKind::Private,
Expand Down Expand Up @@ -626,7 +627,7 @@ class Function final : public Named {

struct TrainingConfig;

using VariableGradientsList = std::list<std::pair<Variable *, Variable *>>;
using VariableGradientsList = std::list<std::pair<Storage *, Storage *>>;

/// Create a new Function that 'trains' the input Function. We differentiate the
/// nodes and insert code to update the weights based on the \p config
Expand Down
23 changes: 13 additions & 10 deletions include/glow/Graph/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ namespace glow {
// Storage is the base class for Variables, which are bound to tensors, and
// Placeholder nodes which are unbound.
class Storage : public Node {
/// Specifies if the variable or placeholder is trainable.
bool isTrainable_;

public:
Storage(Kinded::Kind k, llvm::StringRef name) : Node(k, name) {}
Storage(Kinded::Kind k, llvm::StringRef name, bool isTrainable)
: Node(k, name), isTrainable_(isTrainable) {}

/// \return the single output value of the node.
NodeValue getOutput() { return getNthResult(0); }
Expand All @@ -50,6 +54,10 @@ class Storage : public Node {
Node *clone() const;
/// @}

/// \returns True if the Variable or placeholder are trainable during
/// differentiation.
bool isTraining() const { return isTrainable_; }

/// \returns result type of the variable.
TypeRef getType() const { return Node::getType(0); }

Expand All @@ -66,8 +74,6 @@ class Storage : public Node {
};

class Variable : public Storage {
/// Specifies if the variable is trainable.
bool isTrainable_;
/// Specifies the visibility of the variable.
VisibilityKind visibility_;
/// The tensor payload that the variable holds.
Expand All @@ -77,21 +83,18 @@ class Variable : public Storage {
/// Create a new variable and initialize its payload.
Variable(llvm::StringRef name, TypeRef Ty, VisibilityKind visibility,
bool isTrainable)
: Storage(Kinded::Kind::VariableKind, name), isTrainable_(isTrainable),
: Storage(Kinded::Kind::VariableKind, name, isTrainable),
visibility_(visibility) {
addResult(Ty);
payload_.reset(*Ty);
}

Variable(llvm::StringRef name, VisibilityKind visibility, Tensor &&payload)
: Storage(Kinded::Kind::VariableKind, name), isTrainable_(false),
: Storage(Kinded::Kind::VariableKind, name, false),
visibility_(visibility), payload_(std::move(payload)) {
addResult(&payload_.getType());
}

/// \returns True if the Variable is initialized to be in training mode.
bool isTraining() const { return isTrainable_; }

/// \returns True if the Variable is private.
bool isPrivate() const { return visibility_ == VisibilityKind::Private; }

Expand Down Expand Up @@ -123,8 +126,8 @@ class Variable : public Storage {
class Placeholder : public Storage {
public:
/// Create a new placeholder variable.
Placeholder(llvm::StringRef name, TypeRef Ty)
: Storage(Kinded::Kind::PlaceholderKind, name) {
Placeholder(llvm::StringRef name, TypeRef Ty, bool isTrainable)
: Storage(Kinded::Kind::PlaceholderKind, name, isTrainable) {
addResult(Ty);
}

Expand Down
8 changes: 4 additions & 4 deletions lib/Graph/Grad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Function *glow::differentiate(Function *F, const TrainingConfig &conf,

for (auto it = nodes.rbegin(), e = nodes.rend(); it != e; it++) {
Node *N = *it;
if (isa<Variable>(N)) {
if (isa<Storage>(N)) {
continue;
}

Expand Down Expand Up @@ -228,9 +228,9 @@ Function *glow::differentiate(Function *F, const TrainingConfig &conf,
} // End of the for-each instr loop.

for (auto N : nodes) {
// Iterate only through Variables used by the Function.
// Iterate only through Variables/Placeholders used by the Function.
// These are inserted during the post-order walk.
Variable *V = llvm::dyn_cast<Variable>(N);
Storage *V = llvm::dyn_cast<Storage>(N);
if (!V)
continue;

Expand All @@ -241,7 +241,7 @@ Function *glow::differentiate(Function *F, const TrainingConfig &conf,
std::string nodeName = "_grad_" + V->getName().str();
// Save the gradient and return the destination variable.
auto *saveNode = G->createSave(nodeName, map.getGradient(V));
auto *GradV = llvm::dyn_cast<Variable>(saveNode->getOutput().getNode());
auto *GradV = llvm::dyn_cast<Storage>(saveNode->getOutput().getNode());
varGrads->push_back({V, GradV});
}
continue;
Expand Down
11 changes: 6 additions & 5 deletions lib/Graph/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,16 @@ static ShapeVector getNewShapeWithoutAxis(llvm::ArrayRef<size_t> dims,
// Node builders
//===----------------------------------------------------------------------===//

Placeholder *Module::createPlaceholder(TypeRef T, llvm::StringRef name) {
Placeholder *Module::createPlaceholder(TypeRef T, llvm::StringRef name,
bool isTrainable) {
auto FT = uniqueType(*T);
return addPlaceholder(new Placeholder(name, FT));
return addPlaceholder(new Placeholder(name, FT, isTrainable));
}

Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef<size_t> dims,
llvm::StringRef name) {
llvm::StringRef name, bool isTrainable) {
auto FT = uniqueType(T, dims);
return createPlaceholder(FT, name);
return createPlaceholder(FT, name, isTrainable);
}

Variable *Module::createVariable(TypeRef T, llvm::StringRef name,
Expand Down Expand Up @@ -2068,7 +2069,7 @@ Function *Function::clone(llvm::StringRef newName,

auto it = currToNew.find(input.getNode());
if (it == currToNew.end()) {
assert(isa<Variable>(input.getNode()) &&
assert(isa<Storage>(input.getNode()) &&
"Could not find a mapping for some node!");
continue;
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Graph/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ std::string Placeholder::getDebugDesc() const {
DescriptionBuilder db(getKindName());
db.addParam("name", quote(getName()))
.addParam("output", *getType())
.addParam("users", getNumUsers());
.addParam("users", getNumUsers())
.addParam("trainable", isTraining());
return db;
}

Expand Down
11 changes: 7 additions & 4 deletions lib/IR/GraphScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class ChildMemSizeBasedScheduler : public Scheduler {
const auto &input = N->getNthInput(idx);
// Skip operands that do not require memory allocations for storing
// their results.
if (isa<Variable>(input))
if (isa<Storage>(input))
continue;
assert(resultMemSize_.count(input) > 0);
assert(maxMemSize_.count(input) > 0);
Expand All @@ -128,7 +128,7 @@ class ChildMemSizeBasedScheduler : public Scheduler {
if (isScheduled(N))
return;
// Do not explicitly schedule variables.
if (isa<Variable>(N))
if (isa<Storage>(N))
return;
// A set of node's sorted children.
llvm::SmallVector<Node *, 8> orderedChildren;
Expand All @@ -144,8 +144,8 @@ class ChildMemSizeBasedScheduler : public Scheduler {
// We don't model memory dependencies, but we still need to honor them.
// Make sure the SaveNode happens after the last use of the output variable.
if (auto *save = dyn_cast<SaveNode>(N)) {
Variable *output = save->getVariable();

This comment was marked as off-topic.

This comment was marked as off-topic.

for (NodeUse &use : output->getUsers()) {
auto *destination = save->getOutput().getNode();
for (NodeUse &use : destination->getUsers()) {
Node *user = use.getUser();
if (user == save) {
continue;
Expand Down Expand Up @@ -220,6 +220,9 @@ void IRFunction::scheduleGraph(NodesPtrList &Schedule) {
for (auto &N : G_->getParent()->getVars()) {
Schedule.push_back(N);
}
for (auto &N : G_->getParent()->getPlaceholders()) {
Schedule.push_back(N);
}
ChildMemSizeBasedScheduler CMSBScheduler(*G_, Schedule);
CMSBScheduler.schedule();
auto numVars = G_->getParent()->getVars().size();
Expand Down
8 changes: 4 additions & 4 deletions tests/unittests/BackendTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ TEST_P(BackendTest, simplePlaceholderValue) {
Tensor data{99.0, 35.0, 2.0, 3.0};
auto &mod = EE_.getModule();
Function *F = mod.createFunction("main");
auto *input = mod.createPlaceholder(ElemKind::FloatTy, {4}, "input");
auto *input = mod.createPlaceholder(ElemKind::FloatTy, {4}, "input", false);
SaveNode *S = F->createSave("ret", input);
Context ctx({input}, {&data});

Expand All @@ -207,9 +207,9 @@ TEST(Context, basicContextTest) {

// Create a simple graph, just to have a few placeholders.
Function *F = mod.createFunction("main");
auto *input1 = mod.createPlaceholder(ty, "input1");
auto *input2 = mod.createPlaceholder(ty, "input2");
auto *input3 = mod.createPlaceholder(ty, "input3");
auto *input1 = mod.createPlaceholder(ty, "input1", false);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

auto *input2 = mod.createPlaceholder(ty, "input2", false);
auto *input3 = mod.createPlaceholder(ty, "input3", false);
auto *add = F->createAdd("add", input1, input2);
F->createSave("ret", add);

Expand Down
36 changes: 36 additions & 0 deletions tests/unittests/MLTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,42 @@ class TestRunnerBase : public ::testing::TestWithParam<BackendKind> {
class InterpreterAndCPU : public TestRunnerBase {};
class MLTest : public TestRunnerBase {};

/// Use placeholders (and not variables) to learn the square root of two.
TEST_P(MLTest, learnSqrt2Placeholder) {
TrainingConfig TC;
Context ctx;

TC.learningRate = 0.03;

auto &mod = EE_.getModule();
Function *F = mod.createFunction("Square root of 2");

auto *A = mod.createPlaceholder(ElemKind::FloatTy, {1}, "A", true);
auto *inputTensor = ctx.allocate(A);
inputTensor->init(Tensor::InitKind::Broadcast, 1, mod.getPRNG());

auto *E = mod.createPlaceholder(ElemKind::FloatTy, {1}, "Ex", false);
ctx.allocate(E)->getHandle() = {2};

auto *O = mod.createPlaceholder(ElemKind::FloatTy, {1}, "output", false);
ctx.allocate(O);

Node *M = F->createMul("Mult", A, A);
M = F->createRegression("reg", M, E);
F->createSave("ret", M);

Function *TF = glow::differentiate(F, TC);
EE_.compile(CompilationMode::Train, TF, ctx);

// Train the network:
for (int i = 0; i < 100; i++) {
EE_.run();
}

float res = inputTensor->getHandle().at({0});
EXPECT_NEAR(res, 1.4142, 0.01);
}

TEST_P(MLTest, trainASimpleNetwork) {
TrainingConfig TC;
Context ctx;
Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/gradCheckTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "gtest/gtest.h"

using namespace glow;
using llvm::cast;

class GradCheckBase : public ::testing::TestWithParam<BackendKind> {
public:
Expand Down Expand Up @@ -57,7 +58,7 @@ float gradDiff(float G1, float G2) {
Variable *getGrad(const VariableGradientsList &grads, Variable *V) {
for (auto &p : grads) {
if (p.first == V) {
return p.second;
return cast<Variable>(p.second);
}
}
return nullptr;
Expand Down
29 changes: 29 additions & 0 deletions tests/unittests/graphGradTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,32 @@ TEST(GraphAutoGrad, cloneAndDiff) {
EXPECT_EQ(nbSGDA, 1);
EXPECT_EQ(nbSGDB, 1);
}

/// Check that we can differentiate functions that update Placeholder graphs.
TEST(GraphAutoGrad, checkPlaceholderGradTest) {
ExecutionEngine EE;
TrainingConfig TC;
Context ctx;

// Construct the network:
TC.learningRate = 0.001;

auto &mod = EE.getModule();
Function *F = mod.createFunction("main");

Placeholder *A =
mod.createPlaceholder(ElemKind::FloatTy, {10, 28, 28, 1}, "input", true);
auto *RL = F->createRELU("relu", A);
F->createSave("return", RL);

// Expect a single user to the trainable input placeholder.
EXPECT_EQ(A->getNumUsers(), 1);

Function *TF = glow::differentiate(F, TC);
EE.compile(CompilationMode::Train, TF, ctx);
EE.compile(CompilationMode::Infer, F, ctx);

// Check that the Placeholder has multiple users, because at least one write
/// node will be added.
EXPECT_GE(A->getNumUsers(), 1);

This comment was marked as off-topic.

This comment was marked as off-topic.

}
5 changes: 3 additions & 2 deletions tests/unittests/graphTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -903,8 +903,9 @@ TEST(Graph, placeholder) {
Module MD;
Function *F = MD.createFunction("F");
IRFunction M(F);
Node *K = MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input");
Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select");
Node *K =
MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", false);
Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", false);

K = F->createFullyConnected("FC", K, 10);
K = F->createRELU("Relu", K);
Expand Down