|
17 | 17 | */
|
18 | 18 | package org.tensorflow;
|
19 | 19 |
|
20 |
| -import org.tensorflow.exceptions.TensorFlowException; |
21 |
| -import org.tensorflow.proto.framework.RunMetadata; |
22 |
| - |
23 | 20 | import java.util.ArrayList;
|
24 | 21 | import java.util.Collections;
|
25 | 22 | import java.util.Iterator;
|
|
30 | 27 | import java.util.Set;
|
31 | 28 | import java.util.logging.Level;
|
32 | 29 | import java.util.logging.Logger;
|
| 30 | +import org.tensorflow.exceptions.TensorFlowException; |
| 31 | +import org.tensorflow.proto.framework.RunMetadata; |
33 | 32 |
|
34 | 33 | /**
|
35 | 34 | * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s.
|
36 | 35 | *
|
37 |
| - * <p>When this is closed it closes all the {@link Tensor}s inside it. If you maintain a |
38 |
| - * reference to a value after this object has been closed it will throw an {@link |
39 |
| - * IllegalStateException} upon access. |
| 36 | + * <p>When this is closed it closes all the {@link Tensor}s inside it. If you maintain a reference |
| 37 | + * to a value after this object has been closed it will throw an {@link IllegalStateException} upon |
| 38 | + * access. |
40 | 39 | *
|
41 |
| - * <p>This class is not thread-safe with respect to the close operation. Multiple closers |
42 |
| - * or one thread closing a tensor while another is reading may throw exceptions. |
| 40 | + * <p>This class is not thread-safe with respect to the close operation. Multiple closers or one |
| 41 | + * thread closing a tensor while another is reading may throw exceptions. |
43 | 42 | *
|
44 |
| - * <p>Note this class is used to manage the lifetimes of tensors produced by the |
45 |
| - * TensorFlow runtime, from sessions and function calls. It is not used as an argument |
46 |
| - * to {@code session.run} or function calls as users are in control of the creation |
47 |
| - * of input tensors. |
| 43 | + * <p>Note this class is used to manage the lifetimes of tensors produced by the TensorFlow runtime, |
| 44 | + * from sessions and function calls. It is not used as an argument to {@code session.run} or |
| 45 | + * function calls as users are in control of the creation of input tensors. |
48 | 46 | */
|
49 | 47 | public final class Result implements AutoCloseable, Iterable<Map.Entry<String, Tensor>> {
|
50 |
| - @Override |
51 |
| - public void close() { |
52 |
| - if (!closed) { |
53 |
| - for (Tensor t : list) { |
54 |
| - try { |
55 |
| - t.close(); |
56 |
| - } catch (TensorFlowException e) { |
57 |
| - logger.log(Level.WARNING, "Exception raised when closing tensor inside result.", e); |
58 |
| - } |
59 |
| - } |
60 |
| - closed = true; |
61 |
| - } else { |
62 |
| - logger.warning("Closing an already closed Result"); |
| 48 | + @Override |
| 49 | + public void close() { |
| 50 | + if (!closed) { |
| 51 | + for (Tensor t : list) { |
| 52 | + try { |
| 53 | + t.close(); |
| 54 | + } catch (TensorFlowException e) { |
| 55 | + logger.log(Level.WARNING, "Exception raised when closing tensor inside result.", e); |
63 | 56 | }
|
| 57 | + } |
| 58 | + closed = true; |
| 59 | + } else { |
| 60 | + logger.warning("Closing an already closed Result"); |
64 | 61 | }
|
65 |
| - |
66 |
| - @Override |
67 |
| - public Iterator<Map.Entry<String, Tensor>> iterator() { |
68 |
| - if (!closed) { |
69 |
| - return map.entrySet().iterator(); |
70 |
| - } else { |
71 |
| - throw new IllegalStateException("Result is closed"); |
72 |
| - } |
| 62 | + } |
| 63 | + |
| 64 | + @Override |
| 65 | + public Iterator<Map.Entry<String, Tensor>> iterator() { |
| 66 | + if (!closed) { |
| 67 | + return map.entrySet().iterator(); |
| 68 | + } else { |
| 69 | + throw new IllegalStateException("Result is closed"); |
73 | 70 | }
|
74 |
| - |
75 |
| - /** |
76 |
| - * Returns the number of outputs in this Result. |
77 |
| - * |
78 |
| - * @return The number of outputs. |
79 |
| - */ |
80 |
| - public int size() { |
81 |
| - return map.size(); |
| 71 | + } |
| 72 | + |
| 73 | + /** |
| 74 | + * Returns the number of outputs in this Result. |
| 75 | + * |
| 76 | + * @return The number of outputs. |
| 77 | + */ |
| 78 | + public int size() { |
| 79 | + return map.size(); |
| 80 | + } |
| 81 | + |
| 82 | + /** |
| 83 | + * Gets the set containing all the tensor names. |
| 84 | + * |
| 85 | + * @return The tensor names set. |
| 86 | + */ |
| 87 | + public Set<String> keySet() { |
| 88 | + return Collections.unmodifiableSet(map.keySet()); |
| 89 | + } |
| 90 | + |
| 91 | + /** |
| 92 | + * Does this result object have a tensor for the supplied key? |
| 93 | + * |
| 94 | + * @param key The key to check. |
| 95 | + * @return True if this result object has a tensor for this key. |
| 96 | + */ |
| 97 | + public boolean containsKey(String key) { |
| 98 | + return map.containsKey(key); |
| 99 | + } |
| 100 | + |
| 101 | + /** |
| 102 | + * Gets the value from the container at the specified index. |
| 103 | + * |
| 104 | + * <p>Throws {@link IllegalStateException} if the container has been closed, and {@link |
| 105 | + * IndexOutOfBoundsException} if the index is invalid. |
| 106 | + * |
| 107 | + * @param index The index to lookup. |
| 108 | + * @return The value at the index. |
| 109 | + */ |
| 110 | + public Tensor get(int index) { |
| 111 | + if (!closed) { |
| 112 | + return list.get(index); |
| 113 | + } else { |
| 114 | + throw new IllegalStateException("Result is closed"); |
82 | 115 | }
|
83 |
| - |
84 |
| - /** |
85 |
| - * Gets the set containing all the tensor names. |
86 |
| - * @return The tensor names set. |
87 |
| - */ |
88 |
| - public Set<String> keySet() { |
89 |
| - return Collections.unmodifiableSet(map.keySet()); |
| 116 | + } |
| 117 | + |
| 118 | + /** |
| 119 | + * Gets the value from the container assuming it's not been closed. |
| 120 | + * |
| 121 | + * <p>Throws {@link IllegalStateException} if the container has been closed. |
| 122 | + * |
| 123 | + * @param key The key to lookup. |
| 124 | + * @return Optional.of the value if it exists. |
| 125 | + */ |
| 126 | + public Optional<Tensor> get(String key) { |
| 127 | + if (!closed) { |
| 128 | + return Optional.ofNullable(map.get(key)); |
| 129 | + } else { |
| 130 | + throw new IllegalStateException("Result is closed"); |
90 | 131 | }
|
91 |
| - |
92 |
| - /** |
93 |
| - * Does this result object have a tensor for the supplied key? |
94 |
| - * @param key The key to check. |
95 |
| - * @return True if this result object has a tensor for this key. |
96 |
| - */ |
97 |
| - public boolean containsKey(String key) { |
98 |
| - return map.containsKey(key); |
| 132 | + } |
| 133 | + |
| 134 | + /** |
| 135 | + * Metadata about the run. |
| 136 | + * |
| 137 | + * <p>A <a |
| 138 | + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata |
| 139 | + * protocol buffer</a>. |
| 140 | + */ |
| 141 | + public Optional<RunMetadata> getMetadata() { |
| 142 | + return Optional.ofNullable(metadata); |
| 143 | + } |
| 144 | + |
| 145 | + /** |
| 146 | + * Creates a Result from the names and values produced by {@link Session.Runner#run()}. |
| 147 | + * |
| 148 | + * @param names The output names. |
| 149 | + * @param values The output values. |
| 150 | + * @param metadata The run metadata, may be null. |
| 151 | + */ |
| 152 | + Result(List<String> names, List<Tensor> values, RunMetadata metadata) { |
| 153 | + this.map = new LinkedHashMap<>(); |
| 154 | + this.list = new ArrayList<>(values); |
| 155 | + |
| 156 | + if (names.size() != values.size()) { |
| 157 | + throw new IllegalArgumentException( |
| 158 | + "Expected same number of names and values, found names.length = " |
| 159 | + + names.size() |
| 160 | + + ", values.length = " |
| 161 | + + values.size()); |
99 | 162 | }
|
100 | 163 |
|
101 |
| - /** |
102 |
| - * Gets the value from the container at the specified index. |
103 |
| - * |
104 |
| - * <p>Throws {@link IllegalStateException} if the container has been closed, and {@link |
105 |
| - * IndexOutOfBoundsException} if the index is invalid. |
106 |
| - * |
107 |
| - * @param index The index to lookup. |
108 |
| - * @return The value at the index. |
109 |
| - */ |
110 |
| - public Tensor get(int index) { |
111 |
| - if (!closed) { |
112 |
| - return list.get(index); |
113 |
| - } else { |
114 |
| - throw new IllegalStateException("Result is closed"); |
115 |
| - } |
| 164 | + for (int i = 0; i < names.size(); i++) { |
| 165 | + Tensor old = this.map.put(names.get(i), values.get(i)); |
| 166 | + if (old != null) { |
| 167 | + throw new IllegalArgumentException( |
| 168 | + "Name collision in the result set, two outputs are named '" + names.get(i) + "'"); |
| 169 | + } |
116 | 170 | }
|
117 |
| - |
118 |
| - /** |
119 |
| - * Gets the value from the container assuming it's not been closed. |
120 |
| - * |
121 |
| - * <p>Throws {@link IllegalStateException} if the container has been closed. |
122 |
| - * |
123 |
| - * @param key The key to lookup. |
124 |
| - * @return Optional.of the value if it exists. |
125 |
| - */ |
126 |
| - public Optional<Tensor> get(String key) { |
127 |
| - if (!closed) { |
128 |
| - return Optional.ofNullable(map.get(key)); |
129 |
| - } else { |
130 |
| - throw new IllegalStateException("Result is closed"); |
131 |
| - } |
132 |
| - } |
133 |
| - |
134 |
| - /** |
135 |
| - * Metadata about the run. |
136 |
| - * |
137 |
| - * <p>A <a |
138 |
| - * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata |
139 |
| - * protocol buffer</a>. |
140 |
| - */ |
141 |
| - public Optional<RunMetadata> getMetadata() { |
142 |
| - return Optional.ofNullable(metadata); |
143 |
| - } |
144 |
| - |
145 |
| - /** |
146 |
| - * Creates a Result from the names and values produced by {@link Session.Runner#run()}. |
147 |
| - * |
148 |
| - * @param names The output names. |
149 |
| - * @param values The output values. |
150 |
| - * @param metadata The run metadata, may be null. |
151 |
| - */ |
152 |
| - Result(List<String> names, List<Tensor> values, RunMetadata metadata) { |
153 |
| - this.map = new LinkedHashMap<>(); |
154 |
| - this.list = new ArrayList<>(values); |
155 |
| - |
156 |
| - if (names.size() != values.size()) { |
157 |
| - throw new IllegalArgumentException( |
158 |
| - "Expected same number of names and values, found names.length = " |
159 |
| - + names.size() |
160 |
| - + ", values.length = " |
161 |
| - + values.size()); |
162 |
| - } |
163 |
| - |
164 |
| - for (int i = 0; i < names.size(); i++) { |
165 |
| - Tensor old = this.map.put(names.get(i), values.get(i)); |
166 |
| - if (old != null) { |
167 |
| - throw new IllegalArgumentException("Name collision in the result set, two outputs are named '" + names.get(i) + "'"); |
168 |
| - } |
169 |
| - } |
170 |
| - this.metadata = metadata; |
171 |
| - this.closed = false; |
172 |
| - } |
173 |
| - |
174 |
| - /** |
175 |
| - * Creates a Result from the names and values. |
176 |
| - * |
177 |
| - * @param outputs The run outputs. |
178 |
| - */ |
179 |
| - Result(LinkedHashMap<String,Tensor> outputs) { |
180 |
| - this.map = outputs; |
181 |
| - this.list = new ArrayList<>(outputs.size()); |
182 |
| - for (Map.Entry<String, Tensor> e : outputs.entrySet()) { |
183 |
| - list.add(e.getValue()); |
184 |
| - } |
185 |
| - this.metadata = null; |
186 |
| - this.closed = false; |
| 171 | + this.metadata = metadata; |
| 172 | + this.closed = false; |
| 173 | + } |
| 174 | + |
| 175 | + /** |
| 176 | + * Creates a Result from the names and values. |
| 177 | + * |
| 178 | + * @param outputs The run outputs. |
| 179 | + */ |
| 180 | + Result(LinkedHashMap<String, Tensor> outputs) { |
| 181 | + this.map = outputs; |
| 182 | + this.list = new ArrayList<>(outputs.size()); |
| 183 | + for (Map.Entry<String, Tensor> e : outputs.entrySet()) { |
| 184 | + list.add(e.getValue()); |
187 | 185 | }
|
| 186 | + this.metadata = null; |
| 187 | + this.closed = false; |
| 188 | + } |
188 | 189 |
|
189 |
| - private final Map<String, Tensor> map; |
| 190 | + private final Map<String, Tensor> map; |
190 | 191 |
|
191 |
| - private final List<Tensor> list; |
| 192 | + private final List<Tensor> list; |
192 | 193 |
|
193 |
| - private final RunMetadata metadata; |
| 194 | + private final RunMetadata metadata; |
194 | 195 |
|
195 |
| - private boolean closed; |
| 196 | + private boolean closed; |
196 | 197 |
|
197 |
| - private static final Logger logger = Logger.getLogger(Result.class.getName()); |
| 198 | + private static final Logger logger = Logger.getLogger(Result.class.getName()); |
198 | 199 | }
|
0 commit comments