-
Notifications
You must be signed in to change notification settings - Fork 19.6k
update JAX GPU version #21293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
update JAX GPU version #21293
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #21293 +/- ##
=======================================
Coverage 82.57% 82.57%
=======================================
Files 564 564
Lines 54677 54677
Branches 8500 8500
=======================================
Hits 45152 45152
Misses 7435 7435
Partials 2090 2090
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@hertschuh , I replicated the similar changes here #2252 and JAX GPU is working. In my local instance when I tried to build the package by updating the tensorflow version(2.19.0) with Python 3.11 version, I'm able to generate build package .whl file. @pctablet505 , do you have any suggestions here. |
I tried this in #21026 and it doesn't work. I spend quite a bit of time and hit a roadblock (there might be others): Possible solutions: Force tf2onnx to run with protobuf 4.21.6. That means that we have to remove tf2onnx from requirements.txt and do some extra "manual" pip installs in the CI after the requirements.txt. We also have to document that for ONNX users. |
Oh, so the MemoryError you're facing is because you're running Tensorflow 2.19 with the wrong version of protobuf (3.20 instead of >= 4.21.6. |
Well that is bit complicated then! I didn't see your PR, in that case this PR is passing all the tests and the failing JAX test is not related to this PR. To solve the core dump issue in JAX GPU and continue running JAX GPU tests, may be we can have this merged and then you can probably think of any workaround for updating Tensorflow. |
@@ -13,7 +13,7 @@ source venv/bin/activate | |||
python --version | |||
python3 --version | |||
|
|||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:" | |||
#export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, so this is what was causing the errors on JAX GPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so, remove this line or add a comment explaining
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, detailed discussion here keras-team/keras-hub#2252
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added comment
Hey @sachinprasadhs @hertschuh Maybe we can test the export functionality only on CPU and provide guidance for users who try to export the model using JAX GPU. (e.g. JAX_PLATFORMS=cpu) |
Update JAX GPU version and fix failing GPU tests