Skip to content

Commit 8bf7f1b

Browse files
committed
Convert Variable::build and Variable::buildSubGraph to static functions
1 parent 49b8917 commit 8bf7f1b

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

include/af/autograd/Variable.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ namespace af {
5454

5555
Variable grad() const;
5656

57+
std::ptrdiff_t id() const;
58+
5759
bool isCalcGrad() const;
5860

5961
void setCalcGrad(bool calc_grad);
@@ -64,11 +66,16 @@ namespace af {
6466

6567
void backward(const Variable &grad, bool retain_grad_graph = false);
6668

67-
void buildSubGraph(Cache_t &cache, DAG_t &dag);
69+
6870
private:
6971
void evalGrad(bool retain_grad_graph = false);
7072

71-
DAG_t build();
73+
std::vector<Variable> getInputs() const;
74+
75+
static void buildSubGraph(Cache_t &cache, DAG_t &dag, const Variable &var);
76+
77+
static DAG_t build(const Variable &var);
78+
7279
std::shared_ptr<Shared> m_shared;
7380
};
7481
}

src/autograd/Variable.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ namespace af {
8181
return m_shared->m_grads[0];
8282
}
8383

84+
std::ptrdiff_t Variable::id() const
85+
{
86+
return (std::ptrdiff_t)m_shared.get();
87+
}
88+
89+
std::vector<Variable> Variable::getInputs() const
90+
{
91+
return m_shared->m_inputs;
92+
}
93+
8494
bool Variable::isCalcGrad() const
8595
{
8696
return m_shared->m_calc_grad;
@@ -140,31 +150,31 @@ namespace af {
140150
void Variable::backward(const Variable &grad, bool retain_grad_graph)
141151
{
142152
this->addGrad(grad);
143-
Variable::DAG_t dag = this->build();
153+
Variable::DAG_t dag = Variable::build(*this);
144154
for (auto iter = dag.rbegin(); iter != dag.rend(); iter++) {
145155
iter->calcGradInputs(retain_grad_graph);
146156
}
147157
}
148158

149-
Variable::DAG_t Variable::build()
159+
Variable::DAG_t Variable::build(const Variable &var)
150160
{
151161
Cache_t cache;
152-
Variable::DAG_t dag;
153-
this->buildSubGraph(cache, dag);
162+
Variable::DAG_t dag;
163+
Variable::buildSubGraph(cache, dag, var);
154164
return dag;
155165
}
156166

157-
void Variable::buildSubGraph(Cache_t &cache, Variable::DAG_t &dag)
167+
void Variable::buildSubGraph(Cache_t &cache, Variable::DAG_t &dag, const Variable &var)
158168
{
159-
std::ptrdiff_t id = (std::ptrdiff_t)m_shared.get();
169+
std::ptrdiff_t id = var.id();
160170
if (cache.find(id) != cache.end()) {
161171
return;
162172
}
163-
for (auto input : m_shared->m_inputs) {
164-
input.buildSubGraph(cache, dag);
173+
for (auto input : var.getInputs()) {
174+
Variable::buildSubGraph(cache, dag, input);
165175
}
166176
cache[id] = true;
167-
dag.push_back(*this);
177+
dag.push_back(var);
168178
}
169179
}
170180
}

0 commit comments

Comments
 (0)