Skip to content

Saving and loading models in Java with a functional API #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
karllessard opened this issue Aug 14, 2020 · 22 comments
Closed

Saving and loading models in Java with a functional API #101

karllessard opened this issue Aug 14, 2020 · 22 comments

Comments

@karllessard
Copy link
Collaborator

karllessard commented Aug 14, 2020

Confirmation: lets model the API with functions while still loading/saving session-centric graphs

I take it that we go with the current branch (after addressing open comments) and adding unit tests.

For unit tests: Here is a proposal

  • Checkin savedmodel (created using python) as a resource
  • [optionally] Add the python model + instructions (but not have the test run python code)

Originally posted by @Shajan in #89 (comment)

@karllessard
Copy link
Collaborator Author

karllessard commented Aug 14, 2020

Hi @Shajan , no worries at all for the meeting, I've created a new issue in GitHub to continue our discussion so it won't get lost in a PR that already has been merged.

So the idea is that while we will continue to only support session-centric saved models (these models with a single graph and one or more signatures), we will present it as a function to the user so the API will remain the same as much as possible when we start supporting them for real.

For the unit test, quick question, how do you build the Python model? Is it using one or more tf.function in it?

@Shajan
Copy link
Contributor

Shajan commented Aug 15, 2020

Yes, a model with a few tf.functions.

Here is a colab notebook with example https://colab.research.google.com/drive/1NrFbhe6do2Hq0VF3hgEk4KsGhY9kkNfU?usp=sharing

Copied the main parts here below.

class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.const_scalar = tf.constant(0.0)
    self.const_vector = tf.constant([0.0, 0.0, 0.0])
    self.const_matrix = tf.constant([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

  @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='request')])
  def serve(self, x):
    return self.const_scalar + x

  @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')])
  def get_scalar(self, x):
    return self.const_scalar + x

  @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')])
  def get_vector(self, x):
    return self.const_vector + x

  @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')])
  def get_matrix(self, x):
    return self.const_matrix + x

  @tf.function(input_signature=[
    tf.TensorSpec(shape=None, dtype=tf.float32, name='a'),
    tf.TensorSpec(shape=None, dtype=tf.float32, name='b')])
  def add(self, a, b):
    return a + b

model = MyModel()

signatures = {
  "get_const_scalar": model.get_scalar,
  "get_const_vector": model.get_vector,
  "get_const_matrix": model.get_matrix,
  "add": model.add
}
tf.saved_model.save(obj=model, export_dir='model', signatures=signatures)

@aday00
Copy link

aday00 commented Aug 17, 2020

These are great developments, looking forward to a unit test whenever it's ready.

I merged some older code from @karllessard to save/load models, seems to work with preliminary testing: #100 (comment) Comments welcome! All my diffs, source, and outputs are posted.

Happy to test new methods for model save/load, if that's helpful. My work's motivated by deadlines over here, need model save/load. Thanks for everything!

aday00 pushed a commit to aday00/java that referenced this issue Aug 20, 2020
karllessard@bdb0420

Further code discussion is at https://groups.google.com/a/tensorflow.org/forum/#!msg/jvm/gGKO-hVS4Pc/LF4rLJOdAQAJ

This commit manually merges Karl's code to mainline tensorflow-java.
However, Ops.java is unchanged here, because there were no functional changes in Ops.java.
This commit fixes tensorflow#100

Later, model-saving may be expected to be function-based, not session-based.
Function-based model saving is discussed in tensorflow#101
@karllessard
Copy link
Collaborator Author

@Shajan , can you please take a look at this new implementation? It turns out that the actual model was working great for loading function definition but not for exporting them, so I've shuffle the classes a little bit.

Also, like we discussed, even if in TF Java we continue to only support session-centric models, we expose each signatures as if they were real tf.function.

Right now the FunctionGraph (new revocable name for TfFunction) carries a session for executing directly a graph, providing a list of tensor inputs, but does not manage its lifetime (i.e. the function is invalid as soon as the session is released). I guess eventually, each function will have its own graph and session and therefore should become AutoCloseable as well to prevent leaks. But for now, a minimalist implementation seems to be enough to cover both of our cases, i.e. inference with function graphs and saved model export.

Please let me know what you think of it and see if we should still go in that direction or not, thank you!

