14
14
#include < memory>
15
15
#include < vector>
16
16
#include < unordered_map>
17
+ #include < stdexcept>
17
18
18
19
#include < arrayfire.h>
19
20
20
21
namespace af {
21
22
namespace autograd {
23
+
24
+ // Forward declare the function
25
+ class Variable ;
26
+ Variable operator +(const Variable lhs, const Variable rhs);
27
+
22
28
class Variable
23
29
{
24
30
public:
@@ -31,25 +37,22 @@ namespace af {
31
37
public:
32
38
Shared () :
33
39
m_data (),
34
- m_grad (),
35
40
m_inputs (),
36
- m_grad_parts (),
41
+ m_grads (),
37
42
m_backward (nullptr )
38
43
{}
39
44
40
45
Shared (af::array data) :
41
46
m_data (data),
42
- m_grad (af::constant(0 , data.dims(), data.type())),
43
47
m_inputs (),
44
- m_grad_parts (),
48
+ m_grads (),
45
49
m_backward (nullptr )
46
50
{}
47
51
48
52
Shared (af::array data, std::vector<Variable> inputs, BackwardFunc_t backward) :
49
53
m_data (data),
50
- m_grad (af::constant(0 , data.dims(), data.type())),
51
54
m_inputs (inputs.begin(), inputs.end()),
52
- m_grad_parts (),
55
+ m_grads (),
53
56
m_backward (backward)
54
57
{}
55
58
@@ -58,19 +61,17 @@ namespace af {
58
61
return m_data;
59
62
}
60
63
61
- af::array getGrad () const
64
+ Variable getGrad () const
62
65
{
63
- return m_grad;
66
+ if (m_grads.size () == 0 ) {
67
+ throw std::runtime_error (" Gradient hasn't been calculated" );
68
+ }
69
+ return m_grads[0 ];
64
70
}
65
71
66
72
void addGrad (Variable grad)
67
73
{
68
- m_grad_parts.push_back (grad);
69
- }
70
-
71
- std::vector<Variable> getGradParts ()
72
- {
73
- return m_grad_parts;
74
+ m_grads.push_back (grad);
74
75
}
75
76
76
77
std::vector<Variable> getInputs ()
@@ -80,24 +81,26 @@ namespace af {
80
81
81
82
void evalGrad ()
82
83
{
83
- m_grad = m_grad_parts[0 ].getData ();
84
- for (int i = 1 ; i < (int )m_grad_parts.size (); i++) {
85
- m_grad += m_grad_parts[i].getData ();
84
+ if (m_grads.size () == 1 ) return ;
85
+ Variable grad = m_grads[0 ];
86
+ for (int i = 1 ; i < (int )m_grads.size (); i++) {
87
+ grad = grad + m_grads[i];
86
88
}
87
- af::eval (m_grad);
89
+ grad.getData ().eval ();
90
+ m_grads.clear ();
91
+ m_grads.push_back (grad);
88
92
}
89
93
90
94
void backward ()
91
95
{
92
96
this ->evalGrad ();
93
- if (m_backward) m_backward (m_inputs, m_grad );
97
+ if (m_backward) m_backward (m_inputs, m_grads[ 0 ] );
94
98
}
95
99
96
100
private:
97
101
af::array m_data;
98
- af::array m_grad;
99
102
std::vector<Variable> m_inputs;
100
- std::vector<Variable> m_grad_parts ;
103
+ std::vector<Variable> m_grads ;
101
104
BackwardFunc_t m_backward;
102
105
};
103
106
@@ -123,7 +126,7 @@ namespace af {
123
126
return m_shared->getData ();
124
127
}
125
128
126
- af::array getGrad () const
129
+ Variable getGrad () const
127
130
{
128
131
return m_shared->getGrad ();
129
132
}
0 commit comments