Skip to content

Commit df84422

Browse files
committed
Update framework
Signed-off-by: Ryan Nett <[email protected]>
1 parent cb4bec5 commit df84422

29 files changed

+356
-577
lines changed

tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import org.tensorflow.framework.data.impl.TensorSliceDataset;
3030
import org.tensorflow.framework.data.impl.TextLineDataset;
3131
import org.tensorflow.ndarray.Shape;
32-
import org.tensorflow.op.Op;
3332
import org.tensorflow.op.Ops;
3433
import org.tensorflow.types.family.TType;
3534

@@ -254,7 +253,7 @@ public DatasetIterator makeInitializeableIterator() {
254253
* <pre>
255254
* try (Session session = new Session(graph) {
256255
* // Immediately run initializers
257-
* session.run(tf.init());
256+
* session.initialize();
258257
* }
259258
* </pre>
260259
*
@@ -264,8 +263,8 @@ public DatasetIterator makeInitializeableIterator() {
264263
*/
265264
public DatasetIterator makeOneShotIterator() {
266265
DatasetIterator iterator = makeInitializeableIterator();
267-
Op initializer = iterator.makeInitializer(this);
268-
if (tf.scope().env().isGraph()) tf.initAdd(initializer);
266+
// TODO should pass the scope instead
267+
tf.scope().env().registerInitOp(iterator.makeInitializer(this).op());
269268
return iterator;
270269
}
271270

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.tensorflow.framework.optimizers;
1717

18+
import java.util.List;
1819
import org.tensorflow.Graph;
1920
import org.tensorflow.Operand;
2021
import org.tensorflow.Output;
@@ -23,8 +24,6 @@
2324
import org.tensorflow.op.train.ApplyAdadelta;
2425
import org.tensorflow.types.family.TType;
2526

26-
import java.util.List;
27-
2827
/**
2928
* Optimizer that implements the Adadelta algorithm.
3029
*

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java

+6-6
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,15 @@
1515
*/
1616
package org.tensorflow.framework.optimizers;
1717

18+
import java.util.List;
1819
import org.tensorflow.Graph;
1920
import org.tensorflow.Operand;
2021
import org.tensorflow.Output;
2122
import org.tensorflow.op.Op;
22-
import org.tensorflow.op.train.ApplyAdagrad;
2323
import org.tensorflow.op.core.Variable;
24+
import org.tensorflow.op.train.ApplyAdagrad;
2425
import org.tensorflow.types.family.TType;
2526

26-
import java.util.List;
27-
2827
/**
2928
* Optimizer that implements the Adagrad algorithm.
3029
*
@@ -43,8 +42,8 @@ public class AdaGrad extends Optimizer {
4342
public static final float LEARNING_RATE_DEFAULT = 0.001f;
4443
public static final float INITIAL_ACCUMULATOR_DEFAULT = 0.01f;
4544

46-
private static final ApplyAdagrad.Options[] opts = new ApplyAdagrad.Options[]{
47-
ApplyAdagrad.updateSlots(true), ApplyAdagrad.useLocking(true)};
45+
private static final ApplyAdagrad.Options[] opts =
46+
new ApplyAdagrad.Options[] {ApplyAdagrad.updateSlots(true), ApplyAdagrad.useLocking(true)};
4847

4948
private final float learningRate;
5049

@@ -135,7 +134,8 @@ protected void createSlots(List<Output<? extends TType>> variables) {
135134
*/
136135
private <T extends TType> void createAdaGradSlot(Output<T> v) {
137136
Operand<T> initializer =
138-
tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.type()));
137+
tf.fill(
138+
tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.type()));
139139
createSlot(v.asOutput(), ACCUMULATOR, initializer);
140140
}
141141

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java

+8-9
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,18 @@
1515
*/
1616
package org.tensorflow.framework.optimizers;
1717

