15
15
*/
16
16
package org .tensorflow .framework .optimizers ;
17
17
18
+ import java .util .List ;
19
+ import java .util .Optional ;
18
20
import org .tensorflow .Graph ;
19
21
import org .tensorflow .Operand ;
20
22
import org .tensorflow .Output ;
21
23
import org .tensorflow .ndarray .Shape ;
22
24
import org .tensorflow .op .Op ;
23
- import org .tensorflow .op .core .Assign ;
24
25
import org .tensorflow .op .core .Variable ;
25
26
import org .tensorflow .op .train .ApplyAdagradDa ;
26
27
import org .tensorflow .types .TInt64 ;
27
28
import org .tensorflow .types .family .TType ;
28
29
29
- import java .util .List ;
30
- import java .util .Optional ;
31
-
32
30
/**
33
31
* Optimizer that implements the Adagrad Dual-Averaging algorithm.
34
32
*
@@ -188,9 +186,8 @@ protected void createSlots(List<Output<? extends TType>> variables) {
188
186
for (Output <? extends TType > v : variables ) {
189
187
createAdaGradDASlot (v );
190
188
}
191
- globalStep = tf .withName ("adagrad-da-global-step" ).variable (Shape .scalar (), TInt64 .class );
192
- Assign <TInt64 > globalStepInitializer = tf .assign (globalStep , tf .constant (0L ));
193
- graph .addInitializer (globalStepInitializer );
189
+ globalStep = tf .initScope ().withName ("adagrad-da-global-step" ).variable (Shape .scalar (), TInt64 .class );
190
+ tf .initScope ().assign (globalStep , tf .constant (0L ));
194
191
}
195
192
196
193
/**
@@ -200,10 +197,12 @@ protected void createSlots(List<Output<? extends TType>> variables) {
200
197
* @param <T> the datatype of the variable.
201
198
*/
202
199
private <T extends TType > void createAdaGradDASlot (Output <T > v ) {
203
- Operand <T > initializer = tf .fill (tf .shape (v ), tf .dtypes .cast (tf .constant (0.0f ), v .type ()));
200
+ Operand <T > initializer =
201
+ tf .fill (tf .shape (v ), tf .dtypes .cast (tf .constant (0.0f ), v .type ()));
204
202
createSlot (v .asOutput (), ACCUMULATOR , initializer );
205
203
Operand <T > sqInitializer =
206
- tf .fill (tf .shape (v ), tf .dtypes .cast (tf .constant (initialAccumulatorValue ), v .type ()));
204
+ tf .fill (
205
+ tf .shape (v ), tf .dtypes .cast (tf .constant (initialAccumulatorValue ), v .type ()));
207
206
createSlot (v .asOutput (), SQUARED_ACCUMULATOR , sqInitializer );
208
207
}
209
208
0 commit comments