From 155fdd6af021b250004d38007ad3dcfc0b5727eb Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 30 Jan 2022 16:02:27 -0500 Subject: [PATCH 1/9] Adding an autocloseable result class for the output of Session.Runner.run. --- .../src/main/java/org/tensorflow/Session.java | 141 ++++++++++++++++-- .../test/java/org/tensorflow/SessionTest.java | 32 ++-- 2 files changed, 141 insertions(+), 32 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 71fdcec3f41..dbbf40622db 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -21,12 +21,19 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; import com.google.protobuf.InvalidProtocolBufferException; + +import java.sql.Array; import java.util.ArrayList; import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; +import java.util.logging.Logger; + import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -490,13 +497,13 @@ private void doInit() { * * @return list of resulting tensors fetched by this session runner */ - public List run() { + public Result run() { doInit(); return runNoInit(); } - List runNoInit() { - return runHelper(false).outputs; + Result runNoInit() { + return runHelper(false); } /** @@ -509,18 +516,19 @@ List runNoInit() { * * @return list of resulting tensors fetched by this session runner, with execution metadata */ - public Run runAndFetchMetadata() { + public Result runAndFetchMetadata() { doInit(); return runHelper(true); } - private Run runHelper(boolean wantMetadata) { + private Result runHelper(boolean wantMetadata) { TF_Tensor[] inputTensorHandles = new TF_Tensor[inputTensors.size()]; TF_Operation[] inputOpHandles = new TF_Operation[inputs.size()]; int[] inputOpIndices = new int[inputs.size()]; TF_Operation[] outputOpHandles = new TF_Operation[outputs.size()]; int[] outputOpIndices = new int[outputs.size()]; TF_Operation[] targetOpHandles = new TF_Operation[targets.size()]; + List outputNames = new ArrayList<>(); // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the // validity of the Graph and graphRef ensures that. @@ -538,6 +546,7 @@ private Run runHelper(boolean wantMetadata) { for (Output o : outputs) { outputOpHandles[idx] = (TF_Operation) o.getUnsafeNativeHandle(); outputOpIndices[idx] = o.index(); + outputNames.add(o.name()); idx++; } idx = 0; @@ -569,10 +578,7 @@ private Run runHelper(boolean wantMetadata) { } finally { runRef.close(); } - Run ret = new Run(); - ret.outputs = outputs; - ret.metadata = metadata; - return ret; + return new Result(outputNames,outputs,metadata); } private class Reference implements AutoCloseable { @@ -699,14 +705,117 @@ public void restore(String prefix) { } /** - * Output tensors and metadata obtained when executing a session. + * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s. * - *

See {@link Runner#runAndFetchMetadata()} + *

When this is closed it closes all the {@link Tensor}s inside it. If you maintain a + * reference to a value after this object has been closed it will throw an {@link + * IllegalStateException} upon access. */ - public static final class Run { + public static class Result implements AutoCloseable, Iterable> { + + private static final Logger logger = Logger.getLogger(Result.class.getName()); + + private final Map map; + + private final List list; + + private final RunMetadata metadata; + + private boolean closed; + + /** + * Creates a Result from the names and values produced by {@link Session.Runner#run()}. + * + * @param names The output names. + * @param values The output values. + * @param metadata The run metadata, may be null. + */ + Result(List names, List values, RunMetadata metadata) { + this.map = new LinkedHashMap<>(); + this.list = new ArrayList<>(values); + + if (names.size() != values.size()) { + throw new IllegalArgumentException( + "Expected same number of names and values, found names.length = " + + names.size() + + ", values.length = " + + values.size()); + } - /** Tensors from requested fetches. */ - public List outputs; + for (int i = 0; i < names.size(); i++) { + this.map.put(names.get(i), values.get(i)); + } + this.metadata = metadata; + this.closed = false; + } + + @Override + public void close() { + if (!closed) { + closed = true; + for (Tensor t : map.values()) { + t.close(); + } + } else { + logger.warning("Closing an already closed Result"); + } + } + + @Override + public Iterator> iterator() { + if (!closed) { + return map.entrySet().iterator(); + } else { + throw new IllegalStateException("Result is closed"); + } + } + + /** + * Gets the value from the container at the specified index. + * + *

Throws {@link IllegalStateException} if the container has been closed, and {@link + * IndexOutOfBoundsException} if the index is invalid. + * + * @param index The index to lookup. + * @return The value at the index. + */ + public Tensor get(int index) { + if (!closed) { + return list.get(index); + } else { + throw new IllegalStateException("Result is closed"); + } + } + + /** + * Returns the number of outputs in this Result. + * + * @return The number of outputs. + */ + public int size() { + return map.size(); + } + + /** + * Gets the value from the container assuming it's not been closed. + * + *

Throws {@link IllegalStateException} if the container has been closed. + * + * @param key The key to lookup. + * @return Optional.of the value if it exists. + */ + public Optional get(String key) { + if (!closed) { + Tensor value = map.get(key); + if (value != null) { + return Optional.of(value); + } else { + return Optional.empty(); + } + } else { + throw new IllegalStateException("Result is closed"); + } + } /** * Metadata about the run. @@ -715,7 +824,9 @@ public static final class Run { * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata * protocol buffer. */ - public RunMetadata metadata; + public Optional getMetadata() { + return Optional.ofNullable(metadata); + } } Graph graph() { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 95da0520f7d..ddff36de91b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -26,6 +26,8 @@ import java.nio.file.Path; import java.util.Comparator; import java.util.Iterator; +import java.util.Optional; + import org.junit.jupiter.api.Test; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -38,6 +40,7 @@ import org.tensorflow.op.math.Add; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.GraphDef; +import org.tensorflow.proto.framework.RunMetadata; import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -69,8 +72,7 @@ public void runUsingOperationNames() { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) { + Session.Result outputs = s.runner().feed("X", x).fetch("Y").run()) { assertEquals(1, outputs.size()); assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } @@ -86,8 +88,7 @@ public void runUsingOperationHandles() { Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) { + Session.Result outputs = s.runner().feed(feed, x).fetch(fetch).run()) { assertEquals(1, outputs.size()); assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } @@ -124,20 +125,20 @@ public void runWithMetadata() { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { - Session.Run result = + Session.Result result = s.runner() .feed("X", x) .fetch("Y") .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList outputs = new AutoCloseableList<>(result.outputs); - assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); + assertEquals(1, result.size()); + assertEquals(31, ((TInt32) result.get(0)).getInt(0, 0)); // Sanity check on metadata - assertNotNull(result.metadata); - assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); - outputs.close(); + Optional metadata = result.getMetadata(); + assertTrue(metadata.isPresent()); + assertTrue(metadata.get().hasStepStats(), metadata.get().toString()); + result.close(); } } } @@ -149,8 +150,7 @@ public void runMultipleOutputs() { Ops tf = Ops.create(g); tf.withName("c1").constant(2718); tf.withName("c2").constant(31415); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run()); + Session.Result outputs = s.runner().fetch("c2").fetch("c1").run(); assertEquals(2, outputs.size()); assertEquals(31415, ((TInt32) outputs.get(0)).getInt()); assertEquals(2718, ((TInt32) outputs.get(1)).getInt()); @@ -227,10 +227,8 @@ public void saveAndRestore() throws IOException { restoredGraph.importGraphDef(graphDef); try (Session restoredSession = new Session(restoredGraph)) { restoredSession.restore(testFolder.resolve("checkpoint").toString()); - try (AutoCloseableList oldList = - new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run()); - AutoCloseableList newList = - new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())) { + try (Session.Result oldList = s.runner().fetch("x").fetch("y").run(); + Session.Result newList = restoredSession.runner().fetch("x").fetch("y").run()) { assertEquals(oldList.get(0), newList.get(0)); assertEquals(oldList.get(1), newList.get(1)); } From 14509801e819e249a4cb13b9e563192cbaa56a4b Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 30 Jan 2022 16:18:03 -0500 Subject: [PATCH 2/9] Fix framework tests. --- .../java/org/tensorflow/SessionFunction.java | 2 +- .../framework/data/DatasetIteratorTest.java | 13 +-- .../framework/data/MapDatasetTest.java | 12 +-- .../metrics/impl/AssertBroadcastableTest.java | 9 ++- .../metrics/impl/BroadcastWeightsTest.java | 80 ++++++++++--------- .../optimizers/GradientDescentTest.java | 22 +++-- 6 files changed, 70 insertions(+), 68 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java index 07bc418ac51..3c6a11fd954 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java @@ -113,7 +113,7 @@ public Map call(Map arguments) { signature.getOutputs().values().forEach(x -> runner.fetch(x.name)); - List results = runner.run(); + Session.Result results = runner.run(); Map outputs = new LinkedHashMap<>(results.size()); int i = 0; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java index 1f8503829b7..9a281305703 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java @@ -51,15 +51,10 @@ public void testGraphIteration() { int batches = 0; while (true) { - try { - List outputs = session.runner().fetch(x).fetch(y).run(); - - try (TInt32 xBatch = (TInt32) outputs.get(0); - TInt32 yBatch = (TInt32) outputs.get(1)) { - assertEquals(testMatrix1.get(batches), xBatch); - assertEquals(testMatrix2.get(batches), yBatch); - batches++; - } + try (Session.Result outputs = session.runner().fetch(x).fetch(y).run()) { + assertEquals(testMatrix1.get(batches), outputs.get(0)); + assertEquals(testMatrix2.get(batches), outputs.get(1)); + batches++; } catch (TFOutOfRangeException e) { break; } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java index afa38e04ee8..9cac8fe54c4 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java @@ -76,17 +76,11 @@ public void testGraphIteration() { int batches = 0; while (true) { - try { - List outputs = session.runner().fetch(X).fetch(y).run(); - - try (TInt32 XBatch = (TInt32) outputs.get(0); - TInt32 yBatch = (TInt32) outputs.get(1)) { - - assertEquals(mapped1.get(batches), XBatch); - assertEquals(mapped2.get(batches), yBatch); + try (Session.Result outputs = session.runner().fetch(X).fetch(y).run()) { + assertEquals(mapped1.get(batches), outputs.get(0)); + assertEquals(mapped2.get(batches), outputs.get(1)); batches++; - } } catch (TFOutOfRangeException e) { break; } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index 4330fa0aed7..18b8a6254ce 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -16,6 +16,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Op; @@ -69,10 +70,10 @@ private void testValid( Operand weightsPlaceholder = tf.placeholder(type); Operand valuesPlaceholder = tf.placeholder(type); - List tensors = - testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); - try (Tensor weightsTensor = tensors.get(0); - Tensor valuesTensor = tensors.get(1)) { + try (Session.Result tensors = + testSession.getGraphSession().runner().fetch(weights).fetch(values).run()) { + Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1); Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder); testSession diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java index 3322a81fe5b..e72d3534ade 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java @@ -16,6 +16,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -78,55 +79,56 @@ private void testValid( Operand weightsPlaceholder = tf.placeholder(type); Operand valuesPlaceholder = tf.placeholder(type); - List tensors = - testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); - try (Tensor weightsTensor = tensors.get(0); - Tensor valuesTensor = tensors.get(1)) { + try (Session.Result tensors = + testSession.getGraphSession().runner().fetch(weights).fetch(values).run()) { + Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1); Operand dynamicOp = MetricsHelper.broadcastWeights(tf, weightsPlaceholder, valuesPlaceholder); - List result = + try (Session.Result result = testSession .getGraphSession() .runner() .feed(weightsPlaceholder, weightsTensor) .feed(valuesPlaceholder, valuesTensor) .fetch(dynamicOp) - .run(); - - if (expected != null) { - if (type.equals(TInt32.class)) { - TInt32 intT = (TInt32) result.get(0); - AtomicInteger i = new AtomicInteger(); - intT.scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expected[i.getAndIncrement()].intValue(), f.getInt())); - } else if (type.equals(TInt64.class)) { - TInt64 floatT = (TInt64) result.get(0); - AtomicInteger i = new AtomicInteger(); - floatT - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expected[i.getAndIncrement()].longValue(), f.getLong())); - } else if (type.equals(TFloat32.class)) { - TFloat32 floatT = (TFloat32) result.get(0); - AtomicInteger i = new AtomicInteger(); - floatT - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals( - expected[i.getAndIncrement()].floatValue(), f.getFloat(), 1e-5F)); - } else if (type.equals(TFloat64.class)) { - TFloat64 doubleT = (TFloat64) result.get(0); - AtomicInteger i = new AtomicInteger(); - doubleT - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals( - expected[i.getAndIncrement()].doubleValue(), f.getDouble(), 1e-5F)); + .run()) { + + if (expected != null) { + if (type.equals(TInt32.class)) { + TInt32 intT = (TInt32) result.get(0); + AtomicInteger i = new AtomicInteger(); + intT.scalars() + .forEachIndexed( + (idx, f) -> assertEquals(expected[i.getAndIncrement()].intValue(), f.getInt())); + } else if (type.equals(TInt64.class)) { + TInt64 floatT = (TInt64) result.get(0); + AtomicInteger i = new AtomicInteger(); + floatT + .scalars() + .forEachIndexed( + (idx, f) -> assertEquals(expected[i.getAndIncrement()].longValue(), f.getLong())); + } else if (type.equals(TFloat32.class)) { + TFloat32 floatT = (TFloat32) result.get(0); + AtomicInteger i = new AtomicInteger(); + floatT + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals( + expected[i.getAndIncrement()].floatValue(), f.getFloat(), 1e-5F)); + } else if (type.equals(TFloat64.class)) { + TFloat64 doubleT = (TFloat64) result.get(0); + AtomicInteger i = new AtomicInteger(); + doubleT + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals( + expected[i.getAndIncrement()].doubleValue(), f.getDouble(), 1e-5F)); + } } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index 909fd53ca27..c90cd03763e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -4,6 +4,8 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; + import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -189,13 +191,17 @@ public void testDeterminism() { g.importGraphDef(def); s.initialize(); - initialized.add( - s.runner() + Session.Result initializationRes = s.runner() .fetch(fcWeightName) .fetch(fcBiasName) .fetch(outputWeightName) .fetch(outputBiasName) - .run()); + .run(); + List initializedRun = new ArrayList<>(); + for (Map.Entry e : initializationRes) { + initializedRun.add(e.getValue()); + } + initialized.add(initializedRun); TFloat32 lossVal = (TFloat32) @@ -209,13 +215,17 @@ public void testDeterminism() { initialLoss[i] = lossVal.getFloat(); lossVal.close(); - trained.add( - s.runner() + Session.Result trainedRes = s.runner() .fetch(fcWeightName) .fetch(fcBiasName) .fetch(outputWeightName) .fetch(outputBiasName) - .run()); + .run(); + List trainedRun = new ArrayList<>(); + for (Map.Entry e : trainedRes) { + trainedRun.add(e.getValue()); + } + trained.add(trainedRun); lossVal = (TFloat32) From 1035b1dc441e0a96ab84453276caf8dd94af2816 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 30 Jan 2022 16:23:14 -0500 Subject: [PATCH 3/9] Making Session.Result final. --- .../src/main/java/org/tensorflow/Session.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index dbbf40622db..d5637929159 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -711,7 +711,7 @@ public void restore(String prefix) { * reference to a value after this object has been closed it will throw an {@link * IllegalStateException} upon access. */ - public static class Result implements AutoCloseable, Iterable> { + public static final class Result implements AutoCloseable, Iterable> { private static final Logger logger = Logger.getLogger(Result.class.getName()); From 5e92bc13060b03687c4fe4d201ea71350390b6d4 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 30 Jan 2022 16:25:21 -0500 Subject: [PATCH 4/9] Rearranging the members in Session.Result. --- .../src/main/java/org/tensorflow/Session.java | 73 +++++++++---------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index d5637929159..3feb7770ec2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -712,43 +712,6 @@ public void restore(String prefix) { * IllegalStateException} upon access. */ public static final class Result implements AutoCloseable, Iterable> { - - private static final Logger logger = Logger.getLogger(Result.class.getName()); - - private final Map map; - - private final List list; - - private final RunMetadata metadata; - - private boolean closed; - - /** - * Creates a Result from the names and values produced by {@link Session.Runner#run()}. - * - * @param names The output names. - * @param values The output values. - * @param metadata The run metadata, may be null. - */ - Result(List names, List values, RunMetadata metadata) { - this.map = new LinkedHashMap<>(); - this.list = new ArrayList<>(values); - - if (names.size() != values.size()) { - throw new IllegalArgumentException( - "Expected same number of names and values, found names.length = " - + names.size() - + ", values.length = " - + values.size()); - } - - for (int i = 0; i < names.size(); i++) { - this.map.put(names.get(i), values.get(i)); - } - this.metadata = metadata; - this.closed = false; - } - @Override public void close() { if (!closed) { @@ -827,6 +790,42 @@ public Optional get(String key) { public Optional getMetadata() { return Optional.ofNullable(metadata); } + + /** + * Creates a Result from the names and values produced by {@link Session.Runner#run()}. + * + * @param names The output names. + * @param values The output values. + * @param metadata The run metadata, may be null. + */ + Result(List names, List values, RunMetadata metadata) { + this.map = new LinkedHashMap<>(); + this.list = new ArrayList<>(values); + + if (names.size() != values.size()) { + throw new IllegalArgumentException( + "Expected same number of names and values, found names.length = " + + names.size() + + ", values.length = " + + values.size()); + } + + for (int i = 0; i < names.size(); i++) { + this.map.put(names.get(i), values.get(i)); + } + this.metadata = metadata; + this.closed = false; + } + + private final Map map; + + private final List list; + + private final RunMetadata metadata; + + private boolean closed; + + private static final Logger logger = Logger.getLogger(Result.class.getName()); } Graph graph() { From 5f89ee184e9bb179d5b2268258ee96706d3279bb Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 30 Jan 2022 16:59:19 -0500 Subject: [PATCH 5/9] Fix the tests. --- .../org/tensorflow/AutoCloseableList.java | 27 ------------------- .../org/tensorflow/ConcreteFunctionTest.java | 6 ++--- .../org/tensorflow/CustomGradientTest.java | 3 +-- .../java/org/tensorflow/DeviceSpecTest.java | 15 ++++------- .../test/java/org/tensorflow/GraphTest.java | 18 +++++-------- .../op/core/BooleanMaskUpdateTest.java | 6 ++--- .../org/tensorflow/op/core/ConstantTest.java | 17 ++++-------- .../org/tensorflow/op/core/GradientsTest.java | 15 ++++------- .../org/tensorflow/op/core/ZerosTest.java | 2 +- 9 files changed, 28 insertions(+), 81 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java deleted file mode 100644 index 330a40bae6b..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.tensorflow; - -import java.util.ArrayList; -import java.util.Collection; - -public final class AutoCloseableList extends ArrayList - implements AutoCloseable { - - public AutoCloseableList(Collection c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index 250ff9cc383..75503977f68 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -205,15 +205,13 @@ public void testGradientsGraph() { try (TFloat32 c1 = TFloat32.scalarOf(3.0f); TFloat32 c2 = TFloat32.scalarOf(2.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - s.runner() + Session.Result outputs = s.runner() .feed(x1, c1) .feed(x2, c2) .fetch(grads0[0]) .fetch(grads1[0]) .fetch(grads1[1]) - .run())) { + .run()) { assertEquals(3, outputs.size()); assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java index 62626c35641..abfd6bffc8d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java @@ -66,8 +66,7 @@ public void testCustomGradient() { assertEquals(DataType.DT_FLOAT, grads0[0].dataType()); try (TFloat32 c1 = TFloat32.vectorOf(3.0f, 2.0f, 1.0f, 0.0f); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed(x, c1).fetch(grads0[0]).run())) { + Session.Result outputs = s.runner().feed(x, c1).fetch(grads0[0]).run()) { assertEquals(1, outputs.size()); assertEquals(0.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java index e4340da3275..88e77a4022a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java @@ -53,8 +53,7 @@ public void withDeviceMethod() { .abs(aOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { + try (Session.Result t = session.runner().fetch(absOps).run()) { assertEquals(1, ((TInt32)t.get(0)).getInt()); } } @@ -85,8 +84,7 @@ public void withEmptyDeviceSpec() { .abs(aOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { + try (Session.Result t = session.runner().fetch(absOps).run()) { assertEquals(1, ((TInt32)t.get(0)).getInt()); } } @@ -131,8 +129,7 @@ public void withTwoScopes() { .mul(absOps, bOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { + try (Session.Result t = session.runner().fetch(mulOps).run()) { assertEquals(10, ((TInt32)t.get(0)).getInt()); } } @@ -179,8 +176,7 @@ public void withIncorrectDeviceSpec() { .mul(absOps, bOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { + try (Session.Result t = session.runner().fetch(mulOps).run()) { fail(); } catch (TFInvalidArgumentException e) { // ok @@ -212,8 +208,7 @@ public void withDeviceSpecInScope() { .abs(aOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { + try (Session.Result t = session.runner().fetch(absOps).run()) { assertEquals(1, ((TInt32)t.get(0)).getInt()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index 154d3903dcd..11f66cb8d8d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -84,15 +84,13 @@ public void graphDefRoundTripWithInit() { Operand variable2 = init.withName("var2").variable(init.constant(4)); - try (Session s = new Session(g, true)) { - List results = s.runner().fetch("result").fetch("var2").run(); + try (Session s = new Session(g, true); + Session.Result results = s.runner().fetch("result").fetch("var2").run()) { TInt32 result = (TInt32) results.get(0); assertEquals(6, result.getInt()); TInt32 var2Result = (TInt32) results.get(1); assertEquals(4, var2Result.getInt()); - - results.forEach(Tensor::close); } } } @@ -266,15 +264,13 @@ public void addGradientsToGraph() { try (TFloat32 c1 = TFloat32.scalarOf(3.0f); TFloat32 c2 = TFloat32.scalarOf(2.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - s.runner() + Session.Result outputs = s.runner() .feed(x1, c1) .feed(x2, c2) .fetch(grads0[0]) .fetch(grads1[0]) .fetch(grads1[1]) - .run())) { + .run()) { assertEquals(3, outputs.size()); assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); assertEquals(6.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f); @@ -418,14 +414,12 @@ public void buildWhileLoopMultipleInputs() { try (TInt32 c1 = TInt32.scalarOf(2); TInt32 c2 = TInt32.scalarOf(5); - AutoCloseableList outputs = - new AutoCloseableList<>( - s.runner() + Session.Result outputs = s.runner() .feed(input1, c1) .feed(input2, c2) .fetch(loopOutputs[0]) .fetch(loopOutputs[1]) - .run())) { + .run()) { assertEquals(2, outputs.size()); assertEquals(16, ((TInt32) outputs.get(0)).getInt()); // ((2^2)^2) assertEquals(625, ((TInt32) outputs.get(1)).getInt()); // ((5^2)^2) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java index 16c14f7a9a3..736bcbd0c63 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java @@ -50,7 +50,7 @@ public void testBooleanMaskUpdateSlice() { Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); - List results = sess.runner().fetch(output).fetch(bcastOutput).run(); + Session.Result results = sess.runner().fetch(output).fetch(bcastOutput).run(); try (TInt32 result = (TInt32) results.get(0); TInt32 bcastResult = (TInt32) results.get(1)) { @@ -89,7 +89,7 @@ public void testBooleanMaskUpdateSliceWithBroadcast() { Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); - List results = sess.runner().fetch(output).fetch(bcastOutput).run(); + Session.Result results = sess.runner().fetch(output).fetch(bcastOutput).run(); try (TInt32 result = (TInt32) results.get(0); TInt32 bcastResult = (TInt32) results.get(1)) { @@ -131,7 +131,7 @@ public void testBooleanMaskUpdateAxis() { BooleanMaskUpdate.create( scope, input, mask, Constant.scalarOf(scope, -1), BooleanMaskUpdate.axis(2)); - List results = sess.runner().fetch(output).fetch(bcastOutput).run(); + Session.Result results = sess.runner().fetch(output).fetch(bcastOutput).run(); try (TInt32 result = (TInt32) results.get(0); TInt32 bcastResult = (TInt32) results.get(1)) { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 5c413b3abeb..33391747ae0 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -19,12 +19,10 @@ import java.io.IOException; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.EagerSession; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.IntNdArray; @@ -66,8 +64,7 @@ public void createInts() { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -85,8 +82,7 @@ public void createFloats() { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -104,8 +100,7 @@ public void createDoubles() { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -123,8 +118,7 @@ public void createLongs() { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -142,8 +136,7 @@ public void createStrings() throws IOException { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index 80150b64bb6..e1ecaa7ea5f 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -21,11 +21,9 @@ import java.util.Arrays; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; @@ -48,9 +46,8 @@ public void createGradients() { assertEquals(2, grads.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { + Session.Result outputs = sess.runner() + .feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run()) { assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); assertEquals(18.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); @@ -75,8 +72,7 @@ public void createGradientsWithSum() { assertEquals(1, grads.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { + Session.Result outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run()) { assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); } @@ -101,9 +97,8 @@ public void createGradientsWithInitialValues() { assertEquals(1, grads1.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { + Session.Result outputs = + sess.runner().feed(x, c).fetch(grads1.dy(0)).run()) { assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java index b4d36702c93..30cffc8b51d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -134,7 +134,7 @@ public void operationsComposingZerosAreCorrectlyNamed() { long[] shape = {2, 2}; Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.class); - List results = + Session.Result results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); } } From b122dd0055e0dea1fdbe873c19d8521b6a93c6da Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sat, 5 Feb 2022 15:37:21 -0500 Subject: [PATCH 6/9] Change the signature of TensorFunction.call to return Result. Migrate Session.Result to be a top level class. --- .../java/org/tensorflow/ConcreteFunction.java | 8 +- .../src/main/java/org/tensorflow/Result.java | 186 ++++++++++++++++++ .../src/main/java/org/tensorflow/Session.java | 134 +------------ .../java/org/tensorflow/SessionFunction.java | 16 +- .../java/org/tensorflow/TensorFunction.java | 7 +- .../org/tensorflow/ConcreteFunctionTest.java | 8 +- .../org/tensorflow/CustomGradientTest.java | 2 +- .../java/org/tensorflow/DeviceSpecTest.java | 10 +- .../test/java/org/tensorflow/GraphTest.java | 7 +- .../org/tensorflow/SavedModelBundleTest.java | 10 +- .../test/java/org/tensorflow/SessionTest.java | 12 +- .../op/core/BooleanMaskUpdateTest.java | 9 +- .../org/tensorflow/op/core/ConstantTest.java | 11 +- .../org/tensorflow/op/core/GradientsTest.java | 7 +- .../org/tensorflow/op/core/ZerosTest.java | 4 +- .../framework/data/DatasetIteratorTest.java | 3 +- .../framework/data/MapDatasetTest.java | 3 +- .../metrics/impl/AssertBroadcastableTest.java | 6 +- .../metrics/impl/BroadcastWeightsTest.java | 7 +- .../optimizers/GradientDescentTest.java | 5 +- 20 files changed, 249 insertions(+), 206 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 4d07b678811..70dec2d9533 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -1,4 +1,4 @@ -/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020-2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -295,7 +295,7 @@ public Operand call(Scope scope, Operand argument) { } @Override - public Map call(Map arguments) { + public Result call(Map arguments) { // FIXME need to manage input/output operand lifetimes Ops tf = Ops.create(); Map> inputs = new LinkedHashMap<>(arguments.size()); @@ -305,11 +305,11 @@ public Map call(Map arguments) { inputs.put(inputName, tf.constantOf((TType) argument)); } Map> outputs = tf.call(this, inputs); - Map tensorOutputs = new LinkedHashMap<>(outputs.size()); + LinkedHashMap tensorOutputs = new LinkedHashMap<>(outputs.size()); for (String outputName : outputs.keySet()) { tensorOutputs.put(outputName, outputs.get(outputName).asTensor()); } - return tensorOutputs; + return new Result(tensorOutputs); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java new file mode 100644 index 00000000000..99703f517b1 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java @@ -0,0 +1,186 @@ +/* +Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. +Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow; + +import org.tensorflow.proto.framework.RunMetadata; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.logging.Logger; + +/** + * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s. + * + *

When this is closed it closes all the {@link Tensor}s inside it. If you maintain a + * reference to a value after this object has been closed it will throw an {@link + * IllegalStateException} upon access. + */ +public final class Result implements AutoCloseable, Iterable> { + @Override + public void close() { + if (!closed) { + closed = true; + for (Tensor t : map.values()) { + t.close(); + } + } else { + logger.warning("Closing an already closed Result"); + } + } + + @Override + public Iterator> iterator() { + if (!closed) { + return map.entrySet().iterator(); + } else { + throw new IllegalStateException("Result is closed"); + } + } + + /** + * Returns the number of outputs in this Result. + * + * @return The number of outputs. + */ + public int size() { + return map.size(); + } + + /** + * Gets the set containing all the tensor names. + * @return The tensor names set. + */ + public Set keySet() { + return Collections.unmodifiableSet(map.keySet()); + } + + /** + * Does this result object have a tensor for the supplied key? + * @param key The key to check. + * @return True if this result object has a tensor for this key. + */ + public boolean containsKey(String key) { + return map.containsKey(key); + } + + /** + * Gets the value from the container at the specified index. + * + *

Throws {@link IllegalStateException} if the container has been closed, and {@link + * IndexOutOfBoundsException} if the index is invalid. + * + * @param index The index to lookup. + * @return The value at the index. + */ + public Tensor get(int index) { + if (!closed) { + return list.get(index); + } else { + throw new IllegalStateException("Result is closed"); + } + } + + /** + * Gets the value from the container assuming it's not been closed. + * + *

Throws {@link IllegalStateException} if the container has been closed. + * + * @param key The key to lookup. + * @return Optional.of the value if it exists. + */ + public Optional get(String key) { + if (!closed) { + Tensor value = map.get(key); + if (value != null) { + return Optional.of(value); + } else { + return Optional.empty(); + } + } else { + throw new IllegalStateException("Result is closed"); + } + } + + /** + * Metadata about the run. + * + *

A RunMetadata + * protocol buffer. + */ + public Optional getMetadata() { + return Optional.ofNullable(metadata); + } + + /** + * Creates a Result from the names and values produced by {@link Session.Runner#run()}. + * + * @param names The output names. + * @param values The output values. + * @param metadata The run metadata, may be null. + */ + Result(List names, List values, RunMetadata metadata) { + this.map = new LinkedHashMap<>(); + this.list = new ArrayList<>(values); + + if (names.size() != values.size()) { + throw new IllegalArgumentException( + "Expected same number of names and values, found names.length = " + + names.size() + + ", values.length = " + + values.size()); + } + + for (int i = 0; i < names.size(); i++) { + this.map.put(names.get(i), values.get(i)); + } + this.metadata = metadata; + this.closed = false; + } + + /** + * Creates a Result from the names and values. + * + * @param outputs The run outputs. + */ + Result(LinkedHashMap outputs) { + this.map = outputs; + this.list = new ArrayList<>(outputs.size()); + for (Map.Entry e : outputs.entrySet()) { + list.add(e.getValue()); + } + this.metadata = null; + this.closed = false; + } + + private final Map map; + + private final List list; + + private final RunMetadata metadata; + + private boolean closed; + + private static final Logger logger = Logger.getLogger(Result.class.getName()); +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 3feb7770ec2..c4ede27df8b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -1,4 +1,4 @@ -/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,17 +22,12 @@ import com.google.protobuf.InvalidProtocolBufferException; -import java.sql.Array; import java.util.ArrayList; import java.util.Collections; -import java.util.Iterator; -import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; -import java.util.logging.Logger; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; @@ -654,8 +649,9 @@ public SessionFunction function(Signature signature) { * * @param signature the signature of the function * @param arguments the arguments to call with. + * @return The results of the function call. */ - public Map run(Signature signature, Map arguments) { + public Result run(Signature signature, Map arguments) { return function(signature).call(arguments); } @@ -704,130 +700,6 @@ public void restore(String prefix) { setInitialized(); } - /** - * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s. - * - *

When this is closed it closes all the {@link Tensor}s inside it. If you maintain a - * reference to a value after this object has been closed it will throw an {@link - * IllegalStateException} upon access. - */ - public static final class Result implements AutoCloseable, Iterable> { - @Override - public void close() { - if (!closed) { - closed = true; - for (Tensor t : map.values()) { - t.close(); - } - } else { - logger.warning("Closing an already closed Result"); - } - } - - @Override - public Iterator> iterator() { - if (!closed) { - return map.entrySet().iterator(); - } else { - throw new IllegalStateException("Result is closed"); - } - } - - /** - * Gets the value from the container at the specified index. - * - *

Throws {@link IllegalStateException} if the container has been closed, and {@link - * IndexOutOfBoundsException} if the index is invalid. - * - * @param index The index to lookup. - * @return The value at the index. - */ - public Tensor get(int index) { - if (!closed) { - return list.get(index); - } else { - throw new IllegalStateException("Result is closed"); - } - } - - /** - * Returns the number of outputs in this Result. - * - * @return The number of outputs. - */ - public int size() { - return map.size(); - } - - /** - * Gets the value from the container assuming it's not been closed. - * - *

Throws {@link IllegalStateException} if the container has been closed. - * - * @param key The key to lookup. - * @return Optional.of the value if it exists. - */ - public Optional get(String key) { - if (!closed) { - Tensor value = map.get(key); - if (value != null) { - return Optional.of(value); - } else { - return Optional.empty(); - } - } else { - throw new IllegalStateException("Result is closed"); - } - } - - /** - * Metadata about the run. - * - *

A RunMetadata - * protocol buffer. - */ - public Optional getMetadata() { - return Optional.ofNullable(metadata); - } - - /** - * Creates a Result from the names and values produced by {@link Session.Runner#run()}. - * - * @param names The output names. - * @param values The output values. - * @param metadata The run metadata, may be null. - */ - Result(List names, List values, RunMetadata metadata) { - this.map = new LinkedHashMap<>(); - this.list = new ArrayList<>(values); - - if (names.size() != values.size()) { - throw new IllegalArgumentException( - "Expected same number of names and values, found names.length = " - + names.size() - + ", values.length = " - + values.size()); - } - - for (int i = 0; i < names.size(); i++) { - this.map.put(names.get(i), values.get(i)); - } - this.metadata = metadata; - this.closed = false; - } - - private final Map map; - - private final List list; - - private final RunMetadata metadata; - - private boolean closed; - - private static final Logger logger = Logger.getLogger(Result.class.getName()); - } - Graph graph() { return graph; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java index 3c6a11fd954..8338eac4710 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021-2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ import java.io.IOException; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; /** @@ -89,7 +88,7 @@ public SessionFunction withNewSession(Session session) { } @Override - public Map call(Map arguments) { + public Result call(Map arguments) { Session.Runner runner = session.runner(); signature .getInputs() @@ -113,15 +112,6 @@ public Map call(Map arguments) { signature.getOutputs().values().forEach(x -> runner.fetch(x.name)); - Session.Result results = runner.run(); - - Map outputs = new LinkedHashMap<>(results.size()); - int i = 0; - for (String outputName : signature.outputNames()) { - outputs.put(outputName, results.get(i)); - i++; - } - - return outputs; + return runner.run(); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java index 0304d786494..e58437e6f8f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java @@ -28,7 +28,7 @@ public interface TensorFunction { /** * Invokes a function using the default eager session. * - *

Caller is responsible for closing all Tensors. + *

Caller is responsible for close the result object. * * @param arguments list of tensors to pass in input to the function, mapped by their signature * name @@ -37,7 +37,7 @@ public interface TensorFunction { * @throws IllegalArgumentException if the passed arguments don't match up to the function's * parameters. */ - Map call(Map arguments); + Result call(Map arguments); /** * Invokes a function with a single input and output using the default eager session. @@ -76,12 +76,11 @@ default Tensor call(Tensor tensor) { } String inputName = signature().inputNames().iterator().next(); - String outputName = signature().outputNames().iterator().next(); Map inputMap = new LinkedHashMap<>(); inputMap.put(inputName, tensor); - return call(inputMap).get(outputName); + return call(inputMap).get(0); } static Operand validateDescription( diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index 75503977f68..6a900b7b6cf 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -162,9 +162,9 @@ public void testFunctionWithTwoOutputs() { Map inputs = new HashMap<>(); inputs.put("x", TInt32.scalarOf(2)); - Map outputs = cf.call(inputs); - assertEquals(4, ((TInt32) outputs.get("dbl")).getInt()); - assertEquals(6, ((TInt32) outputs.get("trpl")).getInt()); + Result outputs = cf.call(inputs); + assertEquals(4, ((TInt32) outputs.get("dbl").get()).getInt()); + assertEquals(6, ((TInt32) outputs.get("trpl").get()).getInt()); } private static Signature square(Ops tf) { @@ -205,7 +205,7 @@ public void testGradientsGraph() { try (TFloat32 c1 = TFloat32.scalarOf(3.0f); TFloat32 c2 = TFloat32.scalarOf(2.0f); - Session.Result outputs = s.runner() + Result outputs = s.runner() .feed(x1, c1) .feed(x2, c2) .fetch(grads0[0]) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java index abfd6bffc8d..0ad94ad2130 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java @@ -66,7 +66,7 @@ public void testCustomGradient() { assertEquals(DataType.DT_FLOAT, grads0[0].dataType()); try (TFloat32 c1 = TFloat32.vectorOf(3.0f, 2.0f, 1.0f, 0.0f); - Session.Result outputs = s.runner().feed(x, c1).fetch(grads0[0]).run()) { + Result outputs = s.runner().feed(x, c1).fetch(grads0[0]).run()) { assertEquals(1, outputs.size()); assertEquals(0.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java index 88e77a4022a..9d2316603d1 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java @@ -53,7 +53,7 @@ public void withDeviceMethod() { .abs(aOps) .asOutput(); - try (Session.Result t = session.runner().fetch(absOps).run()) { + try (Result t = session.runner().fetch(absOps).run()) { assertEquals(1, ((TInt32)t.get(0)).getInt()); } } @@ -84,7 +84,7 @@ public void withEmptyDeviceSpec() { .abs(aOps) .asOutput(); - try (Session.Result t = session.runner().fetch(absOps).run()) { + try (Result t = session.runner().fetch(absOps).run()) { assertEquals(1, ((TInt32)t.get(0)).getInt()); } } @@ -129,7 +129,7 @@ public void withTwoScopes() { .mul(absOps, bOps) .asOutput(); - try (Session.Result t = session.runner().fetch(mulOps).run()) { + try (Result t = session.runner().fetch(mulOps).run()) { assertEquals(10, ((TInt32)t.get(0)).getInt()); } } @@ -176,7 +176,7 @@ public void withIncorrectDeviceSpec() { .mul(absOps, bOps) .asOutput(); - try (Session.Result t = session.runner().fetch(mulOps).run()) { + try (Result t = session.runner().fetch(mulOps).run()) { fail(); } catch (TFInvalidArgumentException e) { // ok @@ -208,7 +208,7 @@ public void withDeviceSpecInScope() { .abs(aOps) .asOutput(); - try (Session.Result t = session.runner().fetch(absOps).run()) { + try (Result t = session.runner().fetch(absOps).run()) { assertEquals(1, ((TInt32)t.get(0)).getInt()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index 11f66cb8d8d..c8055ddae14 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -25,7 +25,6 @@ import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashSet; -import java.util.List; import java.util.Set; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TFInvalidArgumentException; @@ -85,7 +84,7 @@ public void graphDefRoundTripWithInit() { Operand variable2 = init.withName("var2").variable(init.constant(4)); try (Session s = new Session(g, true); - Session.Result results = s.runner().fetch("result").fetch("var2").run()) { + Result results = s.runner().fetch("result").fetch("var2").run()) { TInt32 result = (TInt32) results.get(0); assertEquals(6, result.getInt()); @@ -264,7 +263,7 @@ public void addGradientsToGraph() { try (TFloat32 c1 = TFloat32.scalarOf(3.0f); TFloat32 c2 = TFloat32.scalarOf(2.0f); - Session.Result outputs = s.runner() + Result outputs = s.runner() .feed(x1, c1) .feed(x2, c2) .fetch(grads0[0]) @@ -414,7 +413,7 @@ public void buildWhileLoopMultipleInputs() { try (TInt32 c1 = TInt32.scalarOf(2); TInt32 c2 = TInt32.scalarOf(5); - Session.Result outputs = s.runner() + Result outputs = s.runner() .feed(input1, c1) .feed(input2, c2) .fetch(loopOutputs[0]) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index be6f952fb6a..328118a9568 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -293,9 +293,9 @@ public void pythonTfFunction() { System.out.println(add.signature()); args.put("a", a); args.put("b", b); - Map result = add.call(args); + Result result = add.call(args); assertEquals(result.size(), 1); - try (TFloat32 c = (TFloat32) result.values().iterator().next()) { + try (TFloat32 c = (TFloat32) result.get(0)) { assertEquals(25.5f, c.getFloat()); } } @@ -307,11 +307,7 @@ public void pythonTfFunction() { args.put("dummy", dummy); // TF functions always require an input, so we supply a dummy one here // This test actually checks that resource variables can be loaded correctly. - try (TFloat32 v = - (TFloat32) - getVariable - .call(args) - .get(getVariable.signature().outputNames().iterator().next())) { + try (TFloat32 v = (TFloat32) getVariable.call(args).get(0)) { assertEquals(2f, v.getFloat()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index ddff36de91b..aefb048db7e 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -72,7 +72,7 @@ public void runUsingOperationNames() { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - Session.Result outputs = s.runner().feed("X", x).fetch("Y").run()) { + Result outputs = s.runner().feed("X", x).fetch("Y").run()) { assertEquals(1, outputs.size()); assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } @@ -88,7 +88,7 @@ public void runUsingOperationHandles() { Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - Session.Result outputs = s.runner().feed(feed, x).fetch(fetch).run()) { + Result outputs = s.runner().feed(feed, x).fetch(fetch).run()) { assertEquals(1, outputs.size()); assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } @@ -125,7 +125,7 @@ public void runWithMetadata() { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { - Session.Result result = + Result result = s.runner() .feed("X", x) .fetch("Y") @@ -150,7 +150,7 @@ public void runMultipleOutputs() { Ops tf = Ops.create(g); tf.withName("c1").constant(2718); tf.withName("c2").constant(31415); - Session.Result outputs = s.runner().fetch("c2").fetch("c1").run(); + Result outputs = s.runner().fetch("c2").fetch("c1").run(); assertEquals(2, outputs.size()); assertEquals(31415, ((TInt32) outputs.get(0)).getInt()); assertEquals(2718, ((TInt32) outputs.get(1)).getInt()); @@ -227,8 +227,8 @@ public void saveAndRestore() throws IOException { restoredGraph.importGraphDef(graphDef); try (Session restoredSession = new Session(restoredGraph)) { restoredSession.restore(testFolder.resolve("checkpoint").toString()); - try (Session.Result oldList = s.runner().fetch("x").fetch("y").run(); - Session.Result newList = restoredSession.runner().fetch("x").fetch("y").run()) { + try (Result oldList = s.runner().fetch("x").fetch("y").run(); + Result newList = restoredSession.runner().fetch("x").fetch("y").run()) { assertEquals(oldList.get(0), newList.get(0)); assertEquals(oldList.get(1), newList.get(1)); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java index 736bcbd0c63..4edbea33b0d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java @@ -18,12 +18,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.Result; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.OpScope; import org.tensorflow.op.Scope; @@ -50,7 +49,7 @@ public void testBooleanMaskUpdateSlice() { Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); - Session.Result results = sess.runner().fetch(output).fetch(bcastOutput).run(); + Result results = sess.runner().fetch(output).fetch(bcastOutput).run(); try (TInt32 result = (TInt32) results.get(0); TInt32 bcastResult = (TInt32) results.get(1)) { @@ -89,7 +88,7 @@ public void testBooleanMaskUpdateSliceWithBroadcast() { Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); - Session.Result results = sess.runner().fetch(output).fetch(bcastOutput).run(); + Result results = sess.runner().fetch(output).fetch(bcastOutput).run(); try (TInt32 result = (TInt32) results.get(0); TInt32 bcastResult = (TInt32) results.get(1)) { @@ -131,7 +130,7 @@ public void testBooleanMaskUpdateAxis() { BooleanMaskUpdate.create( scope, input, mask, Constant.scalarOf(scope, -1), BooleanMaskUpdate.axis(2)); - Session.Result results = sess.runner().fetch(output).fetch(bcastOutput).run(); + Result results = sess.runner().fetch(output).fetch(bcastOutput).run(); try (TInt32 result = (TInt32) results.get(0); TInt32 bcastResult = (TInt32) results.get(1)) { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 33391747ae0..5194fccd707 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -22,6 +22,7 @@ import org.tensorflow.EagerSession; import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.Result; import org.tensorflow.Session; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.FloatNdArray; @@ -64,7 +65,7 @@ public void createInts() { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { + try (Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -82,7 +83,7 @@ public void createFloats() { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { + try (Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -100,7 +101,7 @@ public void createDoubles() { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { + try (Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -118,7 +119,7 @@ public void createLongs() { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { + try (Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -136,7 +137,7 @@ public void createStrings() throws IOException { Scope scope = new OpScope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { + try (Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index e1ecaa7ea5f..7ba4c4da26c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Output; +import org.tensorflow.Result; import org.tensorflow.Session; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; @@ -46,7 +47,7 @@ public void createGradients() { assertEquals(2, grads.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - Session.Result outputs = sess.runner() + Result outputs = sess.runner() .feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run()) { assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); @@ -72,7 +73,7 @@ public void createGradientsWithSum() { assertEquals(1, grads.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - Session.Result outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run()) { + Result outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run()) { assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); } @@ -97,7 +98,7 @@ public void createGradientsWithInitialValues() { assertEquals(1, grads1.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - Session.Result outputs = + Result outputs = sess.runner().feed(x, c).fetch(grads1.dy(0)).run()) { assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java index 30cffc8b51d..73b7e0a551c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -19,9 +19,9 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; +import org.tensorflow.Result; import org.tensorflow.Session; import org.tensorflow.op.OpScope; import org.tensorflow.op.Scope; @@ -134,7 +134,7 @@ public void operationsComposingZerosAreCorrectlyNamed() { long[] shape = {2, 2}; Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.class); - Session.Result results = + Result results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java index 9a281305703..1bbeb1a3f0a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.Result; import org.tensorflow.Session; import org.tensorflow.exceptions.TFOutOfRangeException; import org.tensorflow.op.Ops; @@ -51,7 +52,7 @@ public void testGraphIteration() { int batches = 0; while (true) { - try (Session.Result outputs = session.runner().fetch(x).fetch(y).run()) { + try (Result outputs = session.runner().fetch(x).fetch(y).run()) { assertEquals(testMatrix1.get(batches), outputs.get(0)); assertEquals(testMatrix2.get(batches), outputs.get(1)); batches++; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java index 9cac8fe54c4..7cb3ed1cc18 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.Result; import org.tensorflow.Session; import org.tensorflow.exceptions.TFOutOfRangeException; import org.tensorflow.ndarray.IntNdArray; @@ -76,7 +77,7 @@ public void testGraphIteration() { int batches = 0; while (true) { - try (Session.Result outputs = session.runner().fetch(X).fetch(y).run()) { + try (Result outputs = session.runner().fetch(X).fetch(y).run()) { assertEquals(mapped1.get(batches), outputs.get(0)); assertEquals(mapped2.get(batches), outputs.get(1)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index 18b8a6254ce..278c9dbf1c4 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -16,7 +16,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; -import org.tensorflow.Session; +import org.tensorflow.Result; import org.tensorflow.Tensor; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Op; @@ -27,8 +27,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import java.util.List; - import static org.junit.jupiter.api.Assertions.assertThrows; public class AssertBroadcastableTest { @@ -70,7 +68,7 @@ private void testValid( Operand weightsPlaceholder = tf.placeholder(type); Operand valuesPlaceholder = tf.placeholder(type); - try (Session.Result tensors = + try (Result tensors = testSession.getGraphSession().runner().fetch(weights).fetch(values).run()) { Tensor weightsTensor = tensors.get(0); Tensor valuesTensor = tensors.get(1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java index e72d3534ade..a8ff95a5e15 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java @@ -16,7 +16,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; -import org.tensorflow.Session; +import org.tensorflow.Result; import org.tensorflow.Tensor; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; @@ -26,7 +26,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -79,7 +78,7 @@ private void testValid( Operand weightsPlaceholder = tf.placeholder(type); Operand valuesPlaceholder = tf.placeholder(type); - try (Session.Result tensors = + try (Result tensors = testSession.getGraphSession().runner().fetch(weights).fetch(values).run()) { Tensor weightsTensor = tensors.get(0); Tensor valuesTensor = tensors.get(1); @@ -87,7 +86,7 @@ private void testValid( Operand dynamicOp = MetricsHelper.broadcastWeights(tf, weightsPlaceholder, valuesPlaceholder); - try (Session.Result result = + try (Result result = testSession .getGraphSession() .runner() diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index c90cd03763e..da48aee9f78 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -13,6 +13,7 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; +import org.tensorflow.Result; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.framework.initializers.Glorot; @@ -191,7 +192,7 @@ public void testDeterminism() { g.importGraphDef(def); s.initialize(); - Session.Result initializationRes = s.runner() + Result initializationRes = s.runner() .fetch(fcWeightName) .fetch(fcBiasName) .fetch(outputWeightName) @@ -215,7 +216,7 @@ public void testDeterminism() { initialLoss[i] = lossVal.getFloat(); lossVal.close(); - Session.Result trainedRes = s.runner() + Result trainedRes = s.runner() .fetch(fcWeightName) .fetch(fcBiasName) .fetch(outputWeightName) From ad9172c68b556b3aad206cde72cd12caf6c2fdc5 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sat, 5 Feb 2022 15:47:38 -0500 Subject: [PATCH 7/9] Missed a bit in the refactor. --- .../src/main/java/org/tensorflow/SavedModelBundle.java | 2 +- .../src/test/java/org/tensorflow/SavedModelBundleTest.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 4295dbb6c4a..35d81e7bc16 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -468,7 +468,7 @@ public List functions() { * @return list of output tensors, mapped by the signature name * @throws IllegalArgumentException if no function can be selected by default */ - public Map call(Map arguments) { + public Result call(Map arguments) { SessionFunction function = null; if (functions.size() == 1) { function = functions.values().iterator().next(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 328118a9568..a5191182ffb 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -215,7 +215,7 @@ public void exportFunctionWithVariables() throws IOException { // Now call the same function directly from the model try (TFloat32 zTensor = (TFloat32) - savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum")) { + savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum").get()) { assertEquals(reducedSum, zTensor.getFloat(), EPSILON); } } From 022d103dd07f6cb59214c10dad8b5d457e9a2b7c Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 7 Feb 2022 11:44:53 -0500 Subject: [PATCH 8/9] Fixes for review comments, and to ensure that tensors are named as expected. --- .../java/org/tensorflow/ConcreteFunction.java | 2 +- .../src/main/java/org/tensorflow/Result.java | 32 +++++++++++++------ .../src/main/java/org/tensorflow/Session.java | 24 ++++++++++++-- .../java/org/tensorflow/SessionFunction.java | 12 ++++++- .../src/main/java/org/tensorflow/Tensor.java | 2 +- 5 files changed, 56 insertions(+), 16 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 70dec2d9533..c822678fda6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -296,7 +296,7 @@ public Operand call(Scope scope, Operand argument) { @Override public Result call(Map arguments) { - // FIXME need to manage input/output operand lifetimes + // FIXME need to manage input operand lifetimes Ops tf = Ops.create(); Map> inputs = new LinkedHashMap<>(arguments.size()); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java index 99703f517b1..825a52a56b5 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java @@ -17,6 +17,7 @@ */ package org.tensorflow; +import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.proto.framework.RunMetadata; import java.util.ArrayList; @@ -27,6 +28,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.logging.Level; import java.util.logging.Logger; /** @@ -35,15 +37,27 @@ *

When this is closed it closes all the {@link Tensor}s inside it. If you maintain a * reference to a value after this object has been closed it will throw an {@link * IllegalStateException} upon access. + * + *

This class is not thread-safe with respect to the close operation. Multiple closers + * or one thread closing a tensor while another is reading may throw exceptions. + * + *

Note this class is used to manage the lifetimes of tensors produced by the + * TensorFlow runtime, from sessions and function calls. It is not used as an argument + * to {@code session.run} or function calls as users are in control of the creation + * of input tensors. */ public final class Result implements AutoCloseable, Iterable> { @Override public void close() { if (!closed) { - closed = true; - for (Tensor t : map.values()) { - t.close(); + for (Tensor t : list) { + try { + t.close(); + } catch (TensorFlowException e) { + logger.log(Level.WARNING, "Exception raised when closing tensor inside result.", e); + } } + closed = true; } else { logger.warning("Closing an already closed Result"); } @@ -111,12 +125,7 @@ public Tensor get(int index) { */ public Optional get(String key) { if (!closed) { - Tensor value = map.get(key); - if (value != null) { - return Optional.of(value); - } else { - return Optional.empty(); - } + return Optional.ofNullable(map.get(key)); } else { throw new IllegalStateException("Result is closed"); } @@ -153,7 +162,10 @@ public Optional getMetadata() { } for (int i = 0; i < names.size(); i++) { - this.map.put(names.get(i), values.get(i)); + Tensor old = this.map.put(names.get(i), values.get(i)); + if (old != null) { + throw new IllegalArgumentException("Name collision in the result set, two outputs are named '" + names.get(i) + "'"); + } } this.metadata = metadata; this.closed = false; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index c4ede27df8b..067a261938e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -308,7 +308,9 @@ public Runner feed(Operand operand, Tensor t) { * @throws IllegalArgumentException if no output exists with the provided name */ public Runner fetch(String operation) { - return fetch(graph.outputOrThrow(operation)); + Runner r = fetch(graph.outputOrThrow(operation),false); + outputNames.add(operation); + return r; } /** @@ -338,6 +340,20 @@ public Runner fetch(String operation, int index) { * @return this session runner */ public Runner fetch(Output output) { + return fetch(output, true); + } + + /** + * Makes {@link #run()} return the Tensor referred to by {@code output}. + * + *

If {@code output} is a resource variable, will fetch the value. + * + * @param output the node to fetch the tensor from + * @param recordName Records the output name. If false the output name must be recorded by the + * calling method as otherwise the result object will throw on construction. + * @return this session runner + */ + private Runner fetch(Output output, boolean recordName) { if (output.env() != graph) { throw new IllegalStateException( "Can't fetch output " @@ -380,6 +396,9 @@ public Runner fetch(Output output) { } else { outputs.add(output); } + if (recordName) { + outputNames.add(output.name()); + } return this; } @@ -523,7 +542,6 @@ private Result runHelper(boolean wantMetadata) { TF_Operation[] outputOpHandles = new TF_Operation[outputs.size()]; int[] outputOpIndices = new int[outputs.size()]; TF_Operation[] targetOpHandles = new TF_Operation[targets.size()]; - List outputNames = new ArrayList<>(); // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the // validity of the Graph and graphRef ensures that. @@ -541,7 +559,6 @@ private Result runHelper(boolean wantMetadata) { for (Output o : outputs) { outputOpHandles[idx] = (TF_Operation) o.getUnsafeNativeHandle(); outputOpIndices[idx] = o.index(); - outputNames.add(o.name()); idx++; } idx = 0; @@ -603,6 +620,7 @@ public void close() { private final ArrayList> inputs = new ArrayList<>(); private final ArrayList inputTensors = new ArrayList<>(); private final ArrayList> outputs = new ArrayList<>(); + private final ArrayList outputNames = new ArrayList<>(); private final ArrayList targets = new ArrayList<>(); private RunOptions runOptions = null; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java index 8338eac4710..bb8f58adaf1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java @@ -112,6 +112,16 @@ public Result call(Map arguments) { signature.getOutputs().values().forEach(x -> runner.fetch(x.name)); - return runner.run(); + Result results = runner.run(); + + // Unpack the result object and rebuild it with the expected names. + LinkedHashMap outputs = new LinkedHashMap<>(results.size()); + int i = 0; + for (String outputName : signature.outputNames()) { + outputs.put(outputName, results.get(i)); + i++; + } + + return new Result(outputs); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 3b9deff9cd4..2ba3dc0a906 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -210,7 +210,7 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData *

When this methods retuns {@code true}, the tensor could be cast to a {@link SparseTensor * SparseTensor} to access its indices, values and denseShape tensors. * - * @retrun true if this tensor is a sparse + * @return true if this tensor is a sparse */ default boolean isSparse() { return false; From 48ea5452aa49d0ff56dd8da2597399f801dc2a6f Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 7 Feb 2022 17:14:34 -0500 Subject: [PATCH 9/9] Applying spotless. --- .../src/main/java/org/tensorflow/Result.java | 293 +++++++++--------- .../src/main/java/org/tensorflow/Session.java | 8 +- .../java/org/tensorflow/SessionFunction.java | 26 +- .../java/org/tensorflow/TensorFunction.java | 22 +- .../org/tensorflow/ConcreteFunctionTest.java | 15 +- .../java/org/tensorflow/DeviceSpecTest.java | 115 ++++--- .../test/java/org/tensorflow/GraphTest.java | 30 +- .../org/tensorflow/SavedModelBundleTest.java | 5 +- .../test/java/org/tensorflow/SessionTest.java | 4 +- .../org/tensorflow/op/core/GradientsTest.java | 17 +- .../framework/data/MapDatasetTest.java | 6 +- .../metrics/impl/AssertBroadcastableTest.java | 4 +- .../metrics/impl/BroadcastWeightsTest.java | 40 +-- .../optimizers/GradientDescentTest.java | 7 +- 14 files changed, 294 insertions(+), 298 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java index 825a52a56b5..a3560b068b1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java @@ -17,9 +17,6 @@ */ package org.tensorflow; -import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.proto.framework.RunMetadata; - import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; @@ -30,169 +27,173 @@ import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; +import org.tensorflow.exceptions.TensorFlowException; +import org.tensorflow.proto.framework.RunMetadata; /** * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s. * - *

When this is closed it closes all the {@link Tensor}s inside it. If you maintain a - * reference to a value after this object has been closed it will throw an {@link - * IllegalStateException} upon access. + *

When this is closed it closes all the {@link Tensor}s inside it. If you maintain a reference + * to a value after this object has been closed it will throw an {@link IllegalStateException} upon + * access. * - *

This class is not thread-safe with respect to the close operation. Multiple closers - * or one thread closing a tensor while another is reading may throw exceptions. + *

This class is not thread-safe with respect to the close operation. Multiple closers or one + * thread closing a tensor while another is reading may throw exceptions. * - *

Note this class is used to manage the lifetimes of tensors produced by the - * TensorFlow runtime, from sessions and function calls. It is not used as an argument - * to {@code session.run} or function calls as users are in control of the creation - * of input tensors. + *

Note this class is used to manage the lifetimes of tensors produced by the TensorFlow runtime, + * from sessions and function calls. It is not used as an argument to {@code session.run} or + * function calls as users are in control of the creation of input tensors. */ public final class Result implements AutoCloseable, Iterable> { - @Override - public void close() { - if (!closed) { - for (Tensor t : list) { - try { - t.close(); - } catch (TensorFlowException e) { - logger.log(Level.WARNING, "Exception raised when closing tensor inside result.", e); - } - } - closed = true; - } else { - logger.warning("Closing an already closed Result"); + @Override + public void close() { + if (!closed) { + for (Tensor t : list) { + try { + t.close(); + } catch (TensorFlowException e) { + logger.log(Level.WARNING, "Exception raised when closing tensor inside result.", e); } + } + closed = true; + } else { + logger.warning("Closing an already closed Result"); } - - @Override - public Iterator> iterator() { - if (!closed) { - return map.entrySet().iterator(); - } else { - throw new IllegalStateException("Result is closed"); - } + } + + @Override + public Iterator> iterator() { + if (!closed) { + return map.entrySet().iterator(); + } else { + throw new IllegalStateException("Result is closed"); } - - /** - * Returns the number of outputs in this Result. - * - * @return The number of outputs. - */ - public int size() { - return map.size(); + } + + /** + * Returns the number of outputs in this Result. + * + * @return The number of outputs. + */ + public int size() { + return map.size(); + } + + /** + * Gets the set containing all the tensor names. + * + * @return The tensor names set. + */ + public Set keySet() { + return Collections.unmodifiableSet(map.keySet()); + } + + /** + * Does this result object have a tensor for the supplied key? + * + * @param key The key to check. + * @return True if this result object has a tensor for this key. + */ + public boolean containsKey(String key) { + return map.containsKey(key); + } + + /** + * Gets the value from the container at the specified index. + * + *

Throws {@link IllegalStateException} if the container has been closed, and {@link + * IndexOutOfBoundsException} if the index is invalid. + * + * @param index The index to lookup. + * @return The value at the index. + */ + public Tensor get(int index) { + if (!closed) { + return list.get(index); + } else { + throw new IllegalStateException("Result is closed"); } - - /** - * Gets the set containing all the tensor names. - * @return The tensor names set. - */ - public Set keySet() { - return Collections.unmodifiableSet(map.keySet()); + } + + /** + * Gets the value from the container assuming it's not been closed. + * + *

Throws {@link IllegalStateException} if the container has been closed. + * + * @param key The key to lookup. + * @return Optional.of the value if it exists. + */ + public Optional get(String key) { + if (!closed) { + return Optional.ofNullable(map.get(key)); + } else { + throw new IllegalStateException("Result is closed"); } - - /** - * Does this result object have a tensor for the supplied key? - * @param key The key to check. - * @return True if this result object has a tensor for this key. - */ - public boolean containsKey(String key) { - return map.containsKey(key); + } + + /** + * Metadata about the run. + * + *

A RunMetadata + * protocol buffer. + */ + public Optional getMetadata() { + return Optional.ofNullable(metadata); + } + + /** + * Creates a Result from the names and values produced by {@link Session.Runner#run()}. + * + * @param names The output names. + * @param values The output values. + * @param metadata The run metadata, may be null. + */ + Result(List names, List values, RunMetadata metadata) { + this.map = new LinkedHashMap<>(); + this.list = new ArrayList<>(values); + + if (names.size() != values.size()) { + throw new IllegalArgumentException( + "Expected same number of names and values, found names.length = " + + names.size() + + ", values.length = " + + values.size()); } - /** - * Gets the value from the container at the specified index. - * - *

Throws {@link IllegalStateException} if the container has been closed, and {@link - * IndexOutOfBoundsException} if the index is invalid. - * - * @param index The index to lookup. - * @return The value at the index. - */ - public Tensor get(int index) { - if (!closed) { - return list.get(index); - } else { - throw new IllegalStateException("Result is closed"); - } + for (int i = 0; i < names.size(); i++) { + Tensor old = this.map.put(names.get(i), values.get(i)); + if (old != null) { + throw new IllegalArgumentException( + "Name collision in the result set, two outputs are named '" + names.get(i) + "'"); + } } - - /** - * Gets the value from the container assuming it's not been closed. - * - *

Throws {@link IllegalStateException} if the container has been closed. - * - * @param key The key to lookup. - * @return Optional.of the value if it exists. - */ - public Optional get(String key) { - if (!closed) { - return Optional.ofNullable(map.get(key)); - } else { - throw new IllegalStateException("Result is closed"); - } - } - - /** - * Metadata about the run. - * - *

A RunMetadata - * protocol buffer. - */ - public Optional getMetadata() { - return Optional.ofNullable(metadata); - } - - /** - * Creates a Result from the names and values produced by {@link Session.Runner#run()}. - * - * @param names The output names. - * @param values The output values. - * @param metadata The run metadata, may be null. - */ - Result(List names, List values, RunMetadata metadata) { - this.map = new LinkedHashMap<>(); - this.list = new ArrayList<>(values); - - if (names.size() != values.size()) { - throw new IllegalArgumentException( - "Expected same number of names and values, found names.length = " - + names.size() - + ", values.length = " - + values.size()); - } - - for (int i = 0; i < names.size(); i++) { - Tensor old = this.map.put(names.get(i), values.get(i)); - if (old != null) { - throw new IllegalArgumentException("Name collision in the result set, two outputs are named '" + names.get(i) + "'"); - } - } - this.metadata = metadata; - this.closed = false; - } - - /** - * Creates a Result from the names and values. - * - * @param outputs The run outputs. - */ - Result(LinkedHashMap outputs) { - this.map = outputs; - this.list = new ArrayList<>(outputs.size()); - for (Map.Entry e : outputs.entrySet()) { - list.add(e.getValue()); - } - this.metadata = null; - this.closed = false; + this.metadata = metadata; + this.closed = false; + } + + /** + * Creates a Result from the names and values. + * + * @param outputs The run outputs. + */ + Result(LinkedHashMap outputs) { + this.map = outputs; + this.list = new ArrayList<>(outputs.size()); + for (Map.Entry e : outputs.entrySet()) { + list.add(e.getValue()); } + this.metadata = null; + this.closed = false; + } - private final Map map; + private final Map map; - private final List list; + private final List list; - private final RunMetadata metadata; + private final RunMetadata metadata; - private boolean closed; + private boolean closed; - private static final Logger logger = Logger.getLogger(Result.class.getName()); + private static final Logger logger = Logger.getLogger(Result.class.getName()); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 067a261938e..76be5597cc1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -21,14 +21,12 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; import com.google.protobuf.InvalidProtocolBufferException; - import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; - import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -308,7 +306,7 @@ public Runner feed(Operand operand, Tensor t) { * @throws IllegalArgumentException if no output exists with the provided name */ public Runner fetch(String operation) { - Runner r = fetch(graph.outputOrThrow(operation),false); + Runner r = fetch(graph.outputOrThrow(operation), false); outputNames.add(operation); return r; } @@ -350,7 +348,7 @@ public Runner fetch(Output output) { * * @param output the node to fetch the tensor from * @param recordName Records the output name. If false the output name must be recorded by the - * calling method as otherwise the result object will throw on construction. + * calling method as otherwise the result object will throw on construction. * @return this session runner */ private Runner fetch(Output output, boolean recordName) { @@ -590,7 +588,7 @@ private Result runHelper(boolean wantMetadata) { } finally { runRef.close(); } - return new Result(outputNames,outputs,metadata); + return new Result(outputNames, outputs, metadata); } private class Reference implements AutoCloseable { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java index bb8f58adaf1..877ba1b2f2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java @@ -1,18 +1,18 @@ /* Copyright 2021-2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import java.io.IOException; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java index e58437e6f8f..1b83a1176ca 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import java.util.LinkedHashMap; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index 6a900b7b6cf..b303618eae2 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -205,13 +205,14 @@ public void testGradientsGraph() { try (TFloat32 c1 = TFloat32.scalarOf(3.0f); TFloat32 c2 = TFloat32.scalarOf(2.0f); - Result outputs = s.runner() - .feed(x1, c1) - .feed(x2, c2) - .fetch(grads0[0]) - .fetch(grads1[0]) - .fetch(grads1[1]) - .run()) { + Result outputs = + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run()) { assertEquals(3, outputs.size()); assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java index 9d2316603d1..28a549d72ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java @@ -14,6 +14,11 @@ ==============================================================================*/ package org.tensorflow; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; +import static org.tensorflow.DeviceSpec.DeviceType; + import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.op.Ops; @@ -21,90 +26,87 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.types.TInt32; -import static com.google.common.truth.Truth.assertThat; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; -import static org.tensorflow.DeviceSpec.DeviceType; - /** Tests for {@link DeviceSpec}. */ public class DeviceSpecTest { @Test public void withDeviceMethod() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + ConfigProto config = + ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) .setLogDevicePlacement(true) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { + try (Graph g = new Graph(); + Session session = new Session(g, config)) { Ops tf = Ops.create(g).withSubScope("testScope"); Constant aOps = tf.constant(-1); - DeviceSpec deviceSpec = DeviceSpec.newBuilder() + DeviceSpec deviceSpec = + DeviceSpec.newBuilder() .job("localhost") .replica(0) .task(0) .deviceType(DeviceSpec.DeviceType.CPU) .build(); - Output absOps = tf - .withName("absWithDevice") - .withDevice(deviceSpec) - .math - .abs(aOps) - .asOutput(); + Output absOps = + tf.withName("absWithDevice").withDevice(deviceSpec).math.abs(aOps).asOutput(); try (Result t = session.runner().fetch(absOps).run()) { - assertEquals(1, ((TInt32)t.get(0)).getInt()); + assertEquals(1, ((TInt32) t.get(0)).getInt()); } } } @Test public void withEmptyDeviceSpec() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + ConfigProto config = + ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) .setLogDevicePlacement(true) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { + try (Graph g = new Graph(); + Session session = new Session(g, config)) { Ops tf = Ops.create(g).withSubScope("testScope"); Constant aOps = tf.constant(-1); - DeviceSpec deviceSpec = DeviceSpec.newBuilder() + DeviceSpec deviceSpec = + DeviceSpec.newBuilder() .job("localhost") .replica(0) .task(0) .deviceType(DeviceSpec.DeviceType.CPU) .build(); - Output absOps = tf - .withName("absWithDevice") - .withDevice(deviceSpec) - .math - .abs(aOps) - .asOutput(); + Output absOps = + tf.withName("absWithDevice").withDevice(deviceSpec).math.abs(aOps).asOutput(); try (Result t = session.runner().fetch(absOps).run()) { - assertEquals(1, ((TInt32)t.get(0)).getInt()); + assertEquals(1, ((TInt32) t.get(0)).getInt()); } } } @Test public void withTwoScopes() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + ConfigProto config = + ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) .setLogDevicePlacement(true) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - DeviceSpec deviceSpec1 = DeviceSpec.newBuilder() + try (Graph g = new Graph(); + Session session = new Session(g, config)) { + DeviceSpec deviceSpec1 = + DeviceSpec.newBuilder() .job("localhost") .replica(0) .task(0) .deviceType(DeviceSpec.DeviceType.CPU) .build(); - DeviceSpec deviceSpec2 = DeviceSpec.newBuilder() + DeviceSpec deviceSpec2 = + DeviceSpec.newBuilder() .job("localhost") .replica(0) .task(0) @@ -117,32 +119,27 @@ public void withTwoScopes() { Constant aOps = tf1.constant(-1); Constant bOps = tf2.constant(10); - Output absOps = tf1 - .withName("absWithDevice") - .math - .abs(aOps) - .asOutput(); + Output absOps = tf1.withName("absWithDevice").math.abs(aOps).asOutput(); - Output mulOps = tf2 - .withName("mulWithDevice") - .math - .mul(absOps, bOps) - .asOutput(); + Output mulOps = tf2.withName("mulWithDevice").math.mul(absOps, bOps).asOutput(); try (Result t = session.runner().fetch(mulOps).run()) { - assertEquals(10, ((TInt32)t.get(0)).getInt()); + assertEquals(10, ((TInt32) t.get(0)).getInt()); } } } @Test public void withIncorrectDeviceSpec() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + ConfigProto config = + ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) .setLogDevicePlacement(true) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - DeviceSpec correctDeviceSpec = DeviceSpec.newBuilder() + try (Graph g = new Graph(); + Session session = new Session(g, config)) { + DeviceSpec correctDeviceSpec = + DeviceSpec.newBuilder() .job("localhost") .replica(0) .task(0) @@ -150,7 +147,8 @@ public void withIncorrectDeviceSpec() { .build(); // Incorrect device spec, it will never be executed - DeviceSpec incorrectDeviceSpec = DeviceSpec.newBuilder() + DeviceSpec incorrectDeviceSpec = + DeviceSpec.newBuilder() .job("UNKNOWN") .replica(1) .task(1000) @@ -162,15 +160,11 @@ public void withIncorrectDeviceSpec() { Constant aOps = tf.constant(-1); Constant bOps = tf.constant(10); - Output absOps = tf - .withName("absWithDevice") - .withDevice(incorrectDeviceSpec) - .math - .abs(aOps) - .asOutput(); + Output absOps = + tf.withName("absWithDevice").withDevice(incorrectDeviceSpec).math.abs(aOps).asOutput(); - Output mulOps = tf - .withName("mulWithDevice") + Output mulOps = + tf.withName("mulWithDevice") .withDevice(correctDeviceSpec) .math .mul(absOps, bOps) @@ -186,12 +180,15 @@ public void withIncorrectDeviceSpec() { @Test public void withDeviceSpecInScope() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + ConfigProto config = + ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) .setLogDevicePlacement(true) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - DeviceSpec deviceSpec = DeviceSpec.newBuilder() + try (Graph g = new Graph(); + Session session = new Session(g, config)) { + DeviceSpec deviceSpec = + DeviceSpec.newBuilder() .job("localhost") .replica(0) .task(0) @@ -202,14 +199,10 @@ public void withDeviceSpecInScope() { Constant aOps = tf.constant(-1); - Output absOps = tf - .withName("absWithDevice") - .math - .abs(aOps) - .asOutput(); + Output absOps = tf.withName("absWithDevice").math.abs(aOps).asOutput(); try (Result t = session.runner().fetch(absOps).run()) { - assertEquals(1, ((TInt32)t.get(0)).getInt()); + assertEquals(1, ((TInt32) t.get(0)).getInt()); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index c8055ddae14..ff691e30adb 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -84,7 +84,7 @@ public void graphDefRoundTripWithInit() { Operand variable2 = init.withName("var2").variable(init.constant(4)); try (Session s = new Session(g, true); - Result results = s.runner().fetch("result").fetch("var2").run()) { + Result results = s.runner().fetch("result").fetch("var2").run()) { TInt32 result = (TInt32) results.get(0); assertEquals(6, result.getInt()); @@ -263,13 +263,14 @@ public void addGradientsToGraph() { try (TFloat32 c1 = TFloat32.scalarOf(3.0f); TFloat32 c2 = TFloat32.scalarOf(2.0f); - Result outputs = s.runner() - .feed(x1, c1) - .feed(x2, c2) - .fetch(grads0[0]) - .fetch(grads1[0]) - .fetch(grads1[1]) - .run()) { + Result outputs = + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run()) { assertEquals(3, outputs.size()); assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); assertEquals(6.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f); @@ -413,12 +414,13 @@ public void buildWhileLoopMultipleInputs() { try (TInt32 c1 = TInt32.scalarOf(2); TInt32 c2 = TInt32.scalarOf(5); - Result outputs = s.runner() - .feed(input1, c1) - .feed(input2, c2) - .fetch(loopOutputs[0]) - .fetch(loopOutputs[1]) - .run()) { + Result outputs = + s.runner() + .feed(input1, c1) + .feed(input2, c2) + .fetch(loopOutputs[0]) + .fetch(loopOutputs[1]) + .run()) { assertEquals(2, outputs.size()); assertEquals(16, ((TInt32) outputs.get(0)).getInt()); // ((2^2)^2) assertEquals(625, ((TInt32) outputs.get(1)).getInt()); // ((5^2)^2) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index a5191182ffb..deff52ffbeb 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -215,7 +215,10 @@ public void exportFunctionWithVariables() throws IOException { // Now call the same function directly from the model try (TFloat32 zTensor = (TFloat32) - savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum").get()) { + savedModel + .call(Collections.singletonMap("input", xTensor)) + .get("reducedSum") + .get()) { assertEquals(reducedSum, zTensor.getFloat(), EPSILON); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index aefb048db7e..918ccac5fe2 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -16,7 +16,6 @@ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -27,7 +26,6 @@ import java.util.Comparator; import java.util.Iterator; import java.util.Optional; - import org.junit.jupiter.api.Test; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -228,7 +226,7 @@ public void saveAndRestore() throws IOException { try (Session restoredSession = new Session(restoredGraph)) { restoredSession.restore(testFolder.resolve("checkpoint").toString()); try (Result oldList = s.runner().fetch("x").fetch("y").run(); - Result newList = restoredSession.runner().fetch("x").fetch("y").run()) { + Result newList = restoredSession.runner().fetch("x").fetch("y").run()) { assertEquals(oldList.get(0), newList.get(0)); assertEquals(oldList.get(1), newList.get(1)); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index 7ba4c4da26c..fb52b2d1059 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -47,11 +47,10 @@ public void createGradients() { assertEquals(2, grads.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - Result outputs = sess.runner() - .feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run()) { + Result outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run()) { - assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); - assertEquals(18.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); + assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); + assertEquals(18.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f); } } } @@ -75,7 +74,7 @@ public void createGradientsWithSum() { try (TFloat32 c = TFloat32.scalarOf(3.0f); Result outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run()) { - assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); + assertEquals(114.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); } } } @@ -91,17 +90,17 @@ public void createGradientsWithInitialValues() { Output y1 = tf.math.square(y0).y(); Gradients grads0 = Gradients.create(tf.scope(), y1, Arrays.asList(y0)); - Gradients grads1 = Gradients.create(tf.scope(), y0, Arrays.asList(x), Gradients.dx(grads0.dy())); + Gradients grads1 = + Gradients.create(tf.scope(), y0, Arrays.asList(x), Gradients.dx(grads0.dy())); assertNotNull(grads1); assertNotNull(grads1.dy()); assertEquals(1, grads1.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - Result outputs = - sess.runner().feed(x, c).fetch(grads1.dy(0)).run()) { + Result outputs = sess.runner().feed(x, c).fetch(grads1.dy(0)).run()) { - assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); + assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java index 7cb3ed1cc18..e75bdde766e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java @@ -78,10 +78,10 @@ public void testGraphIteration() { int batches = 0; while (true) { try (Result outputs = session.runner().fetch(X).fetch(y).run()) { - assertEquals(mapped1.get(batches), outputs.get(0)); - assertEquals(mapped2.get(batches), outputs.get(1)); + assertEquals(mapped1.get(batches), outputs.get(0)); + assertEquals(mapped2.get(batches), outputs.get(1)); - batches++; + batches++; } catch (TFOutOfRangeException e) { break; } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index 278c9dbf1c4..fc1e2fe9573 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.Result; @@ -27,8 +29,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import static org.junit.jupiter.api.Assertions.assertThrows; - public class AssertBroadcastableTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java index a8ff95a5e15..9df29436e31 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java @@ -14,6 +14,10 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.Result; @@ -26,11 +30,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import java.util.concurrent.atomic.AtomicInteger; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - public class BroadcastWeightsTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -100,33 +99,34 @@ private void testValid( TInt32 intT = (TInt32) result.get(0); AtomicInteger i = new AtomicInteger(); intT.scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expected[i.getAndIncrement()].intValue(), f.getInt())); + .forEachIndexed( + (idx, f) -> assertEquals(expected[i.getAndIncrement()].intValue(), f.getInt())); } else if (type.equals(TInt64.class)) { TInt64 floatT = (TInt64) result.get(0); AtomicInteger i = new AtomicInteger(); floatT - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expected[i.getAndIncrement()].longValue(), f.getLong())); + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals(expected[i.getAndIncrement()].longValue(), f.getLong())); } else if (type.equals(TFloat32.class)) { TFloat32 floatT = (TFloat32) result.get(0); AtomicInteger i = new AtomicInteger(); floatT - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals( - expected[i.getAndIncrement()].floatValue(), f.getFloat(), 1e-5F)); + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals( + expected[i.getAndIncrement()].floatValue(), f.getFloat(), 1e-5F)); } else if (type.equals(TFloat64.class)) { TFloat64 doubleT = (TFloat64) result.get(0); AtomicInteger i = new AtomicInteger(); doubleT - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals( - expected[i.getAndIncrement()].doubleValue(), f.getDouble(), 1e-5F)); + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals( + expected[i.getAndIncrement()].doubleValue(), f.getDouble(), 1e-5F)); } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index da48aee9f78..a59f67f5a99 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -5,7 +5,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; - import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -192,7 +191,8 @@ public void testDeterminism() { g.importGraphDef(def); s.initialize(); - Result initializationRes = s.runner() + Result initializationRes = + s.runner() .fetch(fcWeightName) .fetch(fcBiasName) .fetch(outputWeightName) @@ -216,7 +216,8 @@ public void testDeterminism() { initialLoss[i] = lossVal.getFloat(); lossVal.close(); - Result trainedRes = s.runner() + Result trainedRes = + s.runner() .fetch(fcWeightName) .fetch(fcBiasName) .fetch(outputWeightName)