-
Notifications
You must be signed in to change notification settings - Fork 282
Fix Qwen3 Embedding Float16 DType #663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hey @tpendragon thanks for the PR, but we just ran it from |
@alvarobartt Ah, sorry, I'm running on a Mac with CPU only and running the rust as a crate (https://github.com/pulibrary/dpul-collections/blob/334456f531b41ce084dd8dd822d792199df298c8/native/embedanything/src/lib.rs#L44.) Maybe I missed a step in setting that up? I'd expect this branch of code to get hit if it was run with the candle backend on CPU with dtype 16. |
@alvarobartt oh and importantly it only breaks if two inputs are batched together that have different token lengths, so padding happens. If the batches happen singly it doesn't trigger the bug. |
Fair @tpendragon thanks for the details on reproducing, I'll try on both CPU and MPS with the aforementioned settings and come back to you! I'm adding another comment on the PR with what IMO is not required 🤗 Thanks again for the PR! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for the PR! Didn't reproduce your issue yet, but the PR looks fair, only left some nits / comments 👍🏻
backends/candle/src/models/qwen3.rs
Outdated
@@ -514,7 +519,7 @@ impl Qwen3Model { | |||
|
|||
let attention_bias = if masking { | |||
let attention_bias = | |||
Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?; | |||
Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?.to_dtype(self.dtype)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here the linting is failing, you can install it as pip install pre-commit --upgrade
, then pre-commit install
from the root directory of the repository to run that on every commit, or just run it one to fix the issue with pre-commit run --all-files
🤗
Hey @tpendragon I already tested and it's indeed fixed with your PR, thanks! Could you please run diff --git a/backends/candle/src/models/qwen3.rs b/backends/candle/src/models/qwen3.rs
index 064ab6d..1330992 100644
--- a/backends/candle/src/models/qwen3.rs
+++ b/backends/candle/src/models/qwen3.rs
@@ -519,7 +519,8 @@ impl Qwen3Model {
let attention_bias = if masking {
let attention_bias =
- Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?.to_dtype(self.dtype)?;
+ Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?
+ .to_dtype(self.dtype)?;
// Broadcast once instead of at every layer
let attention_bias = attention_bias
.broadcast_as((batch_size, self.num_attention_heads, max_length, max_length))?
diff --git a/backends/src/lib.rs b/backends/src/lib.rs
index d333951..be40b09 100644
--- a/backends/src/lib.rs
+++ b/backends/src/lib.rs
@@ -150,11 +150,8 @@ impl Backend {
}
max_input_length = std::cmp::min(max_input_length, max_warmup_length);
- let mut seq_lengths: Vec<usize> = generate_bucket_sizes(
- seq_bucket_size,
- max_input_length,
- seq_len_exp_base,
- );
+ let mut seq_lengths: Vec<usize> =
+ generate_bucket_sizes(seq_bucket_size, max_input_length, seq_len_exp_base);
if let Some(&last) = seq_lengths.last() {
if last < max_input_length {
seq_lengths.push(max_input_length); |
@alvarobartt Ran it - thanks! |
Hey @tpendragon |
@alvarobartt I have no idea what's going on, I can't get fmt to throw an error, even in a nix shell. |
What does this PR do?
Without this, batches were erroring on CPU because it was trying to compare the F32 attention mask with the F16 output tensors.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.