18+
import java.util.List;
19+
import java.util.Optional;
1820
import org.tensorflow.Graph;
1921
import org.tensorflow.Operand;
2022
import org.tensorflow.Output;
2123
import org.tensorflow.ndarray.Shape;
2224
import org.tensorflow.op.Op;
23-
import org.tensorflow.op.core.Assign;
2425
import org.tensorflow.op.core.Variable;
2526
import org.tensorflow.op.train.ApplyAdagradDa;
2627
import org.tensorflow.types.TInt64;
2728
import org.tensorflow.types.family.TType;
2829

29-
import java.util.List;
30-
import java.util.Optional;
31-
3230
/**
3331
* Optimizer that implements the Adagrad Dual-Averaging algorithm.
3432
*
@@ -188,9 +186,8 @@ protected void createSlots(List<Output<? extends TType>> variables) {
188186
for (Output<? extends TType> v : variables) {
189187
createAdaGradDASlot(v);
190188
}
191-
globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(), TInt64.class);
192-
Assign<TInt64> globalStepInitializer = tf.assign(globalStep, tf.constant(0L));
193-
graph.addInitializer(globalStepInitializer);
189+
globalStep = tf.initScope().withName("adagrad-da-global-step").variable(Shape.scalar(), TInt64.class);
190+
tf.initScope().assign(globalStep, tf.constant(0L));
194191
}
195192

196193
/**
@@ -200,10 +197,12 @@ protected void createSlots(List<Output<? extends TType>> variables) {
200197
* @param <T> the datatype of the variable.
201198
*/
202199
private <T extends TType> void createAdaGradDASlot(Output<T> v) {
203-
Operand<T> initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type()));
200+
Operand<T> initializer =
201+
tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type()));
204202
createSlot(v.asOutput(), ACCUMULATOR, initializer);
205203
Operand<T> sqInitializer =
206-
tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.type()));
204+
tf.fill(
205+
tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.type()));
207206
createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer);
208207
}
209208

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java

+6-9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616
package org.tensorflow.framework.optimizers;
1717

18+
import java.util.List;
19+
import java.util.Optional;
1820
import org.tensorflow.Graph;
1921
import org.tensorflow.Operand;
2022
import org.tensorflow.Output;
@@ -30,9 +32,6 @@
3032
import org.tensorflow.types.TFloat32;
3133
import org.tensorflow.types.family.TType;
3234

33-
import java.util.List;
34-
import java.util.Optional;
35-
3635
/**
3736
* Optimizer that implements the Adam algorithm.
3837
*
@@ -190,12 +189,10 @@ protected void createSlots(List<Output<? extends TType>> variables) {
190189
for (Output<? extends TType> v : variables) {
191190
createAdamSlot(v.asOutput());
192191
}
193-
betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class);
194-
Assign<TFloat32> betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne));
195-
graph.addInitializer(betaOnePowerInit);
196-
betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.class);
197-
Assign<TFloat32> betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo));
198-
graph.addInitializer(betaTwoPowerInit);
192+
betaOnePower = tf.initScope().withName("beta1_power").variable(Shape.scalar(), TFloat32.class);
193+
tf.initScope().assign(betaOnePower, tf.constant(betaOne));
194+
betaTwoPower = tf.initScope().withName("beta2_power").variable(Shape.scalar(), TFloat32.class);
195+
tf.initScope().assign(betaTwoPower, tf.constant(betaTwo));
199196
}
200197

201198
/** {@inheritDoc} */

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package org.tensorflow.framework.optimizers;
22

3+
import java.util.List;
4+
import java.util.Optional;
35
import org.tensorflow.Graph;
46
import org.tensorflow.Operand;
57
import org.tensorflow.Output;
@@ -12,9 +14,6 @@
1214
import org.tensorflow.types.TFloat32;
1315
import org.tensorflow.types.family.TType;
1416

15-
import java.util.List;
16-
import java.util.Optional;
17-
1817
/**
1918
* Optimizer that implements the Adamax algorithm.
2019
*
@@ -135,9 +134,8 @@ protected void createSlots(List<Output<? extends TType>> variables) {
135134
for (Output<? extends TType> v : variables) {
136135
createAdamaxSlot(v.asOutput());
137136
}
138-
betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class);
139-
Assign<TFloat32> betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne));
140-
((Graph) tf.scope().env()).addInitializer(betaOnePowerInit);
137+
betaOnePower = tf.initScope().withName("beta1_power").variable(Shape.scalar(), TFloat32.class);
138+
tf.initScope().assign(betaOnePower, tf.constant(betaOne));
141139
}
142140

143141
/**

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.tensorflow.framework.optimizers;
22

3+
import java.util.List;
34
import org.tensorflow.Graph;
45
import org.tensorflow.Operand;
56
import org.tensorflow.Output;
@@ -8,8 +9,6 @@
89
import org.tensorflow.op.train.ApplyFtrl;
910
import org.tensorflow.types.family.TType;
1011

11-
import java.util.List;
12-
1312
/**
1413
* Optimizer that implements the FTRL algorithm.
1514
*
@@ -230,7 +229,8 @@ protected void createSlots(List<Output<? extends TType>> variables) {
230229
*/
231230
private <T extends TType> void createFtrlSlot(Output<T> v) {
232231
Operand<T> initializer =
233-
tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.type()));
232+
tf.fill(
233+
tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.type()));
234234
createSlot(v.asOutput(), ACCUMULATOR, initializer);
235235
Operand<T> linearInitializer =
236236
tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type()));

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.tensorflow.framework.optimizers;
1717

18+
import java.util.List;
1819
import org.tensorflow.Graph;
1920
import org.tensorflow.Operand;
2021
import org.tensorflow.Output;
@@ -23,8 +24,6 @@
2324
import org.tensorflow.op.train.ApplyMomentum;
2425
import org.tensorflow.types.family.TType;
2526

26-
import java.util.List;
27-
2827
/**
2928
* Stochastic gradient descent plus momentum, either nesterov or traditional.
3029
*
@@ -125,7 +124,8 @@ protected void createSlots(List<Output<? extends TType>> variables) {
125124
* @param <T> the data type of the variable
126125
*/
127126
private <T extends TType> void createMomentumSlot(Output<T> v) {
128-
Operand<T> initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type()));
127+
Operand<T> initializer =
128+
tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type()));
129129
createSlot(v.asOutput(), MOMENTUM, initializer);
130130
}
131131

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java

+8-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package org.tensorflow.framework.optimizers;
22

3+
import java.util.List;
4+
import java.util.Optional;
35
import org.tensorflow.Graph;
46
import org.tensorflow.Operand;
57
import org.tensorflow.Output;
@@ -12,9 +14,6 @@
1214
import org.tensorflow.types.TInt64;
1315
import org.tensorflow.types.family.TType;
1416

15-
import java.util.List;
16-
import java.util.Optional;
17-
1817
/**
1918
* Nadam Optimizer that implements the NAdam algorithm.
2019
*
@@ -140,17 +139,14 @@ protected void createSlots(List<Output<? extends TType>> variables) {
140139
for (Output<? extends TType> v : variables) {
141140
createNadamSlot(v.asOutput());
142141
}
143-
betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class);
144-
Assign<TFloat32> betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne));
145-
((Graph) tf.scope().env()).addInitializer(betaOnePowerInit);
142+
betaOnePower = tf.initScope().withName("beta1_power").variable(Shape.scalar(), TFloat32.class);
143+
tf.initScope().assign(betaOnePower, tf.constant(betaOne));
146144

147-
betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.class);
148-
Assign<TFloat32> betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo));
149-
((Graph) tf.scope().env()).addInitializer(betaTwoPowerInit);
145+
betaTwoPower = tf.initScope().withName("beta2_power").variable(Shape.scalar(), TFloat32.class);
146+
tf.initScope().assign(betaTwoPower, tf.constant(betaTwo));
150147

151-
momentum = tf.withName("momentum").variable(Shape.scalar(), TFloat32.class);
152-
Assign<TFloat32> momentumInit = tf.assign(momentum, tf.constant(1.0F));
153-
((Graph) tf.scope().env()).addInitializer(momentumInit);
148+
momentum = tf.initScope().withName("momentum").variable(Shape.scalar(), TFloat32.class);
149+
tf.initScope().assign(momentum, tf.constant(1.0F));
154150
}
155151

156152
/**

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,19 @@
1515
*/
1616
package org.tensorflow.framework.optimizers;
1717

