From 08d1d1c8995e8ce7dd0bede7a9d52ce73f5a4c43 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Tue, 24 Jun 2025 13:29:23 +0200 Subject: [PATCH 1/2] Add `dtype` in `Qwen3Model` and fix casting --- backends/candle/src/models/qwen3.rs | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/backends/candle/src/models/qwen3.rs b/backends/candle/src/models/qwen3.rs index f3f541b0..a864faea 100644 --- a/backends/candle/src/models/qwen3.rs +++ b/backends/candle/src/models/qwen3.rs @@ -2,7 +2,7 @@ use crate::layers::{ apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, Linear, RMSNorm, }; use crate::models::Model; -use candle::{Device, IndexOp, Result, Tensor, D}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; use text_embeddings_backend_core::{Batch, ModelType, Pool}; @@ -382,10 +382,12 @@ pub struct Qwen3Model { rotary_cache: (Tensor, Tensor), rotary_dim: usize, pool: Pool, - pub device: Device, num_attention_heads: usize, pad_token_id: u32, + dtype: DType, + device: Device, + span: tracing::Span, } @@ -435,8 +437,9 @@ impl Qwen3Model { rotary_dim, pool, pad_token_id: config.eos_token_id as u32, - device: vb.device().clone(), num_attention_heads: config.num_attention_heads, + dtype: vb.dtype().clone(), + device: vb.device().clone(), span: tracing::span!(tracing::Level::TRACE, "model"), }) } @@ -444,8 +447,6 @@ impl Qwen3Model { fn get_causal_attention_bias(&self, attention_bias: Tensor) -> Result { let (bs, dim, seq_len, _) = attention_bias.dims4()?; - let device = attention_bias.device(); - let mask: Vec = (0..seq_len) .flat_map(|i| (0..seq_len).map(move |j| (j > i) as u8)) .collect(); @@ -453,12 +454,13 @@ impl Qwen3Model { let causal_mask = Tensor::from_slice(&mask, (seq_len, seq_len), &Device::Cpu)?; let causal_mask = causal_mask.expand(&[bs, dim, seq_len, seq_len])?; - let negatives = Tensor::full(f32::MIN, attention_bias.shape(), &Device::Cpu)?; - let zeros = Tensor::zeros_like(&attention_bias)?.to_device(&Device::Cpu)?; + let negatives = + Tensor::full(f32::MIN, attention_bias.shape(), &Device::Cpu)?.to_dtype(self.dtype)?; + let zeros = Tensor::zeros_like(&attention_bias)?.to_dtype(self.dtype)?; let causal_mask = causal_mask .where_cond(&negatives, &zeros)? - .to_device(device)?; + .to_device(&self.device)?; attention_bias.broadcast_add(&causal_mask) } @@ -494,7 +496,7 @@ impl Qwen3Model { for _ in 0..padding { input_ids.push(self.pad_token_id); position_ids.push(0); - attention_bias.push(f32::MIN); + attention_bias.push(f32::NEG_INFINITY); } } @@ -539,7 +541,7 @@ impl Qwen3Model { // Create attention bias for causal masking even for single sequences let attention_bias = Tensor::zeros( (1, self.num_attention_heads, seq_len, seq_len), - candle::DType::F32, + self.dtype, &self.device, )?; From 30f9d1f7f609f7ed5be01c8afc59227073612603 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Tue, 24 Jun 2025 13:43:39 +0200 Subject: [PATCH 2/2] Remove `clone` for `dtype` --- backends/candle/src/models/qwen3.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/candle/src/models/qwen3.rs b/backends/candle/src/models/qwen3.rs index a864faea..989968d1 100644 --- a/backends/candle/src/models/qwen3.rs +++ b/backends/candle/src/models/qwen3.rs @@ -438,7 +438,7 @@ impl Qwen3Model { pool, pad_token_id: config.eos_token_id as u32, num_attention_heads: config.num_attention_heads, - dtype: vb.dtype().clone(), + dtype: vb.dtype(), device: vb.device().clone(), span: tracing::span!(tracing::Level::TRACE, "model"), })