Skip to content

Commit 92a85e3

Browse files
committed
[Placeholder] Remove the ability to differentiate variables.
This commit removes the ability to differentiate variables. From this point we can only differentiate placeholders. It looks like we've reached this point because all of the tests are passing.
1 parent 2813dd7 commit 92a85e3

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

include/glow/Graph/Graph.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,8 @@ class Function final : public Named {
677677

678678
struct TrainingConfig;
679679

680-
using VariableGradientsList = std::list<std::pair<Storage *, Storage *>>;
680+
using VariableGradientsList =
681+
std::list<std::pair<Placeholder *, Placeholder *>>;
681682

682683
/// Create a new Function that 'trains' the input Function. We differentiate the
683684
/// nodes and insert code to update the weights based on the \p config

lib/Graph/Grad.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -228,36 +228,36 @@ Function *glow::differentiate(Function *F, const TrainingConfig &conf,
228228
} // End of the for-each instr loop.
229229

230230
for (auto N : nodes) {
231-
// Iterate only through Variables/Placeholders used by the Function.
232-
// These are inserted during the post-order walk.
233-
Storage *V = llvm::dyn_cast<Storage>(N);
234-
if (!V)
231+
// Iterate only through Placeholders used by the Function. These are
232+
// inserted during the post-order walk.
233+
Placeholder *PH = llvm::dyn_cast<Placeholder>(N);
234+
if (!PH)
235235
continue;
236236

237237
// In this special differentiation mode we record the last gradient value
238238
// without performing the SGD update. This mode is used by the unit tests.
239239
if (varGrads) {
240-
if (map.hasGradient(V)) {
241-
std::string nodeName = "_grad_" + V->getName().str();
240+
if (map.hasGradient(PH)) {
241+
std::string nodeName = "_grad_" + PH->getName().str();
242242
// Save the gradient and return the destination variable.
243-
auto *saveNode = G->createSavePH(nodeName, map.getGradient(V));
244-
auto *GradV = llvm::dyn_cast<Storage>(saveNode->getPlaceholder());
245-
varGrads->push_back({V, GradV});
243+
auto *saveNode = G->createSavePH(nodeName, map.getGradient(PH));
244+
Placeholder *GradV = saveNode->getPlaceholder();
245+
varGrads->push_back({PH, GradV});
246246
}
247247
continue;
248248
}
249249

250250
// Don't update nodes that are not marked as trainable.
251-
if (!V->isTraining()) {
251+
if (!PH->isTraining()) {
252252
continue;
253253
}
254254

255-
auto X = new SGDNode(V->getName(), map.getGradient(V), V, conf.L1Decay,
255+
auto X = new SGDNode(PH->getName(), map.getGradient(PH), PH, conf.L1Decay,
256256
conf.L2Decay, conf.learningRate, conf.momentum,
257257
conf.batchSize);
258258
toAppend.push_back(X);
259259
// Now update the weight with the value computed by SGD.
260-
auto *save = new SaveNode(V->getName().str() + ".saveGrad", {X, 0}, V);
260+
auto *save = new SaveNode(PH->getName().str() + ".saveGrad", {X, 0}, PH);
261261
toAppend.push_back(save);
262262
}
263263

tests/unittests/gradCheckTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ float gradDiff(float G1, float G2) {
5858
Placeholder *getGrad(const VariableGradientsList &grads, Placeholder *V) {
5959
for (auto &p : grads) {
6060
if (p.first == V) {
61-
return cast<Placeholder>(p.second);
61+
return p.second;
6262
}
6363
}
6464
return nullptr;
6565
}
6666

6767
void allocateGrads(Context &ctx, const VariableGradientsList &grads) {
6868
for (auto &p : grads) {
69-
auto grad = cast<Placeholder>(p.second);
69+
auto grad = p.second;
7070
ctx.allocate(grad);
7171
}
7272
}

0 commit comments

Comments
 (0)