Skip to content

Commit 242931c

Browse files
authored
Init Scope (#338)
* Start of init scope Signed-off-by: Ryan Nett <[email protected]> * Fix tests Signed-off-by: Ryan Nett <[email protected]> * Javadoc updates Signed-off-by: Ryan Nett <[email protected]> * Session init helpers Signed-off-by: Ryan Nett <[email protected]> * Format fixes Signed-off-by: Ryan Nett <[email protected]> * Make initEnv default to this. Signed-off-by: Ryan Nett <[email protected]> * More formatting fixes Signed-off-by: Ryan Nett <[email protected]> * Small fixes, add native pointer based equals and hashCode to EagerOperation Signed-off-by: Ryan Nett <[email protected]> * Export init ops to GraphDefs and import from them Signed-off-by: Ryan Nett <[email protected]> * Test adding init ops after import Signed-off-by: Ryan Nett <[email protected]> * Automatically lift constants to init if required Signed-off-by: Ryan Nett <[email protected]> * Add withInitScope Signed-off-by: Ryan Nett <[email protected]> * Allow init ops to depend on other init ops Signed-off-by: Ryan Nett <[email protected]> * Add void withInitScope Signed-off-by: Ryan Nett <[email protected]> * Lift init inputs to init as well Signed-off-by: Ryan Nett <[email protected]> * Formatting Signed-off-by: Ryan Nett <[email protected]> * Allow use of init input lifting Signed-off-by: Ryan Nett <[email protected]> * Add forceInitialize to reinitialize session Signed-off-by: Ryan Nett <[email protected]> * Replace withInitScope with liftToInitScope Signed-off-by: Ryan Nett <[email protected]> * Update framework Signed-off-by: Ryan Nett <[email protected]> * Rebase fixes Signed-off-by: Ryan Nett <[email protected]> * Remove Java 10 API Signed-off-by: Ryan Nett <[email protected]> * Check control dependencies up front Signed-off-by: Ryan Nett <[email protected]> * Reorder methods Signed-off-by: Ryan Nett <[email protected]> * Session constructors Signed-off-by: Ryan Nett <[email protected]> * Variable with init value uses passed scope Signed-off-by: Ryan Nett <[email protected]> * Change initScope to withInitScope Signed-off-by: Ryan Nett <[email protected]> * Fix initScope usages Signed-off-by: Ryan Nett <[email protected]> * Variable fixes Signed-off-by: Ryan Nett <[email protected]> * Don't put reset ops in init scope Signed-off-by: Ryan Nett <[email protected]> * New session initialization, formatting Signed-off-by: Ryan Nett <[email protected]> * Fix a test Signed-off-by: Ryan Nett <[email protected]> * Fix generated names Signed-off-by: Ryan Nett <[email protected]> * Comment fixes Signed-off-by: Ryan Nett <[email protected]> * More comment fixes Signed-off-by: Ryan Nett <[email protected]> * Track init ops and topmost init ops separately Signed-off-by: Ryan Nett <[email protected]> * Don't track topmost init ops Signed-off-by: Ryan Nett <[email protected]> * Unsynchronize Session Signed-off-by: Ryan Nett <[email protected]> * Fix format Signed-off-by: Ryan Nett <[email protected]> * Fix not building init op Signed-off-by: Ryan Nett <[email protected]> * Use init scope in variable-with-init Signed-off-by: Ryan Nett <[email protected]> * Fix comments, make ranInits final Signed-off-by: Ryan Nett <[email protected]> * Update toGraphDef comment Signed-off-by: Ryan Nett <[email protected]> * Update Sessions comment Signed-off-by: Ryan Nett <[email protected]> * Remove wildcard import Signed-off-by: Ryan Nett <[email protected]> * Change init op name to Init Signed-off-by: Ryan Nett <[email protected]> * Don't include NoOps in initializer list Signed-off-by: Ryan Nett <[email protected]> * Fix format Signed-off-by: Ryan Nett <[email protected]>
1 parent 4c1c271 commit 242931c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1363
-1349
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

+60-95
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@
106106
import org.tensorflow.op.core.IdentityN;
107107
import org.tensorflow.op.core.If;
108108
import org.tensorflow.op.core.ImmutableConst;
109-
import org.tensorflow.op.core.Init;
110109
import org.tensorflow.op.core.InitializeTable;
111110
import org.tensorflow.op.core.InitializeTableFromTextFile;
112111
import org.tensorflow.op.core.InplaceAdd;
@@ -295,6 +294,12 @@
295294
import org.tensorflow.op.core.VariableShape;
296295
import org.tensorflow.op.core.Where;
297296
import org.tensorflow.op.core.While;
297+
import org.tensorflow.op.core.XlaConvV2;
298+
import org.tensorflow.op.core.XlaDotV2;
299+
import org.tensorflow.op.core.XlaSetDynamicDimensionSize;
300+
import org.tensorflow.op.core.XlaSpmdFullToShardShape;
301+
import org.tensorflow.op.core.XlaSpmdShardToFullShape;
302+
import org.tensorflow.op.core.XlaVariadicSort;
298303
import org.tensorflow.op.core.Zeros;
299304
import org.tensorflow.op.core.ZerosLike;
300305
import org.tensorflow.types.TBool;
@@ -366,20 +371,20 @@ public final class Ops {
366371

367372
public final SparseOps sparse;
368373

369-
public final TpuOps tpu;
370-
371374
public final BitwiseOps bitwise;
372375

376+
public final TpuOps tpu;
377+
373378
public final MathOps math;
374379

375380
public final AudioOps audio;
376381

377382
public final SignalOps signal;
378383

379-
public final QuantizationOps quantization;
380-
381384
public final TrainOps train;
382385

386+
public final QuantizationOps quantization;
387+
383388
private final Scope scope;
384389

385390
private Ops(Scope scope) {
@@ -397,13 +402,13 @@ private Ops(Scope scope) {
397402
random = new RandomOps(this);
398403
strings = new StringsOps(this);
399404
sparse = new SparseOps(this);
400-
tpu = new TpuOps(this);
401405
bitwise = new BitwiseOps(this);
406+
tpu = new TpuOps(this);
402407
math = new MathOps(this);
403408
audio = new AudioOps(this);
404409
signal = new SignalOps(this);
405-
quantization = new QuantizationOps(this);
406410
train = new TrainOps(this);
411+
quantization = new QuantizationOps(this);
407412
}
408413

409414
/**
@@ -1952,14 +1957,15 @@ public Constant<TFloat32> constant(Shape shape, FloatDataBuffer data) {
19521957
}
19531958

19541959
/**
1955-
* Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not
1956-
* fit in the target type.
1960+
* Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be
1961+
* truncated if it does not fit in the target type.
19571962
*
1958-
* @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating})
1963+
* @param type the type of tensor to create. Must be concrete (i.e. not {@link
1964+
* org.tensorflow.types.family.TFloating})
19591965
* @param number the value of the tensor
19601966
* @return a constant of the passed type
1961-
* @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or
1962-
* unknown.
1967+
* @throws IllegalArgumentException if the type is abstract (i.e. {@link
1968+
* org.tensorflow.types.family.TFloating}) or unknown.
19631969
*/
19641970
public <T extends TNumber> Constant<T> constant(Class<T> type, Number number) {
19651971
return Constant.tensorOf(scope, type, number);
@@ -1994,11 +2000,12 @@ public <T extends TType> Constant<T> constant(Class<T> type, Shape shape, ByteDa
19942000
}
19952001

19962002
/**
1997-
* Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed afterwards without
1998-
* issue.
2003+
* Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed
2004+
* afterwards without issue.
19992005
*
20002006
* <p>Note: this endpoint cannot be simply called {@code constant} since it will conflict with
2001-
* other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}.
2007+
* other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope,
2008+
* FloatNdArray)}}.
20022009
*
20032010
* @param tensor a Tensor holding the constant value
20042011
* @return a constant of the same data type as `tensor`
@@ -2008,8 +2015,8 @@ public <T extends TType> Constant<T> constantOf(T tensor) {
20082015
}
20092016

20102017
/**
2011-
* Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be
2012-
* truncated if it does not fit in the target type.
2018+
* Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code
2019+
* number} may be truncated if it does not fit in the target type.
20132020
*
20142021
* @param toMatch the operand providing the target type
20152022
* @param number the value of the tensor
@@ -2993,80 +3000,6 @@ public <T extends TType> ImmutableConst<T> immutableConst(Class<T> dtype, Shape
29933000
return ImmutableConst.create(scope, dtype, shape, memoryRegionName);
29943001
}
29953002

2996-
/**
2997-
* Factory method to create an operation executing all initializers of a graph.
2998-
*
2999-
* <p>All initializers added to a graph via
3000-
* {@link org.tensorflow.op.core.Init#add(Scope, Op) tf.initAdd} are grouped together as a single
3001-
* unit of computation in the graph. This operation must then be added to any graph using one or
3002-
* more {@link Variable variables} and executed once before running the graph so the variable
3003-
* states are initialized properly.</p>
3004-
*
3005-
* <p>When the graph is built by the same process that is running the session, the initializers
3006-
* can be invoked by executing this single endpoint. For example:</p>
3007-
* <pre>{@code
3008-
* try (Graph g = new Graph()) {
3009-
* Variable<TInt32> x = tf.variable(tf.constant(10)); // initAdd is called implicitly
3010-
* Variable<TInt32> y = tf.variable(tf.constant(20)); // idem
3011-
* Add<TInt32> z = tf.math.add(x, y);
3012-
*
3013-
* try (Session s = new Session(g)) {
3014-
* s.run(tf.init()); // initialize all variables
3015-
*
3016-
* try (TInt32 t = (TInt32)s.runner().fetch(z).run().get(0)) {
3017-
* assertEquals(30, t.data().getInt());
3018-
* }
3019-
* }
3020-
* }
3021-
* }</pre>
3022-
*
3023-
* <p>When the graph is built by a separate process, the initializers can be invoked by running
3024-
* the init op by its name, which defaults to {@link org.tensorflow.op.core.Init#DEFAULT_NAME}.
3025-
* For example:</p>
3026-
* <pre>{@code
3027-
* // Building the model
3028-
* try (Graph g = new Graph()) {
3029-
* Variable<TInt32> x = tf.variable(tf.constant(10)); // initAdd is called implicitly
3030-
* Variable<TInt32> y = tf.variable(tf.constant(20)); // idem
3031-
* Add<TInt32> z = tf.withName("z").math.add(x, y);
3032-
*
3033-
* tf.init(); // add variables initializers to the graph, as Init.DEFAULT_NAME
3034-
* // ...exporting graph as a saved model...
3035-
* }
3036-
*
3037-
* ...
3038-
*
3039-
* // Running the model
3040-
* try (SavedModelBundle model = SavedModelBundle.load("/path/to/model", "train")) {
3041-
* model.session().run(Init.DEFAULT_NAME);
3042-
*
3043-
* try (TInt32 t = (TInt32)s.runner().fetch("z").run().get(0)) {
3044-
* assertEquals(30, t.data().getInt());
3045-
* }
3046-
* }
3047-
* }</pre>
3048-
*
3049-
* @return an op grouping all initializers added to the graph
3050-
* @throws IllegalArgumentException if the execution environment in scope is not a graph
3051-
*/
3052-
public Init init() {
3053-
return Init.create(scope);
3054-
}
3055-
3056-
/**
3057-
* Register an op as an initializer of the graph.
3058-
*
3059-
* <p>Registered initializers are then grouped as a single unit of computation by adding
3060-
* and executing an {@link org.tensorflow.op.core.Init#create(Scope) init} operation from a graph
3061-
* session. This is a no-op if executed in an eager session.
3062-
*
3063-
* @param initializer
3064-
* @see org.tensorflow.op.core.Init#create(Scope) init
3065-
*/
3066-
public void initAdd(Op initializer) {
3067-
Init.add(scope, initializer);
3068-
}
3069-
30703003
/**
30713004
* Table initializer that takes two tensors for keys and values respectively.
30723005
*
@@ -7947,10 +7880,11 @@ public VarIsInitializedOp varIsInitializedOp(Operand<? extends TType> resource)
79477880
}
79487881

79497882
/**
7950-
* Factory method to create a new Variable with it's initializer.
7951-
* <p>
7952-
* Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op
7953-
* does not work in an EagerSession.
7883+
* Factory method to create a new Variable with its initializer. Both the creation and assignment
7884+
* are done in the init scope.
7885+
*
7886+
* <p>Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op does not
7887+
* work in an EagerSession.
79547888
*
79557889
* @param init The op to use to initialise this variable.
79567890
* @param options carries optional attributes values
@@ -8143,6 +8077,37 @@ public Ops withSubScope(String childScopeName) {
81438077
return new Ops(scope.withSubScope(childScopeName));
81448078
}
81458079

8080+
/**
8081+
* Returns an API that builds init operations. {@link #liftToInitScope(Operand)} will be called for all created operations.
8082+
* <p>
8083+
* Init operations will be initialized at session creation, will have their inputs (and control inputs) made init ops as well,
8084+
* and are ignored when used as control dependencies.
8085+
* Additionally, this scope ignores any control dependencies.
8086+
* <p>
8087+
* If an input can not be made an init op (i.e. a Placeholder), will throw an {@link IllegalStateException} on op creation.
8088+
* @see #liftToInitScope(Operand)
8089+
*/
8090+
public Ops withInitScope() {
8091+
return new Ops(scope.withInitScope());
8092+
}
8093+
8094+
/**
8095+
* Make {@code op} an init operation, doing the same for all of it's inputs (and control inputs).
8096+
* <p>
8097+
* Init operations will be initialized at session creation, will have their inputs (and control inputs) made init ops as well,
8098+
* and are ignored when used as control dependencies.
8099+
* Additionally, this scope ignores any control dependencies.
8100+
* <p>
8101+
* If an input can not be made an init op (i.e. a Placeholder), will throw an {@link IllegalStateException} on op creation.
8102+
* @see ExecutionEnvironment#registerInitOp(Operation)
8103+
*
8104+
* @throws IllegalStateException if the op or one of its inputs can't be made an init op.
8105+
*/
8106+
public <T extends Operand> T liftToInitScope(T op) {
8107+
scope.env().registerInitOp(op.op());
8108+
return op;
8109+
}
8110+
81468111
/**
81478112
* Returns an API that uses the provided name for an op.
81488113
*

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java

+25
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,31 @@ TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) {
9191
return outputHandles[outputIndex];
9292
}
9393

94+
@Override
95+
public int hashCode() {
96+
return Long.valueOf(opHandle.address()).hashCode();
97+
}
98+
99+
@Override
100+
public boolean equals(Object o) {
101+
if (o == this) {
102+
return true;
103+
}
104+
if (!(o instanceof EagerOperation)) {
105+
return false;
106+
}
107+
EagerOperation that = (EagerOperation) o;
108+
if (session != that.session) {
109+
return false;
110+
}
111+
112+
if (opHandle == null || that.opHandle == null || opHandle.isNull() || that.opHandle.isNull()) {
113+
// if they are the same object, we will already have returned
114+
return false;
115+
}
116+
return opHandle.equals(that.opHandle);
117+
}
118+
94119
@Override
95120
Shape shape(int outputIndex) {
96121
// If the tensor of this output has already been resolved, return its shape.

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java

+21-15
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
22
3-
Licensed under the Apache License, Version 2.0 (the "License");
4-
you may not use this file except in compliance with the License.
5-
You may obtain a copy of the License at
6-
7-
http://www.apache.org/licenses/LICENSE-2.0
8-
9-
Unless required by applicable law or agreed to in writing, software
10-
distributed under the License is distributed on an "AS IS" BASIS,
11-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
See the License for the specific language governing permissions and
13-
limitations under the License.
14-
=======================================================================
15-
*/
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================
15+
*/
1616
package org.tensorflow;
1717

1818
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_Execute;
@@ -53,24 +53,29 @@
5353
import org.tensorflow.internal.c_api.TF_Status;
5454
import org.tensorflow.internal.c_api.TF_Tensor;
5555
import org.tensorflow.ndarray.Shape;
56+
import org.tensorflow.op.Scope;
5657
import org.tensorflow.proto.framework.DataType;
5758

5859
/**
5960
* An {@link OperationBuilder} for building {@link Operation Operations} that are executed eagerly.
6061
*/
6162
final class EagerOperationBuilder implements OperationBuilder {
6263

63-
EagerOperationBuilder(EagerSession session, String type, String name) {
64+
EagerOperationBuilder(EagerSession session, String type, String name, Scope scope) {
6465
this.session = session;
6566
this.type = type;
6667
this.name = name;
68+
this.scope = scope;
6769
this.opHandle = allocate(session, type);
6870
}
6971

7072
@Override
7173
public EagerOperation build() {
74+
scope.apply(this);
7275
TFE_TensorHandle[] tensorHandles = execute(opHandle, session);
73-
return new EagerOperation(session, opHandle, tensorHandles, type, name);
76+
EagerOperation op = new EagerOperation(session, opHandle, tensorHandles, type, name);
77+
scope.onOpCreated(op);
78+
return op;
7479
}
7580

7681
@Override
@@ -250,6 +255,7 @@ public OperationBuilder setAttr(String name, ConcreteFunction[] value) {
250255
private final EagerSession session;
251256
private final String type;
252257
private final String name;
258+
private final Scope scope;
253259

254260
/** This value should be >= to the maximum number of outputs in any op */
255261
private static final int MAX_OUTPUTS_PER_OP = 1000;

0 commit comments

Comments
 (0)