Skip to content

Implement GTE model to support the non-flash-attn version #443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 11, 2024

Conversation

kozistr
Copy link
Contributor

@kozistr kozistr commented Nov 30, 2024

What does this PR do?

I implemented the GTE model to support the non-flash-attn version and tested it's working on the T4 GPU.

  • to work with the gte-multilingual-base, config.json needs some modifications.
    • remove NewModelClassification from architectures.
    • remove new. prefix from the key of the weight file.
./target/release/text-embeddings-router --model-id ./gte-multilingual-base --port 12345 --dtype float32 --pooling cls
2024-11-30T09:07:38.940983Z  INFO text_embeddings_router: router/src/main.rs:175: Args { model_id: "./gte-************-*ase", revision: None, tokenization_workers: None, dtype: Some(Float32), pooling: Some(Cls), max_concurrent_requests: 512, max_batch_tokens: 16384, max_batch_requests: None, max_client_batch_size: 32, auto_truncate: false, default_prompt_name: None, default_prompt: None, hf_api_token: None, hostname: "e2fd3c35e003", port: 12345, uds_path: "/tmp/text-embeddings-inference-server", huggingface_hub_cache: None, payload_limit: 2000000, api_key: None, json_output: false, otlp_endpoint: None, otlp_service_name: "text-embeddings-inference.server", cors_allow_origin: None }
2024-11-30T09:07:39.613214Z  INFO text_embeddings_router: router/src/lib.rs:188: Maximum number of tokens per request: 8192
2024-11-30T09:07:39.613439Z  INFO text_embeddings_core::tokenization: core/src/tokenization.rs:28: Starting 4 tokenization workers
2024-11-30T09:07:41.166025Z  INFO text_embeddings_router: router/src/lib.rs:230: Starting model backend
2024-11-30T09:07:41.396429Z  INFO text_embeddings_backend_candle: backends/candle/src/lib.rs:352: Starting GTE model on Cuda(CudaDevice(DeviceId(1)))
2024-11-30T09:07:42.289292Z  INFO text_embeddings_router: router/src/lib.rs:248: Warming up model
2024-11-30T09:07:43.962553Z  WARN text_embeddings_router: router/src/lib.rs:310: Invalid hostname, defaulting to 0.0.0.0
2024-11-30T09:07:43.965124Z  INFO text_embeddings_router::http::server: router/src/http/server.rs:1812: Starting HTTP server: 0.0.0.0:12345
2024-11-30T09:07:43.965142Z  INFO text_embeddings_router::http::server: router/src/http/server.rs:1813: Ready

Close #375
Close #431
Close #439

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@OlivierDehaene OR @Narsil

@kozistr kozistr changed the base branch from main to dev November 30, 2024 09:56
@superchar
Copy link

Hello @kozistr, thank you for implementing GTE without flash attention—I’m looking forward to this being merged!
I tested your branch on a T4 GPU and observed differences in the output vectors between the flash attention and non-flash attention versions. Specifically, for the query {"input": "Hello!"}, the cosine distance between the two versions on the T4 GPU is 0.52.
Additionally, I ran the same query on an A100 GPU with flash attention and observed a cosine distance of 0.0000001 when compared to the T4 flash attention version, but still 0.52 compared to the T4 non-flash attention version.
Do you know if such a significant difference could be due to precision discrepancies? Could this impact the quality of GTE models without flash attention?

@kozistr
Copy link
Contributor Author

kozistr commented Dec 3, 2024

Hello @kozistr, thank you for implementing GTE without flash attention—I’m looking forward to this being merged! I tested your branch on a T4 GPU and observed differences in the output vectors between the flash attention and non-flash attention versions. Specifically, for the query {"input": "Hello!"}, the cosine distance between the two versions on the T4 GPU is 0.52. Additionally, I ran the same query on an A100 GPU with flash attention and observed a cosine distance of 0.0000001 when compared to the T4 flash attention version, but still 0.52 compared to the T4 non-flash attention version. Do you know if such a significant difference could be due to precision discrepancies? Could this impact the quality of GTE models without flash attention?

@superchar hello! thanks for the report! honestly, I didn't test much this time, so there might be some issues. sorry for that. I guess there may be a discrepancy between the implementations, and the only differences are the flash attention and rotary positional encoding parts. I'll look into it later (maybe It'll take time cuz I have no proper local environment to develop and test it). thanks for letting me know!

@OlivierDehaene OlivierDehaene changed the base branch from dev to main December 11, 2024 17:06
@OlivierDehaene OlivierDehaene changed the base branch from main to dev December 11, 2024 17:07
Copy link
Contributor

@OlivierDehaene OlivierDehaene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!
Merging this in a dev branch to run the CI and debug on T4.

@OlivierDehaene OlivierDehaene merged commit 9ced12b into huggingface:dev Dec 11, 2024
@kozistr kozistr deleted the feature/gte-for-cpu branch December 12, 2024 11:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants