@@ -33,8 +33,7 @@ namespace af {
33
33
typedef std::vector<Variable> DAG_t;
34
34
35
35
private:
36
- class Shared {
37
- public:
36
+ struct Shared {
38
37
Shared () :
39
38
m_data (),
40
39
m_inputs (),
@@ -56,48 +55,6 @@ namespace af {
56
55
m_backward (backward)
57
56
{}
58
57
59
- af::array getData () const
60
- {
61
- return m_data;
62
- }
63
-
64
- Variable getGrad () const
65
- {
66
- if (m_grads.size () == 0 ) {
67
- throw std::runtime_error (" Gradient hasn't been calculated" );
68
- }
69
- return m_grads[0 ];
70
- }
71
-
72
- void addGrad (Variable grad)
73
- {
74
- m_grads.push_back (grad);
75
- }
76
-
77
- std::vector<Variable> getInputs ()
78
- {
79
- return m_inputs;
80
- }
81
-
82
- void evalGrad ()
83
- {
84
- if (m_grads.size () == 1 ) return ;
85
- Variable grad = m_grads[0 ];
86
- for (int i = 1 ; i < (int )m_grads.size (); i++) {
87
- grad = grad + m_grads[i];
88
- }
89
- grad.getData ().eval ();
90
- m_grads.clear ();
91
- m_grads.push_back (grad);
92
- }
93
-
94
- void backward ()
95
- {
96
- this ->evalGrad ();
97
- if (m_backward) m_backward (m_inputs, m_grads[0 ]);
98
- }
99
-
100
- private:
101
58
af::array m_data;
102
59
std::vector<Variable> m_inputs;
103
60
std::vector<Variable> m_grads;
@@ -123,32 +80,45 @@ namespace af {
123
80
124
81
af::array getData () const
125
82
{
126
- return m_shared->getData () ;
83
+ return m_shared->m_data ;
127
84
}
128
85
129
86
Variable getGrad () const
130
87
{
131
- return m_shared->getGrad ();
88
+ if (m_shared->m_grads .size () == 0 ) {
89
+ throw std::runtime_error (" Gradient hasn't been calculated" );
90
+ }
91
+ return m_shared->m_grads [0 ];
132
92
}
133
93
134
94
void addGrad (Variable child_grad)
135
95
{
136
- m_shared->addGrad (child_grad);
96
+ m_shared->m_grads . push_back (child_grad);
137
97
}
138
98
139
99
std::vector<Variable> getInputs () const
140
100
{
141
- return m_shared->getInputs () ;
101
+ return m_shared->m_inputs ;
142
102
}
143
103
144
104
void evalGrad ()
145
105
{
146
- m_shared->evalGrad ();
106
+ if (m_shared->m_grads .size () == 1 ) return ;
107
+ Variable grad = m_shared->m_grads [0 ];
108
+ for (unsigned i = 1 ; i < m_shared->m_grads .size (); i++) {
109
+ grad = grad + m_shared->m_grads [i];
110
+ }
111
+ grad.getData ().eval ();
112
+ m_shared->m_grads .clear ();
113
+ m_shared->m_grads .push_back (grad);
147
114
}
148
115
149
116
void backward ()
150
117
{
151
- m_shared->backward ();
118
+ evalGrad ();
119
+ if (m_shared->m_backward ) {
120
+ m_shared->m_backward (m_shared->m_inputs , m_shared->m_grads [0 ]);
121
+ }
152
122
}
153
123
154
124
DAG_t build ()
@@ -165,7 +135,7 @@ namespace af {
165
135
if (cache.find (id) != cache.end ()) {
166
136
return ;
167
137
}
168
- for (auto input : m_shared->getInputs () ) {
138
+ for (auto input : m_shared->m_inputs ) {
169
139
input.buildGraph (cache, dag);
170
140
}
171
141
cache[id] = true ;
0 commit comments