-
Notifications
You must be signed in to change notification settings - Fork 215
Save models as functions #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save models as functions #103
Conversation
…ow#89) Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]>
# Conflicts: # tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java # tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
} | ||
|
||
private enum Ownership { | ||
GRAPH, SESSION, NONE; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Rename GRAPH to GRAPH_AND_SESSION
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change looks good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments on names and other small things.
* Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap); | ||
* }</pre> | ||
*/ | ||
public class FunctionGraph implements AutoCloseable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe FunctionSession
as it's a specialised session? That way it conceptually lives next to EagerSession
and Session
, rather than next to Graph
. I feel like this is much closer to a session than a graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now, it is true that all functions loaded from a saved model share the same graph and are just used to execute it with a given signature (therefore they are acting more as a session). Nonetheless, I still have a preference to preserve the FunctionGraph
naming.
If we ignore the saving part (which is limited due to the actual state of the C API), conceptually these functions also allow you to build your graphs, replacing the need of allocating explicitly a Graph
instance, e.g.
public class MyModel {
// Function graph #1
private static Signature addTwo(Ops tf) {
Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE);
Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f));
return Signature.builder("addTwo").input("x", input).output("y", output).build();
}
// Function graph #2
private static Signature minusTwo(Ops tf) {
Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE);
Add<TFloat32> output = tf.math.sub(input, tf.constant(2.0f));
return Signature.builder("subTwo").input("x", input).output("y", output).build();
}
public static void main(String args[]) {
try (FunctionGraph f1 = FunctionGraph.create(MyModel::addTwo);
FunctionGraph f2 = FunctionGraph.create(MyModel::minusTwo);
Tensor<TFloat32> x = TFloat32.scalarOf(3.0f)) {
assertEquals(3.0f, f2.call(f1.call(x)).expect(TFloat32.DTYPE).data().getFloat());
}
}
}
Each function has its own graph and therefore if very coupled with that concept (where EagerSession
has no graph at all). They will also appear as separate graphs when exporting them into a saved model the same way Python does it (in fact, Python saves each function as one or more "objects", which are then linked to their distinct graph in the function library).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another idea would be to call this class a ConcreteFunction
. Concrete functions and functions are two very similar but distinct concepts in TF Python. The former is a typed realization of a function graph while the latter is its polymorphic version, only acting as a facade (e.g. a function is eventually composed of multiple concrete functions, one for each type of operands the function has been called with).
In our scenario here, the function graphs are strongly typed and refer to a single graph, therefore more behave like a concrete function in the Python paradigm and it would make sense to name them after it.
We can probably support polymorphic functions too in the future, where we pass the type of the input tensors in parameter to the function builder. So we will still need to find a proper name for this concept, which will probably result in class encapsulating one or more ConcreteFunction
. Maybe simply PolymorphicFunction
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My preference for a non-graph name is that Graph
s aren't executable, and are missing necessary state because the Variable
s live in the Session
. So FunctionGraph
is more like a Session
because it contains all the necessary bits to execute. However I'm also fine with the ConcreteFunction
name as that's just a new concept that people have to learn.
if (saverDef == null) { | ||
synchronized (this) { | ||
if (saverDef == null) { | ||
saverDef = addVariableSaver(this); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
saverDef
needs to be declared volatile for this to work - http://www.cs.umd.edu/~pugh/java/memoryModel/DoubleCheckedLocking.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or we make the whole method synchronized, which also seems fine as it's unlikely to be heavily contended.
* | ||
* @param exportDir the directory path containing a saved model. | ||
*/ | ||
public static Exporter exporter(String exportDir) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a helper that takes an exportDir, a FunctionGraph and then performs all the necessary operations? Seems like there are a bunch of hoops to jump through if you just want to save a single FunctionGraph
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we could, what about making that endpoint available at the function level? Something like function.save(exportDir)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that.
import java.nio.file.Paths; | ||
import jdk.nashorn.internal.codegen.FunctionSignature; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong import.
45cb250
to
175d9e6
Compare
Thanks @Shajan , I'll clean this up and merge it to our shared branch. It would be great if you can then validate that this merged solution still successfully cover your initial use case. If so, we will then make an official PR to the master branch out of it. |
@googlebot I consent. |
/** | ||
* Returns the signature of this function | ||
*/ | ||
public Signature signature() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor feature request: adding a toString() function to this ConcreteFunction class that prints a readable text representation of the signature would be really useful.
Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]> Save models as functions (#103) * Draft: Java API to use tf.function available on SavedModel. (#89) Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]> * Change API for creating concrete functions and exporting them to a saved model Co-authored-by: Karl Lessard <[email protected]> Rename signature name to key Print function signature when converting to String Add method that returns the signature of all functions in a saved model Add unit tests for python created SavedModel with tf.function
Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]> Save models as functions (#103) * Draft: Java API to use tf.function available on SavedModel. (#89) Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]> * Change API for creating concrete functions and exporting them to a saved model Co-authored-by: Karl Lessard <[email protected]> Rename signature name to key Print function signature when converting to String Add method that returns the signature of all functions in a saved model Add unit tests for python created SavedModel with tf.function
* Create, save and load models using functional API Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]> Save models as functions (#103) * Draft: Java API to use tf.function available on SavedModel. (#89) Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]> * Change API for creating concrete functions and exporting them to a saved model Co-authored-by: Karl Lessard <[email protected]> Rename signature name to key Print function signature when converting to String Add method that returns the signature of all functions in a saved model Add unit tests for python created SavedModel with tf.function * Add validations on signatures and saved models * Convert text file to Python * Add copyright on Python sample Co-authored-by: Shajan Dasan <[email protected]>
Ok I've modified a little this solution compared to the previous one, now
FunctionGraph
are auto-closeable and own the session/graph only if the function has been created before them. See examples provided in theFunctionGraph
documentation for more details.There are still a few things I need to look at before making this PR more official but I would like to get some feedbacks before continuing in that direction.
In summary:
FunctionGraph
is the new (disposable) name forTfFunction
as submitted previously by @ShajanI'll provide more and better examples later, thanks!