Skip to content

Commit 3f832a0

Browse files
committed
Store gradients as autograd::Variable instead of af::array
1 parent 562b860 commit 3f832a0

File tree

2 files changed

+28
-25
lines changed

2 files changed

+28
-25
lines changed

examples/autograd.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ void test1()
1919
af_print(y.getData());
2020
auto dy = Variable(af::constant(1.0, 5));
2121
backward(y, dy);
22-
af_print(x.getGrad() - 2 * x.getData());
22+
af_print(x.getGrad().getData() - 2 * x.getData());
2323
}
2424

2525
void test2()
@@ -31,8 +31,8 @@ void test2()
3131
auto z = x * x + x * y + y * y;
3232
auto dz = Variable(af::constant(1.0, 5));
3333
backward(z, dz);
34-
af_print(x.getGrad() - 2 * x.getData() - y.getData());
35-
af_print(y.getGrad() - 2 * y.getData() - x.getData());
34+
af_print(x.getGrad().getData() - 2 * x.getData() - y.getData());
35+
af_print(y.getGrad().getData() - 2 * y.getData() - x.getData());
3636
}
3737

3838
int main()

include/af/autograd/Variable.hpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,17 @@
1414
#include <memory>
1515
#include <vector>
1616
#include <unordered_map>
17+
#include <stdexcept>
1718

1819
#include <arrayfire.h>
1920

2021
namespace af {
2122
namespace autograd {
23+
24+
// Forward declare the function
25+
class Variable;
26+
Variable operator +(const Variable lhs, const Variable rhs);
27+
2228
class Variable
2329
{
2430
public:
@@ -31,25 +37,22 @@ namespace af {
3137
public:
3238
Shared() :
3339
m_data(),
34-
m_grad(),
3540
m_inputs(),
36-
m_grad_parts(),
41+
m_grads(),
3742
m_backward(nullptr)
3843
{}
3944

4045
Shared(af::array data) :
4146
m_data(data),
42-
m_grad(af::constant(0, data.dims(), data.type())),
4347
m_inputs(),
44-
m_grad_parts(),
48+
m_grads(),
4549
m_backward(nullptr)
4650
{}
4751

4852
Shared(af::array data, std::vector<Variable> inputs, BackwardFunc_t backward) :
4953
m_data(data),
50-
m_grad(af::constant(0, data.dims(), data.type())),
5154
m_inputs(inputs.begin(), inputs.end()),
52-
m_grad_parts(),
55+
m_grads(),
5356
m_backward(backward)
5457
{}
5558

@@ -58,19 +61,17 @@ namespace af {
5861
return m_data;
5962
}
6063

61-
af::array getGrad() const
64+
Variable getGrad() const
6265
{
63-
return m_grad;
66+
if (m_grads.size() == 0) {
67+
throw std::runtime_error("Gradient hasn't been calculated");
68+
}
69+
return m_grads[0];
6470
}
6571

6672
void addGrad(Variable grad)
6773
{
68-
m_grad_parts.push_back(grad);
69-
}
70-
71-
std::vector<Variable> getGradParts()
72-
{
73-
return m_grad_parts;
74+
m_grads.push_back(grad);
7475
}
7576

7677
std::vector<Variable> getInputs()
@@ -80,24 +81,26 @@ namespace af {
8081

8182
void evalGrad()
8283
{
83-
m_grad = m_grad_parts[0].getData();
84-
for (int i = 1; i < (int)m_grad_parts.size(); i++) {
85-
m_grad += m_grad_parts[i].getData();
84+
if (m_grads.size() == 1) return;
85+
Variable grad = m_grads[0];
86+
for (int i = 1; i < (int)m_grads.size(); i++) {
87+
grad = grad + m_grads[i];
8688
}
87-
af::eval(m_grad);
89+
grad.getData().eval();
90+
m_grads.clear();
91+
m_grads.push_back(grad);
8892
}
8993

9094
void backward()
9195
{
9296
this->evalGrad();
93-
if (m_backward) m_backward(m_inputs, m_grad);
97+
if (m_backward) m_backward(m_inputs, m_grads[0]);
9498
}
9599

96100
private:
97101
af::array m_data;
98-
af::array m_grad;
99102
std::vector<Variable> m_inputs;
100-
std::vector<Variable> m_grad_parts;
103+
std::vector<Variable> m_grads;
101104
BackwardFunc_t m_backward;
102105
};
103106

@@ -123,7 +126,7 @@ namespace af {
123126
return m_shared->getData();
124127
}
125128

126-
af::array getGrad() const
129+
Variable getGrad() const
127130
{
128131
return m_shared->getGrad();
129132
}

0 commit comments

Comments
 (0)