CC\ @aday00 , @yzhuang

@karllessard
Copy link
Collaborator Author

I woke up this morning and rethought about this, I think resource management concerns shared by @yzhuang should be addressed properly. I'll push a new version of this code once I'm done with my coffee...

@karllessard
Copy link
Collaborator Author

New PR draft available at #103

@karllessard
Copy link
Collaborator Author

karllessard commented Aug 25, 2020

Hi @Shajan,

I still think that we should have a unit test loading a model saved from Python using tf.function, like you were previously suggesting. Because even if we can now save models in Java, we won't be able to test the inference on graphs created out of real functions yet. Let me know if you intend to do it or if I should, thanks!

@saudet
Copy link
Contributor

saudet commented Aug 25, 2020

I woke up this morning and rethought about this, I think resource management concerns shared by @yzhuang should be addressed properly. I'll push a new version of this code once I'm done with my coffee...

Hum, time to start doing some reference counting? :)

@karllessard
Copy link
Collaborator Author

@saudet, I certainly thought of it! I think we should attack reference counting as a whole rather that just by pieces though (i.e. using it for managing all native resources, not only a few of them). It is probably better to wait after the first alpha release and see this as a future improvement.

My latest draft handles function resource management properly, at the extent of what we are able to do with try-with-resource blocks.

@Shajan
Copy link
Contributor

Shajan commented Aug 27, 2020

Hi @Shajan,

I still think that we should have a unit test loading a model saved from Python using tf.function, like you were previously suggesting. Because even if we can now save models in Java, we won't be able to test the inference on graphs created out of real functions yet. Let me know if you intend to do it or if I should, thanks!

Agree.

For the unit test, here is my suggestion:
Checkin as resource, saved models created in Python, using Tensorflow 1.X and 2.X, then java unit test using these models. The python code will be checked in with instructions, but will not create the saved model as part of the test run.

What do you think?

@karllessard
Copy link
Collaborator Author

@Shajan : sounds good! Do you think you can take care of this part? I’ll merge my changes to the common branch probably today

@Shajan
Copy link
Contributor

Shajan commented Aug 27, 2020

@karllessard yes, i will add the test.

@karllessard
Copy link
Collaborator Author

Ok thanks @Shajan , I've merged my changes, please let me know when you are done with your unit tests so we can make an official PR for the master branch

@yzhuang
Copy link

yzhuang commented Aug 28, 2020

@saudet, I certainly thought of it! I think we should attack reference counting as a whole rather that just by pieces though (i.e. using it for managing all native resources, not only a few of them). It is probably better to wait after the first alpha release and see this as a future improvement.

My latest draft handles function resource management properly, at the extent of what we are able to do with try-with-resource blocks.

Hi @karllessard , try-with-resource is potentially a worse user experience. Session already does reference counting to track active runs (so that the session is not closed while a run is in flight), see: https://github.com/tensorflow/java/blob/master/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java#L361

Can't we simply add a numActiveOwners besides numActiveRuns (and have ConcreteFunction increment it and Session.close() decrement it)? We get ref counting almost for free this way, because the ref counting concurrency handling logic is already implemented due to needing to handle numActiveRuns.

@karllessard
Copy link
Collaborator Author

Hi @yzhuang , thanks for you valuable feedbacks,

The proposal is done wrt the actual memory management paradigm where only resources that are explicitly allocated by the user must be protected by try-with-resource. For example, if you are loading a saved model bundle, you don't need to protect the retrieved ConcreteFunction instances, the same way as you don't need to protect the Session, since they are owned by the bundle itself.

// Actual sessions
try (SavedModelBundle s = SavedModelBundler.load(...);
    Tensor<TFloat32> in = TFloat32.tensorOf(...);
    Tensor<TFloat32> out = s.session().feed("input1", in).fetch(out).run().get(0).expect(TFloat32.DTYPE)) {
    System.out.println("Result is " + out.data().getFloat());
}

// New functions
try (SavedModelBundle s = SavedModelBundler.load(...);
    Tensor<TFloat32> in = TFloat32.tensorOf(...);
    Tensor<TFloat32> out = s.function("func1").call(in).expect(TFloat32.DTYPE)) {
    System.out.println("Result is " + out.data().getFloat());
}

