diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 944457e7..aec84375 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -11,7 +11,7 @@ use crate::compute_cap::{ compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap, }; use crate::models::{ - BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, JinaBertModel, + BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel, JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig, Qwen2Config, }; #[cfg(feature = "cuda")] @@ -218,10 +218,10 @@ impl CandleBackend { "Mistral is only supported on Cuda devices in fp16 with flash attention enabled" .to_string(), )), - (Config::Gte(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start( - "GTE is only supported on Cuda devices in fp16 with flash attention enabled" - .to_string(), - )), + (Config::Gte(config), Device::Cpu | Device::Metal(_)) => { + tracing::info!("Starting GTE model on {:?}", device); + Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?)) + } (Config::Qwen2(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start( "Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled" .to_string(), @@ -349,10 +349,12 @@ impl CandleBackend { if dtype != DType::F16 || !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) { - return Err(BackendError::Start("GTE is only supported on Cuda devices in fp16 with flash attention enabled".to_string())); + tracing::info!("Starting GTE model on {:?}", device); + Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?)) + } else { + tracing::info!("Starting FlashGTE model on {:?}", device); + Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?)) } - tracing::info!("Starting FlashGTE model on {:?}", device); - Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?)) } #[cfg(feature = "cuda")] (Config::Qwen2(config), Device::Cuda(_)) => { diff --git a/backends/candle/src/models/flash_gte.rs b/backends/candle/src/models/flash_gte.rs index 53e62f6d..b6632d8b 100644 --- a/backends/candle/src/models/flash_gte.rs +++ b/backends/candle/src/models/flash_gte.rs @@ -1,6 +1,8 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{HiddenAct, LayerNorm, Linear}; -use crate::models::{GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling}; +use crate::layers::{LayerNorm, Linear}; +use crate::models::{ + GTEClassificationHead, GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling, GTEMLP, +}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; use text_embeddings_backend_core::{Batch, ModelType, Pool}; @@ -93,60 +95,7 @@ impl GTEAttention { } } -struct GTEMLP { - up_gate_proj: Linear, - down_proj: Linear, - - act: HiddenAct, - intermediate_size: usize, - - span: tracing::Span, -} - -impl GTEMLP { - pub fn load(vb: VarBuilder, config: >EConfig) -> Result { - let intermediate_size = config.intermediate_size; - - let up_gate_proj_weight = vb - .pp("up_gate_proj") - .get((intermediate_size * 2, config.hidden_size), "weight")?; - - let up_gate_proj = Linear::new(up_gate_proj_weight, None, None); - - let down_proj_weight = vb - .pp("down_proj") - .get((config.hidden_size, intermediate_size), "weight")?; - let down_proj_bias = vb.pp("down_proj").get(config.hidden_size, "bias")?; - let down_proj = Linear::new(down_proj_weight, Some(down_proj_bias), None); - - Ok(Self { - up_gate_proj, - down_proj, - intermediate_size, - act: config.hidden_act.clone(), - span: tracing::span!(tracing::Level::TRACE, "mlp"), - }) - } - - pub fn forward(&self, hidden_states: &Tensor) -> Result { - let _enter = self.span.enter(); - - let up_gate_states = self.up_gate_proj.forward(hidden_states)?; - let up_states = up_gate_states.narrow(1, 0, self.intermediate_size)?; - let gate_states = - up_gate_states.narrow(1, self.intermediate_size, self.intermediate_size)?; - - let gate_states = match self.act { - HiddenAct::Gelu => gate_states.gelu(), - HiddenAct::Relu => gate_states.relu(), - HiddenAct::Swiglu => gate_states.silu(), - }?; - let r = self.down_proj.forward(&(gate_states * up_states)?); - r - } -} - -struct GTELayer { +pub struct GTELayer { attention: GTEAttention, mlp: GTEMLP, attention_layer_norm: LayerNorm, @@ -198,58 +147,6 @@ impl GTELayer { } } -pub struct GTEClassificationHead { - pooler: Option, - classifier: Linear, - span: tracing::Span, -} - -impl GTEClassificationHead { - #[allow(dead_code)] - pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { - let n_classes = match &config.id2label { - None => candle::bail!("`id2label` must be set for classifier models"), - Some(id2label) => id2label.len(), - }; - - let pooler = if let Ok(pooler_weight) = vb - .pp("pooler.dense") - .get((config.hidden_size, config.hidden_size), "weight") - { - let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; - Some(Linear::new(pooler_weight, Some(pooler_bias), None)) - } else { - None - }; - - let classifier_weight = vb - .pp("classifier") - .get((n_classes, config.hidden_size), "weight")?; - let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?; - let classifier = Linear::new(classifier_weight, Some(classifier_bias), None); - - Ok(Self { - classifier, - pooler, - span: tracing::span!(tracing::Level::TRACE, "classifier"), - }) - } - - pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result { - let _enter = self.span.enter(); - - let mut hidden_states = hidden_states.unsqueeze(1)?; - if let Some(pooler) = self.pooler.as_ref() { - hidden_states = pooler.forward(&hidden_states)?; - hidden_states = hidden_states.tanh()?; - } - - let hidden_states = self.classifier.forward(&hidden_states)?; - let hidden_states = hidden_states.squeeze(1)?; - Ok(hidden_states) - } -} - pub struct FlashGTEModel { word_embeddings: Embedding, token_type_embeddings: Option, diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index bc4bfdce..10f29eb6 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -1,7 +1,10 @@ -use crate::layers::HiddenAct; -use crate::models::PositionEmbeddingType; +use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear}; +use crate::models::{apply_rotary, cos_sin, inv_freqs, Model, PositionEmbeddingType}; +use candle::{Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct NTKScaling { @@ -35,3 +38,534 @@ pub struct GTEConfig { pub logn_attention_clip1: bool, pub id2label: Option>, } + +struct GTEAttention { + qkv_linear: Linear, + o_proj: Linear, + + num_attention_heads: usize, + attention_head_size: usize, + + softmax_scale: f32, + + span: tracing::Span, +} + +impl GTEAttention { + pub fn load(vb: VarBuilder, config: >EConfig) -> Result { + let num_attention_heads = config.num_attention_heads; + let attention_head_size = config.hidden_size / config.num_attention_heads; + let hidden_size = config.hidden_size; + + let qkv_weight = vb + .pp("qkv_proj") + .get((hidden_size * 3, hidden_size), "weight")?; + let qkv_bias = vb.pp("qkv_proj").get(hidden_size * 3, "bias")?; + + let qkv_linear = Linear::new(qkv_weight, Some(qkv_bias), None); + + let o_proj_weight = vb.pp("o_proj").get((hidden_size, hidden_size), "weight")?; + let o_proj_bias = vb.pp("o_proj").get(hidden_size, "bias")?; + + let o_proj = Linear::new(o_proj_weight, Some(o_proj_bias), None); + + let softmax_scale = 1. / (attention_head_size as f64).sqrt() as f32; + + Ok(Self { + qkv_linear, + o_proj, + num_attention_heads, + attention_head_size, + softmax_scale, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let _enter = self.span.enter(); + let device = hidden_states.device(); + + let qkv = self.qkv_linear.forward(hidden_states)?; + + // Reshape to [tokens, heads, head_size] + let mut new_qkv_shape = qkv.dims().to_vec(); + new_qkv_shape.pop(); + new_qkv_shape.push(self.num_attention_heads * 3); + new_qkv_shape.push(self.attention_head_size); + + let qkv = qkv.reshape(new_qkv_shape)?; + + // Split qkv tensor + let q = qkv.narrow(1, 0, self.num_attention_heads)?; + let k = qkv.narrow(1, self.num_attention_heads, self.num_attention_heads)?; + let v = qkv.narrow(1, self.num_attention_heads * 2, self.num_attention_heads)?; + + let q = apply_rotary(&q, cos, sin, self.attention_head_size)?; + let k = apply_rotary(&k, cos, sin, self.attention_head_size)?; + + #[allow(unused_variables)] + let context_layer = + if let (Device::Cuda(_), Some(cublaslt)) = (device, get_cublas_lt_wrapper()) { + #[cfg(feature = "cuda")] + { + // Batch matrix multiplication + // Fuse softmax scale and attention_bias add + let attention_scores = cublaslt.batch_matmul( + &k, + &q, + None, + Some(self.softmax_scale as f32), + None, + None, + None, + )?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + + cublaslt.batch_matmul( + &v.t()?.contiguous()?, + &attention_probs, + // We save one allocation + Some(&q), + None, + None, + None, + None, + ) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } else { + let attention_scores = q.matmul(&k.t()?)?; + let attention_scores = (attention_scores * self.softmax_scale as f64)?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + attention_probs.matmul(&v.contiguous()?) + }?; + + let context_layer = context_layer.flatten_from(D::Minus2)?; + + let hidden_states = self.o_proj.forward(&context_layer)?; + + Ok(hidden_states) + } +} + +#[allow(clippy::upper_case_acronyms)] +pub struct GTEMLP { + up_gate_proj: Linear, + down_proj: Linear, + + act: HiddenAct, + intermediate_size: usize, + + span: tracing::Span, +} + +impl GTEMLP { + pub fn load(vb: VarBuilder, config: >EConfig) -> Result { + let intermediate_size = config.intermediate_size; + + let up_gate_proj_weight = vb + .pp("up_gate_proj") + .get((intermediate_size * 2, config.hidden_size), "weight")?; + + let up_gate_proj = Linear::new(up_gate_proj_weight, None, None); + + let down_proj_weight = vb + .pp("down_proj") + .get((config.hidden_size, intermediate_size), "weight")?; + let down_proj_bias = vb.pp("down_proj").get(config.hidden_size, "bias")?; + let down_proj = Linear::new(down_proj_weight, Some(down_proj_bias), None); + + Ok(Self { + up_gate_proj, + down_proj, + intermediate_size, + act: config.hidden_act.clone(), + span: tracing::span!(tracing::Level::TRACE, "mlp"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let up_gate_states = self.up_gate_proj.forward(hidden_states)?; + let up_states = up_gate_states.narrow(1, 0, self.intermediate_size)?; + let gate_states = + up_gate_states.narrow(1, self.intermediate_size, self.intermediate_size)?; + + let gate_states = match self.act { + HiddenAct::Gelu => gate_states.gelu(), + HiddenAct::Relu => gate_states.relu(), + HiddenAct::Swiglu => gate_states.silu(), + }?; + + self.down_proj.forward(&(gate_states * up_states)?) + } +} + +pub struct GTELayer { + attention: GTEAttention, + mlp: GTEMLP, + attention_layer_norm: LayerNorm, + mlp_layer_norm: LayerNorm, + + span: tracing::Span, +} + +impl GTELayer { + pub fn load(vb: VarBuilder, config: >EConfig) -> Result { + let attention = GTEAttention::load(vb.pp("attention"), config)?; + let mlp = GTEMLP::load(vb.pp("mlp"), config)?; + + let attention_layer_norm = + LayerNorm::load(vb.pp("attn_ln"), config.hidden_size, config.layer_norm_eps)?; + let mlp_layer_norm = + LayerNorm::load(vb.pp("mlp_ln"), config.hidden_size, config.layer_norm_eps)?; + + Ok(Self { + attention, + mlp, + attention_layer_norm, + mlp_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let _enter = self.span.enter(); + let attn_output = self.attention.forward(hidden_states, cos, sin)?; + let normed_attn_res_output = self + .attention_layer_norm + .forward(&attn_output, Some(hidden_states))?; + + let mlp_output = self.mlp.forward(&normed_attn_res_output)?; + let normed_mlp_res_output = self + .mlp_layer_norm + .forward(&mlp_output, Some(&normed_attn_res_output))?; + Ok(normed_mlp_res_output) + } +} + +pub struct GTEClassificationHead { + pooler: Option, + classifier: Linear, + span: tracing::Span, +} + +impl GTEClassificationHead { + #[allow(dead_code)] + pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { + let n_classes = match &config.id2label { + None => candle::bail!("`id2label` must be set for classifier models"), + Some(id2label) => id2label.len(), + }; + + let pooler = if let Ok(pooler_weight) = vb + .pp("pooler.dense") + .get((config.hidden_size, config.hidden_size), "weight") + { + let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; + Some(Linear::new(pooler_weight, Some(pooler_bias), None)) + } else { + None + }; + + let classifier_weight = vb + .pp("classifier") + .get((n_classes, config.hidden_size), "weight")?; + let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?; + let classifier = Linear::new(classifier_weight, Some(classifier_bias), None); + + Ok(Self { + classifier, + pooler, + span: tracing::span!(tracing::Level::TRACE, "classifier"), + }) + } + + pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.unsqueeze(1)?; + if let Some(pooler) = self.pooler.as_ref() { + hidden_states = pooler.forward(&hidden_states)?; + hidden_states = hidden_states.tanh()?; + } + + let hidden_states = self.classifier.forward(&hidden_states)?; + let hidden_states = hidden_states.squeeze(1)?; + Ok(hidden_states) + } +} + +pub struct GTEModel { + word_embeddings: Embedding, + token_type_embeddings: Option, + layers: Vec, + embeddings_norm: LayerNorm, + cos_cache: Tensor, + sin_cache: Tensor, + classifier: Option, + pool: Pool, + pub device: Device, + + span: tracing::Span, +} + +impl GTEModel { + pub fn load(vb: VarBuilder, config: >EConfig, model_type: ModelType) -> Result { + if config.logn_attention_clip1 { + candle::bail!("`logn_attention_clip1` is not supported"); + } + if config.logn_attention_scale { + candle::bail!("`logn_attention_scale` is not supported"); + } + + if config.position_embedding_type != PositionEmbeddingType::Rope { + candle::bail!("Only `PositionEmbeddingType::Rope` is supported"); + } + + let (pool, classifier) = match model_type { + ModelType::Classifier => { + let pool = Pool::Cls; + + let classifier = GTEClassificationHead::load(vb.clone(), config)?; + (pool, Some(classifier)) + } + ModelType::Embedding(pool) => (pool, None), + }; + + let word_embeddings = Embedding::new( + vb.pp("embeddings.word_embeddings") + .get((config.vocab_size, config.hidden_size), "weight")?, + config.hidden_size, + ); + + let token_type_embeddings = if config.type_vocab_size > 0 { + Some(Embedding::new( + vb.pp("embeddings.token_type_embeddings") + .get((config.type_vocab_size, config.hidden_size), "weight")?, + config.hidden_size, + )) + } else { + None + }; + + let layers = (0..config.num_hidden_layers) + .map(|index| GTELayer::load(vb.pp(format!("encoder.layer.{index}")), config)) + .collect::>>()?; + + let embeddings_norm = LayerNorm::load( + vb.pp("embeddings.LayerNorm"), + config.hidden_size, + config.layer_norm_eps, + )?; + + let inv_freqs = if let Some(RopeScaling::Ntk(NTKScaling { factor })) = config.rope_scaling { + let inv_freqs = inv_freqs( + layers[0].attention.attention_head_size, + config.rope_theta * factor, + vb.device(), + )?; + let s = factor.powf(2.0 / layers[0].attention.attention_head_size as f32) as f64; + inv_freqs / s + } else { + inv_freqs( + layers[0].attention.attention_head_size, + config.rope_theta, + vb.device(), + ) + }?; + + let (cos_cache, sin_cache) = + cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?; + + Ok(Self { + word_embeddings, + token_type_embeddings, + layers, + embeddings_norm, + cos_cache, + sin_cache, + classifier, + pool, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, batch: Batch) -> Result<(Option, Option)> { + let _enter = self.span.enter(); + + let batch_size = batch.cumulative_seq_lengths.len() - 1; + let shape = batch.input_ids.len(); + + // Create Cuda tensors + let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?; + let token_type_ids = Tensor::from_vec(batch.token_type_ids, shape, &self.device)?; + let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?; + let cu_seqlens = Tensor::from_vec( + batch.cumulative_seq_lengths.clone(), + batch_size + 1, + &self.device, + )?; + + let word_embeddings = self.word_embeddings.forward(&input_ids)?; + let token_type_embeddings = self + .token_type_embeddings + .as_ref() + .map(|emb| emb.forward(&token_type_ids)) + .transpose()?; + + let mut hidden_states = self + .embeddings_norm + .forward(&word_embeddings, token_type_embeddings.as_ref())?; + + let cos = self.cos_cache.index_select(&position_ids, 0)?; + let sin = self.sin_cache.index_select(&position_ids, 0)?; + + let cos = cos.unsqueeze(1)?; + let sin = sin.unsqueeze(1)?; + + for layer in &self.layers { + let h = layer.forward(&hidden_states, &cos, &sin)?; + hidden_states = h; + } + + let outputs = hidden_states; + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + let has_raw_requests = !batch.raw_indices.is_empty(); + + let pooled_embeddings = if has_pooling_requests { + match self.pool { + // CLS and LastToken pooling + Pool::Cls | Pool::LastToken => { + if batch_size > 1 { + // Get token indices form cu_seqlens + let mut indices = match self.pool { + Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?, + Pool::LastToken => { + let end = cu_seqlens.narrow(0, 1, batch_size)?; + (&end - &end.ones_like()?)? + } + _ => unreachable!(), + }; + + // If raw_indices is empty, we don't need to do anything with + // the pooled_indices + if has_raw_requests { + // We need the pooled indices to select the correct cls indices + let pooled_indices = Tensor::from_vec( + batch.pooled_indices.clone(), + batch.pooled_indices.len(), + &self.device, + )?; + + // Only select indices that requires pooling + indices = indices.index_select(&pooled_indices, 0)? + } + + // Select tokens + Some(outputs.index_select(&indices, 0)?) + } else { + Some( + match self.pool { + Pool::Cls => outputs.i(0)?, + Pool::LastToken => { + outputs.i(batch.cumulative_seq_lengths[1] as usize - 1)? + } + _ => unreachable!(), + } + .unsqueeze(0)?, + ) + } + } + // Mean pooling + Pool::Mean => { + if batch_size > 1 { + // for each request that requires pooling + let results: Result> = batch + .pooled_indices + .into_iter() + .map(|i| { + let i = i as usize; + let start = batch.cumulative_seq_lengths[i]; + let len = batch.cumulative_seq_lengths[i + 1] - start; + + // Mean + let embeddings = outputs.narrow(0, start as usize, len as usize)?; + embeddings.sum_keepdim(0)? / (len as f64) + }) + .collect(); + + // Concatenate all results + Some(Tensor::cat(&results?, 0)?) + } else { + Some((outputs.sum_keepdim(0)? / (batch.max_length as f64))?) + } + } + Pool::Splade => { + unreachable!(); + } + } + } else { + None + }; + + let raw_embeddings = if has_raw_requests { + if batch_size > 1 && has_pooling_requests { + // Create indexing vector for the embeddings + let mut final_indices: Vec = Vec::with_capacity(shape); + for i in batch.raw_indices.into_iter() { + let i = i as usize; + // Get start/end token index of this specific member of the batch + let start = batch.cumulative_seq_lengths[i]; + let end = batch.cumulative_seq_lengths[i + 1]; + + for j in start..end { + // Add indices for the tokens of this specific member of the batch + final_indices.push(j); + } + } + + let final_indices_length = final_indices.len(); + let final_indices = + Tensor::from_vec(final_indices, final_indices_length, &self.device)?; + + // Select the tokens with final indices + Some(outputs.index_select(&final_indices, 0)?) + } else { + Some(outputs) + } + } else { + None + }; + + Ok((pooled_embeddings, raw_embeddings)) + } +} + +impl Model for GTEModel { + fn is_padded(&self) -> bool { + false + } + + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { + self.forward(batch) + } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classifier.forward(&pooled_embeddings) + } + } + } +} diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index b1e9f937..2c7b2322 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -40,11 +40,11 @@ pub use bert::{BertConfig, BertModel, PositionEmbeddingType}; use candle::{Result, Tensor}; pub use distilbert::{DistilBertConfig, DistilBertModel}; #[allow(unused_imports)] -pub use gte::{GTEConfig, NTKScaling, RopeScaling}; +pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, NTKScaling, RopeScaling, GTEMLP}; pub use jina::JinaBertModel; pub use jina_code::JinaCodeBertModel; pub use mistral::MistralConfig; -pub use nomic::{NomicBertModel, NomicConfig}; +pub use nomic::{apply_rotary, cos_sin, inv_freqs, NomicBertModel, NomicConfig}; pub use qwen2::Qwen2Config; use text_embeddings_backend_core::Batch; diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index cdaaea92..3fd3f645 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -176,15 +176,6 @@ impl NomicAttention { }) } - fn apply_rotary(&self, x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { - let dim = self.attention_head_size / 2; - let x1 = x.narrow(D::Minus1, 0, dim)?; - let x2 = x.narrow(D::Minus1, dim, dim)?; - let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; - let rope = (x.broadcast_mul(cos)? + rotate_x.broadcast_mul(sin)?)?; - Ok(rope) - } - pub fn forward( &self, hidden_states: &Tensor, @@ -208,8 +199,8 @@ impl NomicAttention { let key_layer = &qkv[1].contiguous()?; let value_layer = &qkv[2]; - let query_layer = self.apply_rotary(query_layer, cos, sin)?; - let key_layer = self.apply_rotary(key_layer, cos, sin)?; + let query_layer = apply_rotary(query_layer, cos, sin, self.attention_head_size)?; + let key_layer = apply_rotary(key_layer, cos, sin, self.attention_head_size)?; #[allow(unused_variables)] let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = @@ -699,10 +690,25 @@ pub fn cos_sin(length: usize, inv_freqs: &Tensor, dtype: DType) -> Result<(Tenso Ok((cos, sin)) } +pub fn apply_rotary( + x: &Tensor, + cos: &Tensor, + sin: &Tensor, + attention_head_size: usize, +) -> Result { + let dim = attention_head_size / 2; + let x1 = x.narrow(D::Minus1, 0, dim)?; + let x2 = x.narrow(D::Minus1, dim, dim)?; + let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; + let rope = (x.broadcast_mul(cos)? + rotate_x.broadcast_mul(sin)?)?; + Ok(rope) +} + impl Model for NomicBertModel { fn is_padded(&self) -> bool { false } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) }