Skip to content

Commit fbd5ad0

Browse files
committed
Add forceInitialize to reinitialize session
Signed-off-by: Ryan Nett <[email protected]>
1 parent d6469d1 commit fbd5ad0

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
import org.tensorflow.internal.c_api.TF_Tensor;
6060
import org.tensorflow.ndarray.Shape;
6161
import org.tensorflow.op.Scope;
62-
import org.tensorflow.op.core.Constant;
6362
import org.tensorflow.proto.framework.AttrValue;
6463
import org.tensorflow.proto.framework.AttrValue.ListValue;
6564
import org.tensorflow.proto.framework.DataType;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ public void close() {
145145
* Execute the graph's initializers.
146146
*
147147
* <p>This runs any ops that have been created with an init scope.
148+
* @throws IllegalStateException if the session has already been initialized
148149
*/
149150
public synchronized void initialize() {
150151
if (hasInitialized) {
@@ -165,6 +166,26 @@ public synchronized void initialize() {
165166
hasInitialized = true;
166167
}
167168

169+
/**
170+
* Execute the graph's initializers, regardless of whether the session has been initialized.
171+
*
172+
* <p>This runs any ops that have been created with an init scope.
173+
*/
174+
public synchronized void forceInitialize() {
175+
if (!graph.hasInitializers()) {
176+
hasInitialized = true;
177+
return;
178+
}
179+
180+
List<Operation> initializers = graph.initializers();
181+
if (!initializers.isEmpty()) {
182+
Runner runner = runner();
183+
initializers.forEach(runner::addTarget);
184+
runner.runNoInit();
185+
}
186+
hasInitialized = true;
187+
}
188+
168189
/** Create a session and initialize it. */
169190
public static Session initialized(Graph graph) {
170191
Session s = new Session(graph);

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,9 @@ public void pythonTfFunction() {
323323
private static Signature buildGraphWithVariables(Ops tf, Shape xShape) {
324324
Placeholder<TFloat32> x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape));
325325
Variable<TFloat32> y =
326-
tf.initScope().withName("variable")
327-
.variable(
328-
tf
329-
.random
330-
.randomUniform(tf.constant(xShape), TFloat32.class));
326+
tf.initScope()
327+
.withName("variable")
328+
.variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class));
331329
ReduceSum<TFloat32> z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1));
332330
return Signature.builder().input("input", x).output("reducedSum", z).build();
333331
}

0 commit comments

Comments
 (0)