On the other hand, if the user explicitly creates the function, he should then release it, the same way again we do with graphs and sessions.

// Actual sessions
try (Graph g = new Graph()) {
    // ... fill graph
    try (Session s = new Session(g);
        Tensor<TFloat32> in = TFloat32.tensorOf(...);
        Tensor<TFloat32> out = s.session().feed("input1", in).fetch(out).run().get(0).expect(TFloat32.DTYPE)) {
        System.out.println("Result is " + out.data().getFloat());
    }
}

// New functions
try (ConcreteFunction f = ConcreteFunction.create(...)) {
    try (Tensor<TFloat32> in = TFloat32.tensorOf(...);
        Tensor<TFloat32> out = f.call(in).expect(TFloat32.DTYPE)) {
        System.out.println("Result is " + out.data().getFloat());
    }
}

The only exceptional case is when a function is created from an existing session, in which case protecting it is not required and will have no effect, e.g.

try (Graph g = new Graph()) {
    // ... fill the graph
    try (Session s = new Session(g)) {
        try (ConcreteFunction f = ConcreteFunction.create(Signature.builder()..., s)) {
            // this block was not required but won't break anything neither
        }
        ConcreteFunction f = ConcreteFunction.create(Signature.builder()..., s);  // this is legit in this case
    }
}

We are already planning to improve in the future the user experience with memory management as a whole by using reference counting on the native resources themselves but this is not in the scope of this feature. I think following the actual rules is acceptable for now.

Did you identified a case where the experience would be worsened by the introduction or the use of functions? And can you please provide an example where keeping track of the number of session owners would enhance it, providing that the Graph, Session and SavedModelBundle instances still need to be protected by try-with-resource?

@yzhuang
Copy link

yzhuang commented Sep 1, 2020

@karllessard Hi Karl, the above usage patterns look good to me, thanks for the clarifications!
(Sorry for the late reply--have been traveling and my Internet access have been spotty.)

@aday00
Copy link

aday00 commented Sep 9, 2020

Just to clarify, are these try-with-resource semantics required in Scala, e.g. #100 (comment)

It seems the Scala code compiles without all the try-blocks, but is not using the try-blocks a resource-usage-related mistake?

@Shajan
Copy link
Contributor

Shajan commented Sep 9, 2020

Unsure if i need additional permissions. Push to origin failed while trying to push the unit test.

java [shared-saved-model] $ git push origin
Username for 'https://github.com': [email protected]
Password for 'https://[email protected]@github.com':
...
remote: Resolving deltas: 100% (6/6), completed with 6 local objects.
remote: error: GH006: Protected branch update failed for refs/heads/shared-saved-model.
remote: error: You're not authorized to push to this branch. Visit https://docs.github.com/articles/about-protected-branches/ for more information.
To https://github.com/tensorflow/java
! [remote rejected] shared-saved-model -> shared-saved-model (protected branch hook declined)
error: failed to push some refs to 'https://github.com/tensorflow/java'

@karllessard
Copy link
Collaborator Author

@Shajan , can you please try again?

@Shajan
Copy link
Contributor

Shajan commented Sep 10, 2020

Thanks, merged.

@karllessard
Copy link
Collaborator Author

Great! I think I can now rebase that branch and make a PR out of it

@aday00
Copy link

aday00 commented Oct 15, 2020

Just to clarify, are these try-with-resource semantics required in Scala, e.g. #100 (comment)
It seems the Scala code compiles without all the try-blocks, but is not using the try-blocks a resource-usage-related mistake?

Yes, required, and mistake otherwise. Tensors and other AutoCloseable native resources must be close()ed to prevent resource leaks, but Scala 2.11.8 does not support try-with-resource semantics.
Scala 2.13 introduced Using for try-with-resource https://stackoverflow.com/questions/39866000/java-try-with-resource-not-working-with-scala

So, for my Scala 2.11.8, I wrote a manager that maintains a list of AutoCloseable objects, and closes them properly. This may be worth doing properly if there is ever a Tensorflow Scala.

Also, I've tested this new Functional API in Scala, and saving works for me, thanks for your hard work @karllessard @Shajan! With the older Session-based model saving code, I had some issues with loading a model that used padding operations -- haven't tested yet with this new API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants