Skip to content

Adds a closeable session result #411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 8, 2022
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020-2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -295,8 +295,8 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
}

@Override
public Map<String, Tensor> call(Map<String, Tensor> arguments) {
// FIXME need to manage input/output operand lifetimes
public Result call(Map<String, Tensor> arguments) {
// FIXME need to manage input operand lifetimes
Ops tf = Ops.create();
Map<String, Operand<?>> inputs = new LinkedHashMap<>(arguments.size());

Expand All @@ -305,11 +305,11 @@ public Map<String, Tensor> call(Map<String, Tensor> arguments) {
inputs.put(inputName, tf.constantOf((TType) argument));
}
Map<String, Operand<?>> outputs = tf.call(this, inputs);
Map<String, Tensor> tensorOutputs = new LinkedHashMap<>(outputs.size());
LinkedHashMap<String, Tensor> tensorOutputs = new LinkedHashMap<>(outputs.size());
for (String outputName : outputs.keySet()) {
tensorOutputs.put(outputName, outputs.get(outputName).asTensor());
}
return tensorOutputs;
return new Result(tensorOutputs);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
Copyright 2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
package org.tensorflow;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.proto.framework.RunMetadata;

/**
* An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s.
*
* <p>When this is closed it closes all the {@link Tensor}s inside it. If you maintain a reference
* to a value after this object has been closed it will throw an {@link IllegalStateException} upon
* access.
*
* <p>This class is not thread-safe with respect to the close operation. Multiple closers or one
* thread closing a tensor while another is reading may throw exceptions.
*
* <p>Note this class is used to manage the lifetimes of tensors produced by the TensorFlow runtime,
* from sessions and function calls. It is not used as an argument to {@code session.run} or
* function calls as users are in control of the creation of input tensors.
*/
public final class Result implements AutoCloseable, Iterable<Map.Entry<String, Tensor>> {
@Override
public void close() {
if (!closed) {
for (Tensor t : list) {
try {
t.close();
} catch (TensorFlowException e) {
logger.log(Level.WARNING, "Exception raised when closing tensor inside result.", e);
}
}
closed = true;
} else {
logger.warning("Closing an already closed Result");
}
}

@Override
public Iterator<Map.Entry<String, Tensor>> iterator() {
if (!closed) {
return map.entrySet().iterator();
} else {
throw new IllegalStateException("Result is closed");
}
}

/**
* Returns the number of outputs in this Result.
*
* @return The number of outputs.
*/
public int size() {
return map.size();
}

/**
* Gets the set containing all the tensor names.
*
* @return The tensor names set.
*/
public Set<String> keySet() {
return Collections.unmodifiableSet(map.keySet());
}

/**
* Does this result object have a tensor for the supplied key?
*
* @param key The key to check.
* @return True if this result object has a tensor for this key.
*/
public boolean containsKey(String key) {
return map.containsKey(key);
}

/**
* Gets the value from the container at the specified index.
*
* <p>Throws {@link IllegalStateException} if the container has been closed, and {@link
* IndexOutOfBoundsException} if the index is invalid.
*
* @param index The index to lookup.
* @return The value at the index.
*/
public Tensor get(int index) {
if (!closed) {
return list.get(index);
} else {
throw new IllegalStateException("Result is closed");
}
}

/**
* Gets the value from the container assuming it's not been closed.
*
* <p>Throws {@link IllegalStateException} if the container has been closed.
*
* @param key The key to lookup.
* @return Optional.of the value if it exists.
*/
public Optional<Tensor> get(String key) {
if (!closed) {
return Optional.ofNullable(map.get(key));
} else {
throw new IllegalStateException("Result is closed");
}
}

/**
* Metadata about the run.
*
* <p>A <a
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
* protocol buffer</a>.
*/
public Optional<RunMetadata> getMetadata() {
return Optional.ofNullable(metadata);
}

/**
* Creates a Result from the names and values produced by {@link Session.Runner#run()}.
*
* @param names The output names.
* @param values The output values.
* @param metadata The run metadata, may be null.
*/
Result(List<String> names, List<Tensor> values, RunMetadata metadata) {
this.map = new LinkedHashMap<>();
this.list = new ArrayList<>(values);

if (names.size() != values.size()) {
throw new IllegalArgumentException(
"Expected same number of names and values, found names.length = "
+ names.size()
+ ", values.length = "
+ values.size());
}

for (int i = 0; i < names.size(); i++) {
Tensor old = this.map.put(names.get(i), values.get(i));
if (old != null) {
throw new IllegalArgumentException(
"Name collision in the result set, two outputs are named '" + names.get(i) + "'");
}
}
this.metadata = metadata;
this.closed = false;
}

/**
* Creates a Result from the names and values.
*
* @param outputs The run outputs.
*/
Result(LinkedHashMap<String, Tensor> outputs) {
this.map = outputs;
this.list = new ArrayList<>(outputs.size());
for (Map.Entry<String, Tensor> e : outputs.entrySet()) {
list.add(e.getValue());
}
this.metadata = null;
this.closed = false;
}

private final Map<String, Tensor> map;

private final List<Tensor> list;

private final RunMetadata metadata;

private boolean closed;

private static final Logger logger = Logger.getLogger(Result.class.getName());
}
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ public List<SessionFunction> functions() {
* @return list of output tensors, mapped by the signature name
* @throws IllegalArgumentException if no function can be selected by default
*/
public Map<String, Tensor> call(Map<String, Tensor> arguments) {
public Result call(Map<String, Tensor> arguments) {
SessionFunction function = null;
if (functions.size() == 1) {
function = functions.values().iterator().next();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2019-2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -306,7 +306,9 @@ public Runner feed(Operand<?> operand, Tensor t) {
* @throws IllegalArgumentException if no output exists with the provided name
*/
public Runner fetch(String operation) {
return fetch(graph.outputOrThrow(operation));
Runner r = fetch(graph.outputOrThrow(operation), false);
outputNames.add(operation);
return r;
}

/**
Expand Down Expand Up @@ -336,6 +338,20 @@ public Runner fetch(String operation, int index) {
* @return this session runner
*/
public Runner fetch(Output<?> output) {
return fetch(output, true);
}

/**
* Makes {@link #run()} return the Tensor referred to by {@code output}.
*
* <p>If {@code output} is a resource variable, will fetch the value.
*
* @param output the node to fetch the tensor from
* @param recordName Records the output name. If false the output name must be recorded by the
* calling method as otherwise the result object will throw on construction.
* @return this session runner
*/
private Runner fetch(Output<?> output, boolean recordName) {
if (output.env() != graph) {
throw new IllegalStateException(
"Can't fetch output "
Expand Down Expand Up @@ -378,6 +394,9 @@ public Runner fetch(Output<?> output) {
} else {
outputs.add(output);
}
if (recordName) {
outputNames.add(output.name());
}
return this;
}

Expand Down Expand Up @@ -490,13 +509,13 @@ private void doInit() {
*
* @return list of resulting tensors fetched by this session runner
*/
public List<Tensor> run() {
public Result run() {
doInit();
return runNoInit();
}

List<Tensor> runNoInit() {
return runHelper(false).outputs;
Result runNoInit() {
return runHelper(false);
}

/**
Expand All @@ -509,12 +528,12 @@ List<Tensor> runNoInit() {
*
* @return list of resulting tensors fetched by this session runner, with execution metadata
*/
public Run runAndFetchMetadata() {
public Result runAndFetchMetadata() {
doInit();
return runHelper(true);
}

private Run runHelper(boolean wantMetadata) {
private Result runHelper(boolean wantMetadata) {
TF_Tensor[] inputTensorHandles = new TF_Tensor[inputTensors.size()];
TF_Operation[] inputOpHandles = new TF_Operation[inputs.size()];
int[] inputOpIndices = new int[inputs.size()];
Expand Down Expand Up @@ -569,10 +588,7 @@ private Run runHelper(boolean wantMetadata) {
} finally {
runRef.close();
}
Run ret = new Run();
ret.outputs = outputs;
ret.metadata = metadata;
return ret;
return new Result(outputNames, outputs, metadata);
}

private class Reference implements AutoCloseable {
Expand Down Expand Up @@ -602,6 +618,7 @@ public void close() {
private final ArrayList<Output<?>> inputs = new ArrayList<>();
private final ArrayList<Tensor> inputTensors = new ArrayList<>();
private final ArrayList<Output<?>> outputs = new ArrayList<>();
private final ArrayList<String> outputNames = new ArrayList<>();
private final ArrayList<GraphOperation> targets = new ArrayList<>();
private RunOptions runOptions = null;
}
Expand Down Expand Up @@ -648,8 +665,9 @@ public SessionFunction function(Signature signature) {
*
* @param signature the signature of the function
* @param arguments the arguments to call with.
* @return The results of the function call.
*/
public Map<String, Tensor> run(Signature signature, Map<String, Tensor> arguments) {
public Result run(Signature signature, Map<String, Tensor> arguments) {
return function(signature).call(arguments);
}

Expand Down Expand Up @@ -698,26 +716,6 @@ public void restore(String prefix) {
setInitialized();
}

/**
* Output tensors and metadata obtained when executing a session.
*
* <p>See {@link Runner#runAndFetchMetadata()}
*/
public static final class Run {

/** Tensors from requested fetches. */
public List<Tensor> outputs;

/**
* Metadata about the run.
*
* <p>A <a
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
* protocol buffer</a>.
*/
public RunMetadata metadata;
}

Graph graph() {
return graph;
}
Expand Down
Loading