-
Notifications
You must be signed in to change notification settings - Fork 280
Update requirements-jax-cuda.txt #2252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Upgrading jax version. For CI tests for keras-hub on jax GPU
For CI tests for JAX GPU
@divyashreepathihalli the segmentation fault has been fixed. Now there are few unrelated bugs that need to be fixed independently. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you fix the test failures here as well: https://btx.cloud.google.com/invocations/e69da1da-13de-4e1a-89cd-073188c13ebb/targets/keras_hub%2Fgithub%2Fubuntu%2Fgpu%2Fjax%2Fpresubmit/log?
@@ -16,15 +16,16 @@ fi | |||
set -x | |||
cd "${KOKORO_ROOT}/" | |||
|
|||
PYTHON_BINARY="/usr/bin/python3.9" | |||
PYTHON_BINARY="/usr/bin/python3.10" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we do this because JAX 0.6 is not there with Python 3.9? And does this mean we won't support 3.9 going forward?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may continue using it, but for testing we can use 3.10
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the new jax version will need python to be upgraded
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:" | ||
# setting the LD_LIBRARY_PATH manually is causing segmentation fault | ||
# export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the sole cause of the segmentation error? If yes, should we go without updating the JAX version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updating jax causes installation of cuda drivers, and this was causing old cuda to be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good to update the JAX version!
Thanks @pctablet505! Qwen and mixtral errors that were merged when the JAX tests were not running. I will disable the tests for these models and ping @kanpuriyanawab to fix these. And looks like Gemma is not calling flash attention, I wonder if this is because of the changes merged in Keras keras-team/keras#21254 - not sure but please take a look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@@ -16,15 +16,16 @@ fi | |||
set -x | |||
cd "${KOKORO_ROOT}/" | |||
|
|||
PYTHON_BINARY="/usr/bin/python3.9" | |||
PYTHON_BINARY="/usr/bin/python3.10" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the new jax version will need python to be upgraded
Merging this so we can run JAX GPU tests for KerasHub |
4bed9ae
into
keras-team:master
Upgrading jax version. For CI tests for keras-hub on jax GPU