diff --git a/include/glow/Graph/Graph.h b/include/glow/Graph/Graph.h index 771a1ef4f6..07320fd468 100644 --- a/include/glow/Graph/Graph.h +++ b/include/glow/Graph/Graph.h @@ -674,7 +674,8 @@ class Function final : public Named { struct TrainingConfig; -using VariableGradientsList = std::list>; +using VariableGradientsList = + std::list>; /// Create a new Function that 'trains' the input Function. We differentiate the /// nodes and insert code to update the weights based on the \p config diff --git a/lib/Graph/Grad.cpp b/lib/Graph/Grad.cpp index 2ba20d6ef8..e1bcf61679 100644 --- a/lib/Graph/Grad.cpp +++ b/lib/Graph/Grad.cpp @@ -228,36 +228,36 @@ Function *glow::differentiate(Function *F, const TrainingConfig &conf, } // End of the for-each instr loop. for (auto N : nodes) { - // Iterate only through Variables/Placeholders used by the Function. - // These are inserted during the post-order walk. - Storage *V = llvm::dyn_cast(N); - if (!V) + // Iterate only through Placeholders used by the Function. These are + // inserted during the post-order walk. + Placeholder *PH = llvm::dyn_cast(N); + if (!PH) continue; // In this special differentiation mode we record the last gradient value // without performing the SGD update. This mode is used by the unit tests. if (varGrads) { - if (map.hasGradient(V)) { - std::string nodeName = "_grad_" + V->getName().str(); + if (map.hasGradient(PH)) { + std::string nodeName = "_grad_" + PH->getName().str(); // Save the gradient and return the destination variable. - auto *saveNode = G->createSavePH(nodeName, map.getGradient(V)); - auto *GradV = llvm::dyn_cast(saveNode->getPlaceholder()); - varGrads->push_back({V, GradV}); + auto *saveNode = G->createSavePH(nodeName, map.getGradient(PH)); + Placeholder *GradV = saveNode->getPlaceholder(); + varGrads->push_back({PH, GradV}); } continue; } // Don't update nodes that are not marked as trainable. - if (!V->isTraining()) { + if (!PH->isTraining()) { continue; } - auto X = new SGDNode(V->getName(), map.getGradient(V), V, conf.L1Decay, + auto X = new SGDNode(PH->getName(), map.getGradient(PH), PH, conf.L1Decay, conf.L2Decay, conf.learningRate, conf.momentum, conf.batchSize); toAppend.push_back(X); // Now update the weight with the value computed by SGD. - auto *save = new SaveNode(V->getName().str() + ".saveGrad", {X, 0}, V); + auto *save = new SaveNode(PH->getName().str() + ".saveGrad", {X, 0}, PH); toAppend.push_back(save); } diff --git a/tests/unittests/gradCheckTest.cpp b/tests/unittests/gradCheckTest.cpp index b31683a938..c17ee5341e 100644 --- a/tests/unittests/gradCheckTest.cpp +++ b/tests/unittests/gradCheckTest.cpp @@ -58,7 +58,7 @@ float gradDiff(float G1, float G2) { Placeholder *getGrad(const VariableGradientsList &grads, Placeholder *V) { for (auto &p : grads) { if (p.first == V) { - return cast(p.second); + return p.second; } } return nullptr; @@ -66,7 +66,7 @@ Placeholder *getGrad(const VariableGradientsList &grads, Placeholder *V) { void allocateGrads(Context &ctx, const VariableGradientsList &grads) { for (auto &p : grads) { - auto grad = cast(p.second); + auto grad = p.second; ctx.allocate(grad); } }