Skip to content

Commit 67a76fd

Browse files
committed
Add base scope, test
Signed-off-by: Ryan Nett <[email protected]>
1 parent e4d08c0 commit 67a76fd

File tree

7 files changed

+83
-10
lines changed

7 files changed

+83
-10
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -354,14 +354,14 @@ public final class Ops {
354354

355355
public final SparseOps sparse;
356356

357-
public final BitwiseOps bitwise;
358-
359357
public final TpuOps tpu;
360358

361-
public final AudioOps audio;
359+
public final BitwiseOps bitwise;
362360

363361
public final MathOps math;
364362

363+
public final AudioOps audio;
364+
365365
public final SignalOps signal;
366366

367367
public final TrainOps train;
@@ -385,10 +385,10 @@ private Ops(Scope scope) {
385385
random = new RandomOps(this);
386386
strings = new StringsOps(this);
387387
sparse = new SparseOps(this);
388-
bitwise = new BitwiseOps(this);
389388
tpu = new TpuOps(this);
390-
audio = new AudioOps(this);
389+
bitwise = new BitwiseOps(this);
391390
math = new MathOps(this);
391+
audio = new AudioOps(this);
392392
signal = new SignalOps(this);
393393
train = new TrainOps(this);
394394
quantization = new QuantizationOps(this);
@@ -8185,7 +8185,7 @@ public final Scope scope() {
81858185
* Creates an API for building operations in the provided execution environment
81868186
*/
81878187
public static Ops create(ExecutionEnvironment env) {
8188-
return new Ops(new Scope(env));
8188+
return new Ops(env.baseScope());
81898189
}
81908190

81918191
/**
@@ -8194,6 +8194,6 @@ public static Ops create(ExecutionEnvironment env) {
81948194
* <p>Invoking this method is equivalent to {@code Ops.create(EagerSession.getDefault())}.
81958195
*/
81968196
public static Ops create() {
8197-
return new Ops(new Scope(EagerSession.getDefault()));
8197+
return create(EagerSession.getDefault());
81988198
}
81998199
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java

+13
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
import org.tensorflow.internal.c_api.TFE_ContextOptions;
3030
import org.tensorflow.internal.c_api.TF_Status;
3131
import org.tensorflow.op.Op;
32+
import org.tensorflow.op.Ops;
33+
import org.tensorflow.op.Scope;
3234
import org.tensorflow.op.core.Assign;
3335
import org.tensorflow.op.core.Placeholder;
3436
import org.tensorflow.op.core.Variable;
@@ -306,6 +308,15 @@ public void checkInput(Op input) {
306308
}
307309
}
308310

311+
@Override
312+
public synchronized Scope baseScope() {
313+
if(baseScope == null){
314+
baseScope = new Scope(this);
315+
}
316+
317+
return baseScope;
318+
}
319+
309320
TFE_Context nativeHandle() {
310321
checkSession();
311322
return nativeHandle;
@@ -362,6 +373,8 @@ void detach(Pointer... resources) {
362373
private final WeakPointerScope nativeResources;
363374
private TFE_Context nativeHandle;
364375

376+
private Scope baseScope = null;
377+
365378
private EagerSession(Options options) {
366379
this.nativeResources = new WeakPointerScope();
367380
this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config);

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java

+6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.tensorflow;
1717

1818
import org.tensorflow.op.Op;
19+
import org.tensorflow.op.Scope;
1920

2021
/**
2122
* Defines an environment for creating and executing TensorFlow {@link Operation}s.
@@ -71,4 +72,9 @@ default boolean isEager() {
7172
default boolean isGraph() {
7273
return environmentType() == Types.GRAPH;
7374
}
75+
76+
/**
77+
* Get the top level scope for this execution environment. Is cached, which is necessary to prevent name collisions.
78+
*/
79+
Scope baseScope();
7480
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

+11
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import org.tensorflow.ndarray.StdArrays;
5353
import org.tensorflow.op.Op;
5454
import org.tensorflow.op.Ops;
55+
import org.tensorflow.op.Scope;
5556
import org.tensorflow.op.core.Constant;
5657
import org.tensorflow.op.core.Identity;
5758
import org.tensorflow.op.core.NoOp;
@@ -390,6 +391,15 @@ public void checkInput(Op input) {
390391
}
391392
}
392393

394+
@Override
395+
public synchronized Scope baseScope() {
396+
if(baseScope == null){
397+
baseScope = new Scope(this);
398+
}
399+
400+
return baseScope;
401+
}
402+
393403
/**
394404
* Import a representation of a TensorFlow graph.
395405
*
@@ -692,6 +702,7 @@ synchronized SaverDef saverDef() {
692702
private TF_Graph nativeHandle;
693703
private int refcount = 0;
694704
private SaverDef saverDef;
705+
private Scope baseScope = null;
695706

696707
private final List<Op> initializers = new ArrayList<>();
697708

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java

+2
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ public final class Scope {
7979

8080
/**
8181
* Create a new top-level scope.
82+
* <p>
83+
* <b>For internal use only</b>, use {@link ExecutionEnvironment#baseScope()} if you need a base level scope.
8284
*
8385
* @param env The execution environment used by the scope.
8486
*/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.tensorflow.op.Ops;
21+
22+
public class ScopeTest {
23+
@Test
24+
public void testSeparateOps(){
25+
try(Graph g = new Graph()){
26+
Ops tf1 = Ops.create(g);
27+
Ops tf2 = Ops.create(g);
28+
29+
tf1.constant(2);
30+
tf1.withName("Constant2").constant(2);
31+
tf1.withSubScope("Scope").constant(2);
32+
tf1.withSubScope("Scope").withName("Constant4").constant(2);
33+
34+
tf2.constant(2);
35+
tf2.withName("Constant2").constant(2);
36+
tf2.withSubScope("Scope").constant(2);
37+
tf2.withSubScope("Scope").withName("Constant4").constant(2);
38+
39+
}
40+
}
41+
}

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
470470

471471
MethodSpec.Builder ctorBuilder =
472472
MethodSpec.constructorBuilder()
473-
.addModifiers(Modifier.PRIVATE)
474473
.addParameter(T_SCOPE, "scope")
474+
.addModifiers(Modifier.PRIVATE)
475475
.addStatement("this.scope = scope", T_SCOPE);
476476

477477
TypeSpec.Builder opsBuilder =
@@ -578,7 +578,7 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
578578
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
579579
.addParameter(T_EXEC_ENV, "env")
580580
.returns(T_OPS)
581-
.addStatement("return new Ops(new $T(env))", T_SCOPE)
581+
.addStatement("return new Ops(env.baseScope())")
582582
.addJavadoc(
583583
"Creates an API for building operations in the provided execution environment\n")
584584
.build());
@@ -587,7 +587,7 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
587587
MethodSpec.methodBuilder("create")
588588
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
589589
.returns(T_OPS)
590-
.addStatement("return new Ops(new $T($T.getDefault()))", T_SCOPE, T_EAGER_SESSION)
590+
.addStatement("return create($T.getDefault())", T_EAGER_SESSION)
591591
.addJavadoc(
592592
"Creates an API for building operations in the default eager execution environment\n\n"
593593
+ "<p>Invoking this method is equivalent to {@code Ops.create(EagerSession.getDefault())}.\n")

0 commit comments

Comments
 (0)