18+
import java.util.*;
19+
import java.util.stream.Collectors;
1820
import org.tensorflow.Graph;
1921
import org.tensorflow.Operand;
2022
import org.tensorflow.Operation;
2123
import org.tensorflow.Output;
2224
import org.tensorflow.op.Op;
2325
import org.tensorflow.op.Ops;
2426
import org.tensorflow.op.Scope;
25-
import org.tensorflow.op.core.Assign;
2627
import org.tensorflow.op.core.NoOp;
2728
import org.tensorflow.op.core.Variable;
2829
import org.tensorflow.types.family.TType;
2930

30-
import java.util.*;
31-
import java.util.stream.Collectors;
32-
3331
/** Base class for gradient optimizers. */
3432
public abstract class Optimizer {
3533

@@ -41,6 +39,7 @@ public abstract class Optimizer {
4139
protected final Graph graph;
4240
/** The ops builder for the graph. */
4341
protected final Ops tf;
42+
4443
/** Top level map key is the variable name, lower level map key is the slot name. */
4544
private final Map<String, Map<String, Variable<?>>> slots;
4645

@@ -221,9 +220,10 @@ private <T extends TType> Optional<Variable<T>> getSlot(String varName, String s
221220
protected <T extends TType> void createSlot(
222221
Output<T> variable, String slotName, Operand<T> initializer) {
223222
Variable<T> slot =
224-
tf.withName(createName(variable, slotName)).variable(variable.shape(), variable.type());
225-
Assign<T> slotInit = tf.assign(slot, initializer);
226-
graph.addInitializer(slotInit);
223+
tf.initScope()
224+
.withName(createName(variable, slotName))
225+
.variable(variable.shape(), variable.type());
226+
tf.initScope().assign(slot, initializer);
227227
String varName = variable.op().name();
228228
Map<String, Variable<? extends TType>> variables =
229229
slots.computeIfAbsent(slotName, (k) -> new HashMap<>());

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.tensorflow.framework.optimizers;
1717

18+
import java.util.List;
1819
import org.tensorflow.Graph;
1920
import org.tensorflow.Operand;
2021
import org.tensorflow.Output;
@@ -24,8 +25,6 @@
2425
import org.tensorflow.op.train.ApplyRmsProp;
2526
import org.tensorflow.types.family.TType;
2627

27-
import java.util.List;
28-
2928
/**
3029
* Optimizer that implements the RMSProp algorithm.
3130
*
@@ -177,13 +176,15 @@ protected void createSlots(List<Output<? extends TType>> variables) {
177176
* @param <T> the datatype of the variable.
178177
*/
179178
private <T extends TType> void createRMSPropSlot(Output<T> v) {
180-
Operand<T> rmsInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type()));
179+
Operand<T> rmsInitializer =
180+
tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type()));
181181
createSlot(v.asOutput(), RMS, rmsInitializer);
182182
Operand<T> momentumInitializer =
183183
tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type()));
184184
createSlot(v.asOutput(), MOMENTUM, momentumInitializer);
185185
if (centered) {
186-
Operand<T> mgInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type()));
186+
Operand<T> mgInitializer =
187+
tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type()));
187188
createSlot(v.asOutput(), MG, mgInitializer);
188189
}
189190
}

0 commit comments

Comments
 (0)