@@ -81,6 +81,16 @@ namespace af {
81
81
return m_shared->m_grads [0 ];
82
82
}
83
83
84
+ std::ptrdiff_t Variable::id () const
85
+ {
86
+ return (std::ptrdiff_t )m_shared.get ();
87
+ }
88
+
89
+ std::vector<Variable> Variable::getInputs () const
90
+ {
91
+ return m_shared->m_inputs ;
92
+ }
93
+
84
94
bool Variable::isCalcGrad () const
85
95
{
86
96
return m_shared->m_calc_grad ;
@@ -140,31 +150,31 @@ namespace af {
140
150
void Variable::backward (const Variable &grad, bool retain_grad_graph)
141
151
{
142
152
this ->addGrad (grad);
143
- Variable::DAG_t dag = this -> build ();
153
+ Variable::DAG_t dag = Variable:: build (* this );
144
154
for (auto iter = dag.rbegin (); iter != dag.rend (); iter++) {
145
155
iter->calcGradInputs (retain_grad_graph);
146
156
}
147
157
}
148
158
149
- Variable::DAG_t Variable::build ()
159
+ Variable::DAG_t Variable::build (const Variable &var )
150
160
{
151
161
Cache_t cache;
152
- Variable::DAG_t dag;
153
- this -> buildSubGraph (cache, dag);
162
+ Variable::DAG_t dag;
163
+ Variable:: buildSubGraph (cache, dag, var );
154
164
return dag;
155
165
}
156
166
157
- void Variable::buildSubGraph (Cache_t &cache, Variable::DAG_t &dag)
167
+ void Variable::buildSubGraph (Cache_t &cache, Variable::DAG_t &dag, const Variable &var )
158
168
{
159
- std::ptrdiff_t id = (std:: ptrdiff_t )m_shared. get ();
169
+ std::ptrdiff_t id = var. id ();
160
170
if (cache.find (id) != cache.end ()) {
161
171
return ;
162
172
}
163
- for (auto input : m_shared-> m_inputs ) {
164
- input. buildSubGraph (cache, dag);
173
+ for (auto input : var. getInputs () ) {
174
+ Variable:: buildSubGraph (cache, dag, input );
165
175
}
166
176
cache[id] = true ;
167
- dag.push_back (* this );
177
+ dag.push_back (var );
168
178
}
169
179
}
170
180
}
0 commit comments