Skip to content

Commit 8cd74fc

Browse files
committed
Add validations on signatures and saved models
1 parent c356139 commit 8cd74fc

File tree

8 files changed

+179
-42
lines changed

8 files changed

+179
-42
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,11 @@ public Tensor<?> call(Tensor<?> tensor) throws IllegalArgumentException {
231231
*
232232
* <p>This method is convenient shortcut equivalent to
233233
* {@code SavedModel.exporter(exportDir).withFunction(this).export()}
234+
*
235+
* @throws IOException if saved model or variable state cannot be written on disk
234236
*/
235237
public void save(String exportDir) throws IOException {
236-
SavedModelBundle.exporter(exportDir)
237-
.withFunction(this)
238-
.export();
238+
SavedModelBundle.exporter(exportDir).withFunction(this).export();
239239
}
240240

241241
/**

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
import java.io.OutputStream;
2626
import java.nio.file.Path;
2727
import java.nio.file.Paths;
28-
import java.util.ArrayList;
2928
import java.util.Arrays;
3029
import java.util.HashMap;
30+
import java.util.LinkedHashMap;
3131
import java.util.List;
3232
import java.util.Map;
3333
import java.util.stream.Collectors;
@@ -45,6 +45,7 @@
4545
import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef;
4646
import org.tensorflow.proto.framework.RunOptions;
4747
import org.tensorflow.proto.framework.SavedModel;
48+
import org.tensorflow.proto.util.SaverDef;
4849

4950
/**
5051
* SavedModelBundle represents a model loaded from storage.
@@ -73,6 +74,7 @@ public SavedModelBundle load() {
7374
* @param options A <a
7475
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions
7576
* protocol buffer</a>.
77+
* @return this object
7678
*/
7779
public Loader withRunOptions(RunOptions options) {
7880
this.runOptions = options;
@@ -85,6 +87,7 @@ public Loader withRunOptions(RunOptions options) {
8587
* @param configProto A <a
8688
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto
8789
* protocol buffer</a>.
90+
* @return this object
8891
*/
8992
public Loader withConfigProto(ConfigProto configProto) {
9093
this.configProto = configProto;
@@ -97,11 +100,12 @@ public Loader withConfigProto(ConfigProto configProto) {
97100
* <p>Has no effect if {@code tags} is null or empty
98101
*
99102
* @param tags the tags identifying the specific MetaGraphDef to load.
103+
* @return this object
104+
* @throws IllegalArgumentException if tags are invalid
100105
*/
101106
public Loader withTags(String... tags) {
102-
if (tags != null && tags.length > 0) {
103-
this.tags = tags;
104-
}
107+
validateTags(tags);
108+
this.tags = tags;
105109
return this;
106110
}
107111

@@ -110,7 +114,7 @@ private Loader(String exportDir) {
110114
}
111115

112116
private String exportDir = null;
113-
private String[] tags = {DEFAULT_TAG};
117+
private String[] tags = { DEFAULT_TAG };
114118
private ConfigProto configProto = null;
115119
private RunOptions runOptions = null;
116120
}
@@ -125,9 +129,11 @@ public static final class Exporter {
125129
*
126130
* @param tags the tags identifying the specific MetaGraphDef to save.
127131
* @return this object
132+
* @throws IllegalArgumentException if tags are invalid
128133
*/
129134
public Exporter withTags(String... tags) {
130-
this.tags.addAll(Arrays.asList(tags));
135+
validateTags(tags);
136+
this.tags = tags;
131137
return this;
132138
}
133139

@@ -148,6 +154,8 @@ public Exporter withTags(String... tags) {
148154
* @param function a function carrying a signature and a valid session to the graph to be saved
149155
* @return this object
150156
* @throws IllegalArgumentException if a function with the same name has already been added to the model
157+
* @throws UnsupportedOperationException if this function does not share the same session with the other
158+
* functions added to this model
151159
*/
152160
public Exporter withFunction(ConcreteFunction function) {
153161
Signature signature = function.signature();
@@ -167,22 +175,22 @@ public Exporter withFunction(ConcreteFunction function) {
167175
/**
168176
* Save the model into the export directory.
169177
*
170-
* @throws IOException if saved model or variable state can be written on disk
178+
* @throws IOException if saved model or variable state cannot be written on disk
171179
*/
172180
public void export() throws IOException {
173181
if (functions.isEmpty() || session == null) {
174182
throw new IllegalStateException("Model should contain at least one valid function");
175183
}
176-
if (tags.isEmpty()) {
177-
tags.add(DEFAULT_TAG);
178-
}
184+
Graph graph = session.graph();
185+
179186
// It is imperative to retrieve the graphDef after the saverDef, as the former might add
180187
// new ops to the graph for saving and restoring the variables.
181-
Graph graph = session.graph();
188+
SaverDef saverDef = graph.saverDef();
189+
182190
MetaGraphDef.Builder metaGraphDef = metaGraphDefBuilder
183-
.setSaverDef(graph.saverDef())
191+
.setSaverDef(saverDef)
184192
.setGraphDef(graph.toGraphDef())
185-
.setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(tags));
193+
.setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags)));
186194
functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef()));
187195

188196
// Make sure saved model directories exist
@@ -205,9 +213,9 @@ public void export() throws IOException {
205213
}
206214

207215
private final String exportDir;
208-
private final List<String> tags = new ArrayList<>();
216+
private String[] tags = { DEFAULT_TAG };
209217
private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder();
210-
private final Map<String, ConcreteFunction> functions = new HashMap<>();
218+
private final Map<String, ConcreteFunction> functions = new LinkedHashMap<>();
211219
private Session session;
212220
}
213221

@@ -227,7 +235,11 @@ public void export() throws IOException {
227235
* @return a bundle containing the graph and associated session.
228236
*/
229237
public static SavedModelBundle load(String exportDir, String... tags) {
230-
return loader(exportDir).withTags(tags).load();
238+
Loader loader = loader(exportDir);
239+
if (tags != null && tags.length > 0) {
240+
loader.withTags(tags);
241+
}
242+
return loader.load();
231243
}
232244

233245
/**
@@ -416,6 +428,12 @@ opts, runOpts, new BytePointer(exportDir), new PointerPointer(tags),
416428
return bundle;
417429
}
418430

431+
private static void validateTags(String[] tags) {
432+
if (tags == null || tags.length == 0 || Arrays.stream(tags).anyMatch(t -> t == null || t.isEmpty())) {
433+
throw new IllegalArgumentException("Invalid tags: " + Arrays.toString(tags));
434+
}
435+
}
436+
419437
static {
420438
TensorFlow.init();
421439
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,10 @@ public void run(Op op) {
454454
* <i>mymodel/myvariables/variables</i>, then the generated files will be located under
455455
* <i>mymodel/myvariables</i> and named <i>variables.data-*-of-*</i>
456456
*
457-
* @param prefix
457+
* <p/>Note that this method might alter the underlying graph if it is the first time that one
458+
* of its session is saved, see {@link Graph#saverDef()} for more details.
459+
*
460+
* @param prefix prefix to the variable files to save
458461
*/
459462
public void save(String prefix) {
460463
SaverDef saverDef = graph.saverDef();

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,12 @@ public static class Builder {
4545
*
4646
* @param key signature key
4747
* @return this builder
48+
* @throws IllegalArgumentException if the key is invalid
4849
*/
4950
public Builder key(String key) {
51+
if (key == null || key.isEmpty()) {
52+
throw new IllegalArgumentException("Invalid key: " + key);
53+
}
5054
this.key = key;
5155
return this;
5256
}
@@ -57,8 +61,12 @@ public Builder key(String key) {
5761
* @param inputName user-friendly name for this input tensor
5862
* @param input input tensor
5963
* @return this builder
64+
* @throws IllegalArgumentException if {@code inputName} is already mapped to another input
6065
*/
6166
public Builder input(String inputName, Operand<?> input) {
67+
if (signatureBuilder.containsInputs(inputName)) {
68+
throw new IllegalArgumentException("\"" + inputName + "\" is already being mapped to another input");
69+
}
6270
signatureBuilder.putInputs(inputName, toTensorInfo(input.asOutput()));
6371
return this;
6472
}
@@ -69,8 +77,12 @@ public Builder input(String inputName, Operand<?> input) {
6977
* @param inputName user-friendly name for this input tensor
7078
* @param input input tensor
7179
* @return this builder
80+
* @throws IllegalArgumentException if {@code outputName} is already mapped to another output
7281
*/
7382
public Builder output(String outputName, Operand<?> output) {
83+
if (signatureBuilder.containsOutputs(outputName)) {
84+
throw new IllegalArgumentException("\"" + outputName + "\" is already being mapped to another output");
85+
}
7486
signatureBuilder.putOutputs(outputName, toTensorInfo(output.asOutput()));
7587
return this;
7688
}
@@ -79,11 +91,11 @@ public Builder output(String outputName, Operand<?> output) {
7991
* Provide extensible name information enabling third-party users to mark a signature as
8092
* supporting a particular method
8193
*
82-
* @param methodName method name
94+
* @param methodName method name or null for none (default)
8395
* @return this builder
8496
*/
8597
public Builder methodName(String methodName) {
86-
signatureBuilder.setMethodName(methodName);
98+
signatureBuilder.setMethodName(methodName == null ? "" : methodName);
8799
return this;
88100
}
89101

@@ -126,10 +138,10 @@ public String key() {
126138
}
127139

128140
/**
129-
* Returns the method name of this signature (e.g. as exposed by TF serving)
141+
* Returns the method name of this signature (e.g. as exposed by TF serving) or null if none
130142
*/
131143
public String methodName() {
132-
return signatureDef.getMethodName();
144+
return signatureDef.getMethodName().isEmpty() ? null : signatureDef.getMethodName();
133145
}
134146

135147
/**

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
115
package org.tensorflow;
216

317
import static org.junit.jupiter.api.Assertions.assertEquals;
18+
import static org.junit.jupiter.api.Assertions.assertThrows;
419
import static org.junit.jupiter.api.Assertions.fail;
520

621
import org.junit.jupiter.api.Test;
@@ -75,18 +90,8 @@ public void closingFunctionReleaseAllResourcesItOwns() {
7590
g = f.graph();
7691
s = f.session();
7792
}
78-
try {
79-
s.run("Add");
80-
fail();
81-
} catch (IllegalStateException e) {
82-
// as expected
83-
}
84-
try {
85-
g.toGraphDef();
86-
fail();
87-
} catch (IllegalStateException e) {
88-
// as expected
89-
}
93+
assertThrows(IllegalStateException.class, () -> s.run("Add"));
94+
assertThrows(IllegalStateException.class, () -> g.toGraphDef());
9095
}
9196

9297
@Test
@@ -97,12 +102,7 @@ public void closingFunctionCreatedFromGraphOnlyReleaseResourcesItOwns() {
97102
try (ConcreteFunction f = ConcreteFunction.create(signature, g)) {
98103
s = f.session();
99104
}
100-
try {
101-
s.run(Init.DEFAULT_NAME);
102-
fail();
103-
} catch (IllegalStateException e) {
104-
// as expected
105-
}
105+
assertThrows(IllegalStateException.class, () -> s.run(Init.DEFAULT_NAME));
106106
g.toGraphDef(); // check that graph is still valid
107107
}
108108
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
115
package org.tensorflow;
216

317
import org.junit.jupiter.api.Test;

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import static org.junit.jupiter.api.Assertions.assertEquals;
1919
import static org.junit.jupiter.api.Assertions.assertNotNull;
20+
import static org.junit.jupiter.api.Assertions.assertThrows;
2021
import static org.junit.jupiter.api.Assertions.assertTrue;
2122
import static org.junit.jupiter.api.Assertions.fail;
2223

@@ -258,6 +259,34 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti
258259
}
259260
}
260261

262+
@Test
263+
public void cannotExportOrImportInvalidTags() {
264+
assertThrows(IllegalArgumentException.class, () ->
265+
SavedModelBundle.loader("/").withTags()
266+
);
267+
assertThrows(IllegalArgumentException.class, () ->
268+
SavedModelBundle.loader("/").withTags(new String[]{})
269+
);
270+
assertThrows(IllegalArgumentException.class, () ->
271+
SavedModelBundle.loader("/").withTags(new String[]{"tag", null})
272+
);
273+
assertThrows(IllegalArgumentException.class, () ->
274+
SavedModelBundle.loader("/").withTags(new String[]{"tag", ""})
275+
);
276+
assertThrows(IllegalArgumentException.class, () ->
277+
SavedModelBundle.exporter("/").withTags()
278+
);
279+
assertThrows(IllegalArgumentException.class, () ->
280+
SavedModelBundle.exporter("/").withTags(new String[]{})
281+
);
282+
assertThrows(IllegalArgumentException.class, () ->
283+
SavedModelBundle.exporter("/").withTags(new String[]{"tag", null})
284+
);
285+
assertThrows(IllegalArgumentException.class, () ->
286+
SavedModelBundle.exporter("/").withTags(new String[]{"tag", ""})
287+
);
288+
}
289+
261290
@Test
262291
public void pythonTfFunction() {
263292
// ConcreteFunctions on models saved using python

0 commit comments

Comments
 (0)