@@ -228,36 +228,36 @@ Function *glow::differentiate(Function *F, const TrainingConfig &conf,
228
228
} // End of the for-each instr loop.
229
229
230
230
for (auto N : nodes) {
231
- // Iterate only through Variables/ Placeholders used by the Function.
232
- // These are inserted during the post-order walk.
233
- Storage *V = llvm::dyn_cast<Storage >(N);
234
- if (!V )
231
+ // Iterate only through Placeholders used by the Function. These are
232
+ // inserted during the post-order walk.
233
+ Placeholder *PH = llvm::dyn_cast<Placeholder >(N);
234
+ if (!PH )
235
235
continue ;
236
236
237
237
// In this special differentiation mode we record the last gradient value
238
238
// without performing the SGD update. This mode is used by the unit tests.
239
239
if (varGrads) {
240
- if (map.hasGradient (V )) {
241
- std::string nodeName = " _grad_" + V ->getName ().str ();
240
+ if (map.hasGradient (PH )) {
241
+ std::string nodeName = " _grad_" + PH ->getName ().str ();
242
242
// Save the gradient and return the destination variable.
243
- auto *saveNode = G->createSavePH (nodeName, map.getGradient (V ));
244
- auto *GradV = llvm::dyn_cast<Storage>( saveNode->getPlaceholder () );
245
- varGrads->push_back ({V , GradV});
243
+ auto *saveNode = G->createSavePH (nodeName, map.getGradient (PH ));
244
+ Placeholder *GradV = saveNode->getPlaceholder ();
245
+ varGrads->push_back ({PH , GradV});
246
246
}
247
247
continue ;
248
248
}
249
249
250
250
// Don't update nodes that are not marked as trainable.
251
- if (!V ->isTraining ()) {
251
+ if (!PH ->isTraining ()) {
252
252
continue ;
253
253
}
254
254
255
- auto X = new SGDNode (V ->getName (), map.getGradient (V ), V , conf.L1Decay ,
255
+ auto X = new SGDNode (PH ->getName (), map.getGradient (PH ), PH , conf.L1Decay ,
256
256
conf.L2Decay , conf.learningRate , conf.momentum ,
257
257
conf.batchSize );
258
258
toAppend.push_back (X);
259
259
// Now update the weight with the value computed by SGD.
260
- auto *save = new SaveNode (V ->getName ().str () + " .saveGrad" , {X, 0 }, V );
260
+ auto *save = new SaveNode (PH ->getName ().str () + " .saveGrad" , {X, 0 }, PH );
261
261
toAppend.push_back (save);
262
262
}
263
263
0 commit comments