Skip to content

Commit fe38648

Browse files
committed
Cleaning up internals of autograd::Variable
1 parent 3f832a0 commit fe38648

File tree

2 files changed

+21
-52
lines changed

2 files changed

+21
-52
lines changed

examples/FFNet.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
#include <af/nn.h>
1111

12-
using namespace af;
1312
using namespace af;
1413
using namespace af::nn;
1514

include/af/autograd/Variable.hpp

Lines changed: 21 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ namespace af {
3333
typedef std::vector<Variable> DAG_t;
3434

3535
private:
36-
class Shared {
37-
public:
36+
struct Shared {
3837
Shared() :
3938
m_data(),
4039
m_inputs(),
@@ -56,48 +55,6 @@ namespace af {
5655
m_backward(backward)
5756
{}
5857

59-
af::array getData() const
60-
{
61-
return m_data;
62-
}
63-
64-
Variable getGrad() const
65-
{
66-
if (m_grads.size() == 0) {
67-
throw std::runtime_error("Gradient hasn't been calculated");
68-
}
69-
return m_grads[0];
70-
}
71-
72-
void addGrad(Variable grad)
73-
{
74-
m_grads.push_back(grad);
75-
}
76-
77-
std::vector<Variable> getInputs()
78-
{
79-
return m_inputs;
80-
}
81-
82-
void evalGrad()
83-
{
84-
if (m_grads.size() == 1) return;
85-
Variable grad = m_grads[0];
86-
for (int i = 1; i < (int)m_grads.size(); i++) {
87-
grad = grad + m_grads[i];
88-
}
89-
grad.getData().eval();
90-
m_grads.clear();
91-
m_grads.push_back(grad);
92-
}
93-
94-
void backward()
95-
{
96-
this->evalGrad();
97-
if (m_backward) m_backward(m_inputs, m_grads[0]);
98-
}
99-
100-
private:
10158
af::array m_data;
10259
std::vector<Variable> m_inputs;
10360
std::vector<Variable> m_grads;
@@ -123,32 +80,45 @@ namespace af {
12380

12481
af::array getData() const
12582
{
126-
return m_shared->getData();
83+
return m_shared->m_data;
12784
}
12885

12986
Variable getGrad() const
13087
{
131-
return m_shared->getGrad();
88+
if (m_shared->m_grads.size() == 0) {
89+
throw std::runtime_error("Gradient hasn't been calculated");
90+
}
91+
return m_shared->m_grads[0];
13292
}
13393

13494
void addGrad(Variable child_grad)
13595
{
136-
m_shared->addGrad(child_grad);
96+
m_shared->m_grads.push_back(child_grad);
13797
}
13898

13999
std::vector<Variable> getInputs() const
140100
{
141-
return m_shared->getInputs();
101+
return m_shared->m_inputs;
142102
}
143103

144104
void evalGrad()
145105
{
146-
m_shared->evalGrad();
106+
if (m_shared->m_grads.size() == 1) return;
107+
Variable grad = m_shared->m_grads[0];
108+
for (unsigned i = 1; i < m_shared->m_grads.size(); i++) {
109+
grad = grad + m_shared->m_grads[i];
110+
}
111+
grad.getData().eval();
112+
m_shared->m_grads.clear();
113+
m_shared->m_grads.push_back(grad);
147114
}
148115

149116
void backward()
150117
{
151-
m_shared->backward();
118+
evalGrad();
119+
if (m_shared->m_backward) {
120+
m_shared->m_backward(m_shared->m_inputs, m_shared->m_grads[0]);
121+
}
152122
}
153123

154124
DAG_t build()
@@ -165,7 +135,7 @@ namespace af {
165135
if (cache.find(id) != cache.end()) {
166136
return;
167137
}
168-
for (auto input : m_shared->getInputs()) {
138+
for (auto input : m_shared->m_inputs) {
169139
input.buildGraph(cache, dag);
170140
}
171141
cache[id] = true;

0 commit comments

Comments
 (0)