Skip to content

Commit 3f2593a

Browse files
committed
Rebase fixes
Signed-off-by: Ryan Nett <[email protected]>
1 parent 27ce339 commit 3f2593a

File tree

21 files changed

+98
-157
lines changed

21 files changed

+98
-157
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

+47-94
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@
106106
import org.tensorflow.op.core.IdentityN;
107107
import org.tensorflow.op.core.If;
108108
import org.tensorflow.op.core.ImmutableConst;
109-
import org.tensorflow.op.core.Init;
110109
import org.tensorflow.op.core.InitializeTable;
111110
import org.tensorflow.op.core.InitializeTableFromTextFile;
112111
import org.tensorflow.op.core.InplaceAdd;
@@ -295,6 +294,12 @@
295294
import org.tensorflow.op.core.VariableShape;
296295
import org.tensorflow.op.core.Where;
297296
import org.tensorflow.op.core.While;
297+
import org.tensorflow.op.core.XlaConvV2;
298+
import org.tensorflow.op.core.XlaDotV2;
299+
import org.tensorflow.op.core.XlaSetDynamicDimensionSize;
300+
import org.tensorflow.op.core.XlaSpmdFullToShardShape;
301+
import org.tensorflow.op.core.XlaSpmdShardToFullShape;
302+
import org.tensorflow.op.core.XlaVariadicSort;
298303
import org.tensorflow.op.core.Zeros;
299304
import org.tensorflow.op.core.ZerosLike;
300305
import org.tensorflow.types.TBool;
@@ -366,20 +371,20 @@ public final class Ops {
366371

367372
public final SparseOps sparse;
368373

369-
public final TpuOps tpu;
370-
371374
public final BitwiseOps bitwise;
372375

376+
public final TpuOps tpu;
377+
373378
public final MathOps math;
374379

375380
public final AudioOps audio;
376381

377382
public final SignalOps signal;
378383

379-
public final QuantizationOps quantization;
380-
381384
public final TrainOps train;
382385

386+
public final QuantizationOps quantization;
387+
383388
private final Scope scope;
384389

385390
private Ops(Scope scope) {
@@ -397,13 +402,13 @@ private Ops(Scope scope) {
397402
random = new RandomOps(this);
398403
strings = new StringsOps(this);
399404
sparse = new SparseOps(this);
400-
tpu = new TpuOps(this);
401405
bitwise = new BitwiseOps(this);
406+
tpu = new TpuOps(this);
402407
math = new MathOps(this);
403408
audio = new AudioOps(this);
404409
signal = new SignalOps(this);
405-
quantization = new QuantizationOps(this);
406410
train = new TrainOps(this);
411+
quantization = new QuantizationOps(this);
407412
}
408413

409414
/**
@@ -1952,14 +1957,15 @@ public Constant<TFloat32> constant(Shape shape, FloatDataBuffer data) {
19521957
}
19531958

19541959
/**
1955-
* Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not
1956-
* fit in the target type.
1960+
* Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be
1961+
* truncated if it does not fit in the target type.
19571962
*
1958-
* @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating})
1963+
* @param type the type of tensor to create. Must be concrete (i.e. not {@link
1964+
* org.tensorflow.types.family.TFloating})
19591965
* @param number the value of the tensor
19601966
* @return a constant of the passed type
1961-
* @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or
1962-
* unknown.
1967+
* @throws IllegalArgumentException if the type is abstract (i.e. {@link
1968+
* org.tensorflow.types.family.TFloating}) or unknown.
19631969
*/
19641970
public <T extends TNumber> Constant<T> constant(Class<T> type, Number number) {
19651971
return Constant.tensorOf(scope, type, number);
@@ -1994,11 +2000,12 @@ public <T extends TType> Constant<T> constant(Class<T> type, Shape shape, ByteDa
19942000
}
19952001

19962002
/**
1997-
* Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed afterwards without
1998-
* issue.
2003+
* Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed
2004+
* afterwards without issue.
19992005
*
20002006
* <p>Note: this endpoint cannot be simply called {@code constant} since it will conflict with
2001-
* other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}.
2007+
* other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope,
2008+
* FloatNdArray)}}.
20022009
*
20032010
* @param tensor a Tensor holding the constant value
20042011
* @return a constant of the same data type as `tensor`
@@ -2008,8 +2015,8 @@ public <T extends TType> Constant<T> constantOf(T tensor) {
20082015
}
20092016

20102017
/**
2011-
* Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be
2012-
* truncated if it does not fit in the target type.
2018+
* Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code
2019+
* number} may be truncated if it does not fit in the target type.
20132020
*
20142021
* @param toMatch the operand providing the target type
20152022
* @param number the value of the tensor
@@ -2993,80 +3000,6 @@ public <T extends TType> ImmutableConst<T> immutableConst(Class<T> dtype, Shape
29933000
return ImmutableConst.create(scope, dtype, shape, memoryRegionName);
29943001
}
29953002

2996-
/**
2997-
* Factory method to create an operation executing all initializers of a graph.
2998-
*
2999-
* <p>All initializers added to a graph via
3000-
* {@link org.tensorflow.op.core.Init#add(Scope, Op) tf.initAdd} are grouped together as a single
3001-
* unit of computation in the graph. This operation must then be added to any graph using one or
3002-
* more {@link Variable variables} and executed once before running the graph so the variable
3003-
* states are initialized properly.</p>
3004-
*
3005-
* <p>When the graph is built by the same process that is running the session, the initializers
3006-
* can be invoked by executing this single endpoint. For example:</p>
3007-
* <pre>{@code
3008-
* try (Graph g = new Graph()) {
3009-
* Variable<TInt32> x = tf.variable(tf.constant(10)); // initAdd is called implicitly
3010-
* Variable<TInt32> y = tf.variable(tf.constant(20)); // idem
3011-
* Add<TInt32> z = tf.math.add(x, y);
3012-
*
3013-
* try (Session s = new Session(g)) {
3014-
* s.run(tf.init()); // initialize all variables
3015-
*
3016-
* try (TInt32 t = (TInt32)s.runner().fetch(z).run().get(0)) {
3017-
* assertEquals(30, t.data().getInt());
3018-
* }
3019-
* }
3020-
* }
3021-
* }</pre>
3022-
*
3023-
* <p>When the graph is built by a separate process, the initializers can be invoked by running
3024-
* the init op by its name, which defaults to {@link org.tensorflow.op.core.Init#DEFAULT_NAME}.
3025-
* For example:</p>
3026-
* <pre>{@code
3027-
* // Building the model
3028-
* try (Graph g = new Graph()) {
3029-
* Variable<TInt32> x = tf.variable(tf.constant(10)); // initAdd is called implicitly
3030-
* Variable<TInt32> y = tf.variable(tf.constant(20)); // idem
3031-
* Add<TInt32> z = tf.withName("z").math.add(x, y);
3032-
*
3033-
* tf.init(); // add variables initializers to the graph, as Init.DEFAULT_NAME
3034-
* // ...exporting graph as a saved model...
3035-
* }
3036-
*
3037-
* ...
3038-
*
3039-
* // Running the model
3040-
* try (SavedModelBundle model = SavedModelBundle.load("/path/to/model", "train")) {
3041-
* model.session().run(Init.DEFAULT_NAME);
3042-
*
3043-
* try (TInt32 t = (TInt32)s.runner().fetch("z").run().get(0)) {
3044-
* assertEquals(30, t.data().getInt());
3045-
* }
3046-
* }
3047-
* }</pre>
3048-
*
3049-
* @return an op grouping all initializers added to the graph
3050-
* @throws IllegalArgumentException if the execution environment in scope is not a graph
3051-
*/
3052-
public Init init() {
3053-
return Init.create(scope);
3054-
}
3055-
3056-
/**
3057-
* Register an op as an initializer of the graph.
3058-
*
3059-
* <p>Registered initializers are then grouped as a single unit of computation by adding
3060-
* and executing an {@link org.tensorflow.op.core.Init#create(Scope) init} operation from a graph
3061-
* session. This is a no-op if executed in an eager session.
3062-
*
3063-
* @param initializer
3064-
* @see org.tensorflow.op.core.Init#create(Scope) init
3065-
*/
3066-
public void initAdd(Op initializer) {
3067-
Init.add(scope, initializer);
3068-
}
3069-
30703003
/**
30713004
* Table initializer that takes two tensors for keys and values respectively.
30723005
*
@@ -7948,9 +7881,9 @@ public VarIsInitializedOp varIsInitializedOp(Operand<? extends TType> resource)
79487881

79497882
/**
79507883
* Factory method to create a new Variable with it's initializer.
7951-
* <p>
7952-
* Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op
7953-
* does not work in an EagerSession.
7884+
*
7885+
* <p>Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op does not
7886+
* work in an EagerSession.
79547887
*
79557888
* @param init The op to use to initialise this variable.
79567889
* @param options carries optional attributes values
@@ -8143,6 +8076,26 @@ public Ops withSubScope(String childScopeName) {
81438076
return new Ops(scope.withSubScope(childScopeName));
81448077
}
81458078

8079+
/**
8080+
* Returns an API that builds init operations.
8081+
* <p>Init operations will be initialized at session creation, will have their inputs (and control inputs) made init ops as well, and are never used as control dependencies.
8082+
* Additionally, this scope drops all of its control dependencies. If an input can not be made an init op (i.e. a Placeholder), will error on op creation.
8083+
*/
8084+
public Ops initScope() {
8085+
return new Ops(scope.initScope());
8086+
}
8087+
8088+
/**
8089+
* Make {@code op} an init operation, doing the same for all of it's inputs (and control inputs).
8090+
* <p>Init operations will be initialized at session creation, will have their inputs (and control inputs) made init ops as well, and are never used as control dependencies.
8091+
* Additionally, this scope drops all of its control dependencies. If an input can not be made an init op (i.e. a Placeholder), will error on op creation.
8092+
* @throws IllegalArgumentException if the op or one of its inputs can't be made an init op.
8093+
*/
8094+
public <T extends Operand> T liftToInitScope(T op) {
8095+
scope.env().registerInitOp(op.op());
8096+
return op;
8097+
}
8098+
81468099
/**
81478100
* Returns an API that uses the provided name for an op.
81488101
*

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java

+1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ default ExecutionEnvironment initEnv() {
115115
*
116116
* <p><b>Should generally only be used internally, prefer {@link
117117
* org.tensorflow.op.Ops#initScope()}.</b>
118+
*
118119
* @throws IllegalArgumentException if the op or one of its inputs can't be made an init op.
119120
*/
120121
void registerInitOp(Operation op);

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

+1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ public void close() {
145145
* Execute the graph's initializers.
146146
*
147147
* <p>This runs any ops that have been created with an init scope.
148+
*
148149
* @throws IllegalStateException if the session has already been initialized
149150
*/
150151
public synchronized void initialize() {

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java

+30-29
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,10 @@ void buildClass() {
259259

260260
if (iterable) {
261261
mode = RenderMode.LIST_OPERAND;
262-
if (!isStateSubclass) {builder.addSuperinterface(
263-
ParameterizedTypeName.get(ClassName.get(Iterable.class), operandType));}
262+
if (!isStateSubclass) {
263+
builder.addSuperinterface(
264+
ParameterizedTypeName.get(ClassName.get(Iterable.class), operandType));
265+
}
264266
} else {
265267
mode = RenderMode.OPERAND;
266268
if (!isStateSubclass) {
@@ -315,26 +317,27 @@ void buildClass() {
315317
buildInterfaceImpl();
316318
}
317319

318-
if (!isStateSelector) {// add op name field
319-
builder.addField(
320-
FieldSpec.builder(
321-
TypeResolver.STRING, OP_NAME_FIELD,
320+
if (!isStateSelector) { // add op name field
321+
builder.addField(
322+
FieldSpec.builder(
323+
TypeResolver.STRING,
324+
OP_NAME_FIELD,
322325
Modifier.PUBLIC,
323326
Modifier.STATIC,
324327
Modifier.FINAL)
325328
.addJavadoc("$L", "The name of this op, as known by TensorFlow core engine")
326329
.initializer("$S", op.getName())
327330
.build());
328331

329-
// add output fields
330-
if (op.getOutputArgCount() > 0) {
331-
for (ArgDef output : op.getOutputArgList()) {
332-
builder.addField(
333-
resolver.typeOf(output).listIfIterable().javaType,
334-
getJavaName(output),
335-
Modifier.PRIVATE);
332+
// add output fields
333+
if (op.getOutputArgCount() > 0) {
334+
for (ArgDef output : op.getOutputArgList()) {
335+
builder.addField(
336+
resolver.typeOf(output).listIfIterable().javaType,
337+
getJavaName(output),
338+
Modifier.PRIVATE);
339+
}
336340
}
337-
}
338341

339342
buildConstructor();
340343
}
@@ -411,8 +414,7 @@ private void buildOptionsClass() {
411414
}
412415

413416
// add the field
414-
optionsBuilder.addField(
415-
field.build());
417+
optionsBuilder.addField(field.build());
416418
}
417419

418420
// add a private constructor
@@ -503,8 +505,7 @@ private void buildFactoryMethods() {
503505
Set<TypeVariableName> typeVars = new LinkedHashSet<>(typeParams);
504506

505507
body.addStatement(
506-
"$T opBuilder = scope.opBuilder($L, $S)", Names.OperationBuilder,
507-
OP_NAME_FIELD, className);
508+
"$T opBuilder = scope.opBuilder($L, $S)", Names.OperationBuilder, OP_NAME_FIELD, className);
508509

509510
List<String> functionArgs = new ArrayList<>();
510511
List<String> iterableFunctionArgs = new ArrayList<>();
@@ -599,9 +600,7 @@ private void buildFactoryMethods() {
599600
}
600601

601602
factoryBuilder.addParameter(
602-
ParameterSpec.builder(
603-
ArrayTypeName.of(optionsClassName), "options")
604-
.build());
603+
ParameterSpec.builder(ArrayTypeName.of(optionsClassName), "options").build());
605604
paramTags.put("options", CodeBlock.of("$L", "carries optional attribute values"));
606605
factoryBuilder.varargs();
607606

@@ -619,7 +618,7 @@ private void buildFactoryMethods() {
619618
body.endControlFlow();
620619

621620
body.endControlFlow();
622-
}
621+
}
623622

624623
body.addStatement(
625624
"return new $L(opBuilder.build())", typeParams.isEmpty() ? className : (className + "<>"));
@@ -826,14 +825,16 @@ private void buildInterfaceImpl() {
826825

827826
if (isStateSelector) {
828827
asOutput.addModifiers(Modifier.ABSTRACT);
829-
} else {if (uncheckedCast) {
830-
asOutput.addAnnotation(
831-
AnnotationSpec.builder(SuppressWarnings.class)
832-
.addMember("value", "$S", "unchecked")
833-
.build());
834-
asOutput.addCode("return ($T) $L;", outputType, getJavaName(output));
835828
} else {
836-
asOutput.addCode("return $L;", getJavaName(output));}
829+
if (uncheckedCast) {
830+
asOutput.addAnnotation(
831+
AnnotationSpec.builder(SuppressWarnings.class)
832+
.addMember("value", "$S", "unchecked")
833+
.build());
834+
asOutput.addCode("return ($T) $L;", outputType, getJavaName(output));
835+
} else {
836+
asOutput.addCode("return $L;", getJavaName(output));
837+
}
837838
}
838839

839840
builder.addMethod(asOutput.build());

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java

+4-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import com.squareup.javapoet.JavaFile;
2727
import com.squareup.javapoet.MethodSpec;
2828
import com.squareup.javapoet.ParameterSpec;
29-
import com.squareup.javapoet.ParameterizedTypeName;
3029
import com.squareup.javapoet.TypeName;
3130
import com.squareup.javapoet.TypeSpec;
3231
import com.squareup.javapoet.TypeVariableName;
@@ -40,8 +39,6 @@
4039
import java.util.List;
4140
import java.util.Map;
4241
import java.util.Set;
43-
import java.util.function.Consumer;
44-
import java.util.function.Function;
4542
import java.util.regex.Pattern;
4643
import javax.annotation.processing.AbstractProcessor;
4744
import javax.annotation.processing.Filer;
@@ -575,8 +572,10 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
575572
.returns(T)
576573
.addStatement("scope.env().registerInitOp(op.op())")
577574
.addStatement("return op")
578-
.addJavadoc("Make {@code op} an init operation, doing the same for all of it's inputs (and control inputs).\n" + initScopeComment +
579-
"\n@throws IllegalArgumentException if the op or one of its inputs can't be made an init op.")
575+
.addJavadoc(
576+
"Make {@code op} an init operation, doing the same for all of it's inputs (and control inputs).\n"
577+
+ initScopeComment
578+
+ "\n@throws IllegalArgumentException if the op or one of its inputs can't be made an init op.")
580579
.build());
581580

582581
opsBuilder.addMethod(

0 commit comments

Comments
 (0)