Skip to content

Update requirements-jax-cuda.txt #2252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 16, 2025

Conversation

pctablet505
Copy link
Collaborator

Upgrading jax version. For CI tests for keras-hub on jax GPU

Upgrading jax version. For CI tests for keras-hub on jax GPU
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label May 16, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 16, 2025
For CI tests for JAX GPU
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label May 16, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 16, 2025
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label May 16, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 16, 2025
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label May 16, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 16, 2025
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label May 16, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 16, 2025
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label May 16, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 16, 2025
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label May 16, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 16, 2025
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label May 16, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 16, 2025
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label May 16, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 16, 2025
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label May 16, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 16, 2025
@pctablet505 pctablet505 requested a review from abheesht17 May 16, 2025 09:11
@pctablet505
Copy link
Collaborator Author

@divyashreepathihalli the segmentation fault has been fixed. Now there are few unrelated bugs that need to be fixed independently.

Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -16,15 +16,16 @@ fi
set -x
cd "${KOKORO_ROOT}/"

PYTHON_BINARY="/usr/bin/python3.9"
PYTHON_BINARY="/usr/bin/python3.10"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we do this because JAX 0.6 is not there with Python 3.9? And does this mean we won't support 3.9 going forward?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may continue using it, but for testing we can use 3.10

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the new jax version will need python to be upgraded

Comment on lines -27 to +28
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:"
# setting the LD_LIBRARY_PATH manually is causing segmentation fault
# export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the sole cause of the segmentation error? If yes, should we go without updating the JAX version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating jax causes installation of cuda drivers, and this was causing old cuda to be used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good to update the JAX version!

@divyashreepathihalli
Copy link
Collaborator

divyashreepathihalli commented May 16, 2025

Thanks @pctablet505! Qwen and mixtral errors that were merged when the JAX tests were not running. I will disable the tests for these models and ping @kanpuriyanawab to fix these.

And looks like Gemma is not calling flash attention, I wonder if this is because of the changes merged in Keras keras-team/keras#21254 - not sure but please take a look.

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@@ -16,15 +16,16 @@ fi
set -x
cd "${KOKORO_ROOT}/"

PYTHON_BINARY="/usr/bin/python3.9"
PYTHON_BINARY="/usr/bin/python3.10"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the new jax version will need python to be upgraded

@divyashreepathihalli
Copy link
Collaborator

Merging this so we can run JAX GPU tests for KerasHub

@divyashreepathihalli divyashreepathihalli merged commit 4bed9ae into keras-team:master May 16, 2025
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants