Skip to content

update JAX GPU version #21293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

update JAX GPU version #21293

wants to merge 6 commits into from

Conversation

sachinprasadhs
Copy link
Collaborator

@sachinprasadhs sachinprasadhs commented May 16, 2025

Update JAX GPU version and fix failing GPU tests

@sachinprasadhs sachinprasadhs marked this pull request as draft May 16, 2025 18:00
@codecov-commenter
Copy link

codecov-commenter commented May 16, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.57%. Comparing base (dfaca0e) to head (644641d).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #21293   +/-   ##
=======================================
  Coverage   82.57%   82.57%           
=======================================
  Files         564      564           
  Lines       54677    54677           
  Branches     8500     8500           
=======================================
  Hits        45152    45152           
  Misses       7435     7435           
  Partials     2090     2090           
Flag Coverage Δ
keras 82.39% <ø> (ø)
keras-jax 63.61% <ø> (ø)
keras-numpy 58.74% <ø> (ø)
keras-openvino 33.04% <ø> (ø)
keras-tensorflow 64.02% <ø> (ø)
keras-torch 63.68% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@sachinprasadhs
Copy link
Collaborator Author

@hertschuh , I replicated the similar changes here #2252 and JAX GPU is working.
Additionally when I try to update tensorflow version as well, I'm getting the MemoryError, not sure if it is the Tensorflow package issue or build script issue here.

In my local instance when I tried to build the package by updating the tensorflow version(2.19.0) with Python 3.11 version, I'm able to generate build package .whl file.
Here is the build error for my tensorflow 2.19 commit:
https://github.com/keras-team/keras/actions/runs/15076576568/job/42385286142

@pctablet505 , do you have any suggestions here.

@hertschuh
Copy link
Collaborator

@hertschuh , I replicated the similar changes here #2252 and JAX GPU is working. Additionally when I try to update tensorflow version as well, I'm getting the MemoryError, not sure if it is the Tensorflow package issue or build script issue here.

In my local instance when I tried to build the package by updating the tensorflow version(2.19.0) with Python 3.11 version, I'm able to generate build package .whl file. Here is the build error for my tensorflow 2.19 commit: https://github.com/keras-team/keras/actions/runs/15076576568/job/42385286142

@pctablet505 , do you have any suggestions here.

@sachinprasadhs

I tried this in #21026 and it doesn't work.

I spend quite a bit of time and hit a roadblock (there might be others):
JAX >= 0.5.1 requires Tensorflow 2.19 (at least for the saved_model export / reload tests)
Tensorflow 2.19 requires protobuf >= 4.21.6 (with lower versions, it crashes)
tf2onnx requires protobuf ~= 3.20 context on this here onnx/tensorflow-onnx#2261 they want to support old TF versions. tf2onnx may work ok with more recent version but they claim some other dependencies need 3.20.
The trouble is that the requirement.txt format doesn't allow overrides to solve conflicts or targeted --no-deps options.

Possible solutions:

Force tf2onnx to run with protobuf 4.21.6. That means that we have to remove tf2onnx from requirements.txt and do some extra "manual" pip installs in the CI after the requirements.txt. We also have to document that for ONNX users.
Run export tests (ONNX but also jax2tf potentially) separately with a custom requirements-??.txt . We might even need multiple of those. Export tests have generally been a difficult because that's where JAX meets TF, Torch meets TF, etc. and the versions have to be compatible. Then there is the question of whether we also run those on cuda, requiring even more separate requirements-??-cuda.txt.

@hertschuh
Copy link
Collaborator

@sachinprasadhs

Oh, so the MemoryError you're facing is because you're running Tensorflow 2.19 with the wrong version of protobuf (3.20 instead of >= 4.21.6.

@sachinprasadhs
Copy link
Collaborator Author

Well that is bit complicated then!
Not sure, with protobuf 3.20 in Python 3.11 environment, there was no build issue.

I didn't see your PR, in that case this PR is passing all the tests and the failing JAX test is not related to this PR.

To solve the core dump issue in JAX GPU and continue running JAX GPU tests, may be we can have this merged and then you can probably think of any workaround for updating Tensorflow.

@sachinprasadhs sachinprasadhs marked this pull request as ready for review May 16, 2025 23:16
@sachinprasadhs sachinprasadhs changed the title test JAX version update JAX GPU version May 16, 2025
@@ -13,7 +13,7 @@ source venv/bin/activate
python --version
python3 --version

export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:"
#export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, so this is what was causing the errors on JAX GPU?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, remove this line or add a comment explaining

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, detailed discussion here keras-team/keras-hub#2252

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment

@james77777778
Copy link
Contributor

Hey @sachinprasadhs @hertschuh
I think tf2onnx is not actively maintained (the last release was about 1 year ago).

Maybe we can test the export functionality only on CPU and provide guidance for users who try to export the model using JAX GPU. (e.g. JAX_PLATFORMS=cpu)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants