From 9a3676c73bdafd299ad2ee1f93ff91875569ccf3 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 30 Nov 2024 11:52:59 +0900 Subject: [PATCH 1/7] feature: GTEModel --- backends/candle/src/lib.rs | 10 +- backends/candle/src/models/flash_gte.rs | 115 +---- backends/candle/src/models/gte.rs | 567 +++++++++++++++++++++++- backends/candle/src/models/mod.rs | 4 +- backends/candle/src/models/nomic.rs | 27 +- 5 files changed, 593 insertions(+), 130 deletions(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 944457e7..c8f762dd 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -11,7 +11,7 @@ use crate::compute_cap::{ compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap, }; use crate::models::{ - BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, JinaBertModel, + BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel, JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig, Qwen2Config, }; #[cfg(feature = "cuda")] @@ -218,10 +218,10 @@ impl CandleBackend { "Mistral is only supported on Cuda devices in fp16 with flash attention enabled" .to_string(), )), - (Config::Gte(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start( - "GTE is only supported on Cuda devices in fp16 with flash attention enabled" - .to_string(), - )), + (Config::Gte(config), Device::Cpu | Device::Metal(_)) => { + tracing::info!("Starting GTE model on {:?}", device); + Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?)) + } (Config::Qwen2(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start( "Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled" .to_string(), diff --git a/backends/candle/src/models/flash_gte.rs b/backends/candle/src/models/flash_gte.rs index 53e62f6d..13f97730 100644 --- a/backends/candle/src/models/flash_gte.rs +++ b/backends/candle/src/models/flash_gte.rs @@ -1,6 +1,8 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{HiddenAct, LayerNorm, Linear}; -use crate::models::{GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling}; +use crate::models::{ + GTEClassificationHead, GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling, GTEMLP, +}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; use text_embeddings_backend_core::{Batch, ModelType, Pool}; @@ -93,60 +95,7 @@ impl GTEAttention { } } -struct GTEMLP { - up_gate_proj: Linear, - down_proj: Linear, - - act: HiddenAct, - intermediate_size: usize, - - span: tracing::Span, -} - -impl GTEMLP { - pub fn load(vb: VarBuilder, config: >EConfig) -> Result { - let intermediate_size = config.intermediate_size; - - let up_gate_proj_weight = vb - .pp("up_gate_proj") - .get((intermediate_size * 2, config.hidden_size), "weight")?; - - let up_gate_proj = Linear::new(up_gate_proj_weight, None, None); - - let down_proj_weight = vb - .pp("down_proj") - .get((config.hidden_size, intermediate_size), "weight")?; - let down_proj_bias = vb.pp("down_proj").get(config.hidden_size, "bias")?; - let down_proj = Linear::new(down_proj_weight, Some(down_proj_bias), None); - - Ok(Self { - up_gate_proj, - down_proj, - intermediate_size, - act: config.hidden_act.clone(), - span: tracing::span!(tracing::Level::TRACE, "mlp"), - }) - } - - pub fn forward(&self, hidden_states: &Tensor) -> Result { - let _enter = self.span.enter(); - - let up_gate_states = self.up_gate_proj.forward(hidden_states)?; - let up_states = up_gate_states.narrow(1, 0, self.intermediate_size)?; - let gate_states = - up_gate_states.narrow(1, self.intermediate_size, self.intermediate_size)?; - - let gate_states = match self.act { - HiddenAct::Gelu => gate_states.gelu(), - HiddenAct::Relu => gate_states.relu(), - HiddenAct::Swiglu => gate_states.silu(), - }?; - let r = self.down_proj.forward(&(gate_states * up_states)?); - r - } -} - -struct GTELayer { +pub struct GTELayer { attention: GTEAttention, mlp: GTEMLP, attention_layer_norm: LayerNorm, @@ -183,9 +132,7 @@ impl GTELayer { max_s: usize, ) -> Result { let _enter = self.span.enter(); - let attn_output = self - .attention - .forward(&hidden_states, cu_seqlens, cos, sin, max_s)?; + let attn_output = self.attention.forward(&hidden_states, cos, sin)?; let normed_attn_res_output = self .attention_layer_norm .forward(&attn_output, Some(hidden_states))?; @@ -198,58 +145,6 @@ impl GTELayer { } } -pub struct GTEClassificationHead { - pooler: Option, - classifier: Linear, - span: tracing::Span, -} - -impl GTEClassificationHead { - #[allow(dead_code)] - pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { - let n_classes = match &config.id2label { - None => candle::bail!("`id2label` must be set for classifier models"), - Some(id2label) => id2label.len(), - }; - - let pooler = if let Ok(pooler_weight) = vb - .pp("pooler.dense") - .get((config.hidden_size, config.hidden_size), "weight") - { - let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; - Some(Linear::new(pooler_weight, Some(pooler_bias), None)) - } else { - None - }; - - let classifier_weight = vb - .pp("classifier") - .get((n_classes, config.hidden_size), "weight")?; - let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?; - let classifier = Linear::new(classifier_weight, Some(classifier_bias), None); - - Ok(Self { - classifier, - pooler, - span: tracing::span!(tracing::Level::TRACE, "classifier"), - }) - } - - pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result { - let _enter = self.span.enter(); - - let mut hidden_states = hidden_states.unsqueeze(1)?; - if let Some(pooler) = self.pooler.as_ref() { - hidden_states = pooler.forward(&hidden_states)?; - hidden_states = hidden_states.tanh()?; - } - - let hidden_states = self.classifier.forward(&hidden_states)?; - let hidden_states = hidden_states.squeeze(1)?; - Ok(hidden_states) - } -} - pub struct FlashGTEModel { word_embeddings: Embedding, token_type_embeddings: Option, diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index bc4bfdce..ba360eaf 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -1,7 +1,10 @@ -use crate::layers::HiddenAct; -use crate::models::PositionEmbeddingType; +use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear}; +use crate::models::{apply_rotary, cos_sin, inv_freqs, Model, PositionEmbeddingType}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct NTKScaling { @@ -35,3 +38,563 @@ pub struct GTEConfig { pub logn_attention_clip1: bool, pub id2label: Option>, } + +struct GTEAttention { + qkv_linear: Linear, + o_proj: Linear, + + num_attention_heads: usize, + attention_head_size: usize, + + softmax_scale: f32, + + span: tracing::Span, +} + +impl GTEAttention { + pub fn load(vb: VarBuilder, config: >EConfig) -> Result { + let num_attention_heads = config.num_attention_heads; + let attention_head_size = config.hidden_size / config.num_attention_heads; + let hidden_size = config.hidden_size; + + let qkv_weight = vb + .pp("qkv_proj") + .get((hidden_size * 3, hidden_size), "weight")?; + let qkv_bias = vb.pp("qkv_proj").get(hidden_size * 3, "bias")?; + + let qkv_linear = Linear::new(qkv_weight, Some(qkv_bias), None); + + let o_proj_weight = vb.pp("o_proj").get((hidden_size, hidden_size), "weight")?; + let o_proj_bias = vb.pp("o_proj").get(hidden_size, "bias")?; + + let o_proj = Linear::new(o_proj_weight, Some(o_proj_bias), None); + + let softmax_scale = 1. / (attention_head_size as f64).sqrt() as f32; + + Ok(Self { + qkv_linear, + o_proj, + num_attention_heads, + attention_head_size, + softmax_scale, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let _enter = self.span.enter(); + let device = hidden_states.device(); + + let qkv = self.qkv_linear.forward(hidden_states)?; + + // Reshape to [tokens, heads, head_size] + let mut new_qkv_shape = qkv.dims().to_vec(); + new_qkv_shape.pop(); + new_qkv_shape.push(self.num_attention_heads * 3); + new_qkv_shape.push(self.attention_head_size); + + let qkv = qkv.reshape(new_qkv_shape)?; + + // Split qkv tensor + let q = qkv.narrow(1, 0, self.num_attention_heads)?; + let k = qkv.narrow(1, self.num_attention_heads, self.num_attention_heads)?; + let v = qkv.narrow(1, self.num_attention_heads * 2, self.num_attention_heads)?; + + let q = apply_rotary(&q, cos, sin, self.attention_head_size)?; + let k = apply_rotary(&k, cos, sin, self.attention_head_size)?; + + #[allow(unused_variables)] + let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = + (device, get_cublas_lt_wrapper()) + { + #[cfg(feature = "cuda")] + { + // cuBLASLt batch matmul implementation requires inputs to be dims3 + let (batch_size, _, seq_len, _) = k.shape().dims4()?; + let k = k.flatten(0, 1)?; + let q = q.flatten(0, 1)?; + let v = v.flatten(0, 1)?; + let attention_bias = attention_bias.map(|mask| mask.flatten(0, 1)).transpose()?; + + // If attention_bias is set, we fuse the add by giving it as the output matrix + // and setting beta to 1.0 + let beta = match attention_bias.is_some() { + true => Some(1.0), + false => None, + }; + + // Batch matrix multiplication + // Fuse softmax scale and attention_bias add + let attention_scores = cublaslt.batch_matmul( + &k, + &q, + attention_bias.as_ref(), + Some(self.softmax_scale as f32), + beta, + None, + None, + )?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + + let context_layer = cublaslt.batch_matmul( + &v.t()?.contiguous()?, + &attention_probs, + // We save one allocation + Some(&q), + None, + None, + None, + None, + )?; + + // Reshape to dims4 + context_layer.reshape(( + batch_size, + self.num_attention_heads, + seq_len, + self.attention_head_size, + )) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } else { + let attention_scores = q.matmul(&k.t()?)?; + let attention_scores = (attention_scores * self.softmax_scale as f64)?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + attention_probs.matmul(&v.contiguous()?) + }?; + + let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; + + let hidden_states = self.o_proj.forward(&context_layer)?; + + Ok(hidden_states) + } +} + +#[allow(clippy::upper_case_acronyms)] +pub struct GTEMLP { + up_gate_proj: Linear, + down_proj: Linear, + + act: HiddenAct, + intermediate_size: usize, + + span: tracing::Span, +} + +impl GTEMLP { + pub fn load(vb: VarBuilder, config: >EConfig) -> Result { + let intermediate_size = config.intermediate_size; + + let up_gate_proj_weight = vb + .pp("up_gate_proj") + .get((intermediate_size * 2, config.hidden_size), "weight")?; + + let up_gate_proj = Linear::new(up_gate_proj_weight, None, None); + + let down_proj_weight = vb + .pp("down_proj") + .get((config.hidden_size, intermediate_size), "weight")?; + let down_proj_bias = vb.pp("down_proj").get(config.hidden_size, "bias")?; + let down_proj = Linear::new(down_proj_weight, Some(down_proj_bias), None); + + Ok(Self { + up_gate_proj, + down_proj, + intermediate_size, + act: config.hidden_act.clone(), + span: tracing::span!(tracing::Level::TRACE, "mlp"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let up_gate_states = self.up_gate_proj.forward(hidden_states)?; + let up_states = up_gate_states.narrow(1, 0, self.intermediate_size)?; + let gate_states = + up_gate_states.narrow(1, self.intermediate_size, self.intermediate_size)?; + + let gate_states = match self.act { + HiddenAct::Gelu => gate_states.gelu(), + HiddenAct::Relu => gate_states.relu(), + HiddenAct::Swiglu => gate_states.silu(), + }?; + + self.down_proj.forward(&(gate_states * up_states)?) + } +} + +pub struct GTELayer { + attention: GTEAttention, + mlp: GTEMLP, + attention_layer_norm: LayerNorm, + mlp_layer_norm: LayerNorm, + + span: tracing::Span, +} + +impl GTELayer { + pub fn load(vb: VarBuilder, config: >EConfig) -> Result { + let attention = GTEAttention::load(vb.pp("attention"), config)?; + let mlp = GTEMLP::load(vb.pp("mlp"), config)?; + + let attention_layer_norm = + LayerNorm::load(vb.pp("attn_ln"), config.hidden_size, config.layer_norm_eps)?; + let mlp_layer_norm = + LayerNorm::load(vb.pp("mlp_ln"), config.hidden_size, config.layer_norm_eps)?; + + Ok(Self { + attention, + mlp, + attention_layer_norm, + mlp_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let _enter = self.span.enter(); + let attn_output = self.attention.forward(hidden_states, cos, sin)?; + let normed_attn_res_output = self + .attention_layer_norm + .forward(&attn_output, Some(hidden_states))?; + + let mlp_output = self.mlp.forward(&normed_attn_res_output)?; + let normed_mlp_res_output = self + .mlp_layer_norm + .forward(&mlp_output, Some(&normed_attn_res_output))?; + Ok(normed_mlp_res_output) + } +} + +pub struct GTEClassificationHead { + pooler: Option, + classifier: Linear, + span: tracing::Span, +} + +impl GTEClassificationHead { + #[allow(dead_code)] + pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { + let n_classes = match &config.id2label { + None => candle::bail!("`id2label` must be set for classifier models"), + Some(id2label) => id2label.len(), + }; + + let pooler = if let Ok(pooler_weight) = vb + .pp("pooler.dense") + .get((config.hidden_size, config.hidden_size), "weight") + { + let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; + Some(Linear::new(pooler_weight, Some(pooler_bias), None)) + } else { + None + }; + + let classifier_weight = vb + .pp("classifier") + .get((n_classes, config.hidden_size), "weight")?; + let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?; + let classifier = Linear::new(classifier_weight, Some(classifier_bias), None); + + Ok(Self { + classifier, + pooler, + span: tracing::span!(tracing::Level::TRACE, "classifier"), + }) + } + + pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.unsqueeze(1)?; + if let Some(pooler) = self.pooler.as_ref() { + hidden_states = pooler.forward(&hidden_states)?; + hidden_states = hidden_states.tanh()?; + } + + let hidden_states = self.classifier.forward(&hidden_states)?; + let hidden_states = hidden_states.squeeze(1)?; + Ok(hidden_states) + } +} + +pub struct GTEModel { + word_embeddings: Embedding, + token_type_embeddings: Option, + layers: Vec, + embeddings_norm: LayerNorm, + cos_cache: Tensor, + sin_cache: Tensor, + classifier: Option, + pool: Pool, + pub device: Device, + + span: tracing::Span, +} + +impl GTEModel { + pub fn load(vb: VarBuilder, config: >EConfig, model_type: ModelType) -> Result { + match vb.device() { + Device::Cuda(_) => {} + _ => candle::bail!("FlashGTE requires Cuda"), + } + + if vb.dtype() != DType::F16 { + candle::bail!("FlashGTE requires DType::F16") + } + + if config.logn_attention_clip1 { + candle::bail!("`logn_attention_clip1` is not supported"); + } + if config.logn_attention_scale { + candle::bail!("`logn_attention_scale` is not supported"); + } + + if config.position_embedding_type != PositionEmbeddingType::Rope { + candle::bail!("Only `PositionEmbeddingType::Rope` is supported"); + } + + let (pool, classifier) = match model_type { + ModelType::Classifier => { + let pool = Pool::Cls; + + let classifier = GTEClassificationHead::load(vb.clone(), config)?; + (pool, Some(classifier)) + } + ModelType::Embedding(pool) => (pool, None), + }; + + let word_embeddings = Embedding::new( + vb.pp("embeddings.word_embeddings") + .get((config.vocab_size, config.hidden_size), "weight")?, + config.hidden_size, + ); + + let token_type_embeddings = if config.type_vocab_size > 0 { + Some(Embedding::new( + vb.pp("embeddings.token_type_embeddings") + .get((config.type_vocab_size, config.hidden_size), "weight")?, + config.hidden_size, + )) + } else { + None + }; + + let layers = (0..config.num_hidden_layers) + .map(|index| GTELayer::load(vb.pp(format!("encoder.layer.{index}")), config)) + .collect::>>()?; + + let embeddings_norm = LayerNorm::load( + vb.pp("embeddings.LayerNorm"), + config.hidden_size, + config.layer_norm_eps, + )?; + + let inv_freqs = if let Some(RopeScaling::Ntk(NTKScaling { factor })) = config.rope_scaling { + let inv_freqs = inv_freqs( + layers[0].attention.attention_head_size, + config.rope_theta * factor, + vb.device(), + )?; + let s = factor.powf(2.0 / layers[0].attention.attention_head_size as f32) as f64; + inv_freqs / s + } else { + inv_freqs( + layers[0].attention.attention_head_size, + config.rope_theta, + vb.device(), + ) + }?; + + let (cos_cache, sin_cache) = + cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?; + + Ok(Self { + word_embeddings, + token_type_embeddings, + layers, + embeddings_norm, + cos_cache, + sin_cache, + classifier, + pool, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, batch: Batch) -> Result<(Option, Option)> { + let _enter = self.span.enter(); + + let batch_size = batch.cumulative_seq_lengths.len() - 1; + let shape = batch.input_ids.len(); + + // Create Cuda tensors + let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?; + let token_type_ids = Tensor::from_vec(batch.token_type_ids, shape, &self.device)?; + let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?; + let cu_seqlens = Tensor::from_vec( + batch.cumulative_seq_lengths.clone(), + batch_size + 1, + &self.device, + )?; + + let word_embeddings = self.word_embeddings.forward(&input_ids)?; + let token_type_embeddings = self + .token_type_embeddings + .as_ref() + .map(|emb| emb.forward(&token_type_ids)) + .transpose()?; + + let mut hidden_states = self + .embeddings_norm + .forward(&word_embeddings, token_type_embeddings.as_ref())?; + + let cos = self.cos_cache.index_select(&position_ids, 0)?; + let sin = self.sin_cache.index_select(&position_ids, 0)?; + + for layer in &self.layers { + let h = layer.forward(&hidden_states, &cos, &sin)?; + hidden_states = h; + } + + let outputs = hidden_states; + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + let has_raw_requests = !batch.raw_indices.is_empty(); + + let pooled_embeddings = if has_pooling_requests { + match self.pool { + // CLS and LastToken pooling + Pool::Cls | Pool::LastToken => { + if batch_size > 1 { + // Get token indices form cu_seqlens + let mut indices = match self.pool { + Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?, + Pool::LastToken => { + let end = cu_seqlens.narrow(0, 1, batch_size)?; + (&end - &end.ones_like()?)? + } + _ => unreachable!(), + }; + + // If raw_indices is empty, we don't need to do anything with + // the pooled_indices + if has_raw_requests { + // We need the pooled indices to select the correct cls indices + let pooled_indices = Tensor::from_vec( + batch.pooled_indices.clone(), + batch.pooled_indices.len(), + &self.device, + )?; + + // Only select indices that requires pooling + indices = indices.index_select(&pooled_indices, 0)? + } + + // Select tokens + Some(outputs.index_select(&indices, 0)?) + } else { + Some( + match self.pool { + Pool::Cls => outputs.i(0)?, + Pool::LastToken => { + outputs.i(batch.cumulative_seq_lengths[1] as usize - 1)? + } + _ => unreachable!(), + } + .unsqueeze(0)?, + ) + } + } + // Mean pooling + Pool::Mean => { + if batch_size > 1 { + // for each request that requires pooling + let results: Result> = batch + .pooled_indices + .into_iter() + .map(|i| { + let i = i as usize; + let start = batch.cumulative_seq_lengths[i]; + let len = batch.cumulative_seq_lengths[i + 1] - start; + + // Mean + let embeddings = outputs.narrow(0, start as usize, len as usize)?; + embeddings.sum_keepdim(0)? / (len as f64) + }) + .collect(); + + // Concatenate all results + Some(Tensor::cat(&results?, 0)?) + } else { + Some((outputs.sum_keepdim(0)? / (batch.max_length as f64))?) + } + } + Pool::Splade => { + unreachable!(); + } + } + } else { + None + }; + + let raw_embeddings = if has_raw_requests { + if batch_size > 1 && has_pooling_requests { + // Create indexing vector for the embeddings + let mut final_indices: Vec = Vec::with_capacity(shape); + for i in batch.raw_indices.into_iter() { + let i = i as usize; + // Get start/end token index of this specific member of the batch + let start = batch.cumulative_seq_lengths[i]; + let end = batch.cumulative_seq_lengths[i + 1]; + + for j in start..end { + // Add indices for the tokens of this specific member of the batch + final_indices.push(j); + } + } + + let final_indices_length = final_indices.len(); + let final_indices = + Tensor::from_vec(final_indices, final_indices_length, &self.device)?; + + // Select the tokens with final indices + Some(outputs.index_select(&final_indices, 0)?) + } else { + Some(outputs) + } + } else { + None + }; + + Ok((pooled_embeddings, raw_embeddings)) + } +} + +impl Model for GTEModel { + fn is_padded(&self) -> bool { + false + } + + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { + self.forward(batch) + } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classifier.forward(&pooled_embeddings) + } + } + } +} diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index b1e9f937..2c7b2322 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -40,11 +40,11 @@ pub use bert::{BertConfig, BertModel, PositionEmbeddingType}; use candle::{Result, Tensor}; pub use distilbert::{DistilBertConfig, DistilBertModel}; #[allow(unused_imports)] -pub use gte::{GTEConfig, NTKScaling, RopeScaling}; +pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, NTKScaling, RopeScaling, GTEMLP}; pub use jina::JinaBertModel; pub use jina_code::JinaCodeBertModel; pub use mistral::MistralConfig; -pub use nomic::{NomicBertModel, NomicConfig}; +pub use nomic::{apply_rotary, cos_sin, inv_freqs, NomicBertModel, NomicConfig}; pub use qwen2::Qwen2Config; use text_embeddings_backend_core::Batch; diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index cdaaea92..59c3b881 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -176,15 +176,6 @@ impl NomicAttention { }) } - fn apply_rotary(&self, x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { - let dim = self.attention_head_size / 2; - let x1 = x.narrow(D::Minus1, 0, dim)?; - let x2 = x.narrow(D::Minus1, dim, dim)?; - let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; - let rope = (x.broadcast_mul(cos)? + rotate_x.broadcast_mul(sin)?)?; - Ok(rope) - } - pub fn forward( &self, hidden_states: &Tensor, @@ -208,8 +199,8 @@ impl NomicAttention { let key_layer = &qkv[1].contiguous()?; let value_layer = &qkv[2]; - let query_layer = self.apply_rotary(query_layer, cos, sin)?; - let key_layer = self.apply_rotary(key_layer, cos, sin)?; + let query_layer = apply_rotary(query_layer, cos, sin, self.attention_head_size)?; + let key_layer = apply_rotary(key_layer, cos, sin, self.attention_head_size)?; #[allow(unused_variables)] let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = @@ -699,6 +690,20 @@ pub fn cos_sin(length: usize, inv_freqs: &Tensor, dtype: DType) -> Result<(Tenso Ok((cos, sin)) } +pub fn apply_rotary( + x: &Tensor, + cos: &Tensor, + sin: &Tensor, + attention_head_size: usize, +) -> Result { + let dim = attention_head_size / 2; + let x1 = x.narrow(D::Minus1, 0, dim)?; + let x2 = x.narrow(D::Minus1, dim, dim)?; + let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; + let rope = (x.broadcast_mul(cos)? + rotate_x.broadcast_mul(sin)?)?; + Ok(rope) +} + impl Model for NomicBertModel { fn is_padded(&self) -> bool { false From e0e3bc6ca833fb49e2438f3d6d28bd6805104786 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 30 Nov 2024 12:28:33 +0900 Subject: [PATCH 2/7] update: enable GTEModel for cuda --- backends/candle/src/lib.rs | 3 ++- backends/candle/src/models/gte.rs | 11 +---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index c8f762dd..c52cd776 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -349,7 +349,8 @@ impl CandleBackend { if dtype != DType::F16 || !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) { - return Err(BackendError::Start("GTE is only supported on Cuda devices in fp16 with flash attention enabled".to_string())); + tracing::info!("Starting GTE model on {:?}", device); + Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?)) } tracing::info!("Starting FlashGTE model on {:?}", device); Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?)) diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index ba360eaf..c5830c39 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -1,6 +1,6 @@ use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear}; use crate::models::{apply_rotary, cos_sin, inv_freqs, Model, PositionEmbeddingType}; -use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle::{Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; @@ -339,15 +339,6 @@ pub struct GTEModel { impl GTEModel { pub fn load(vb: VarBuilder, config: >EConfig, model_type: ModelType) -> Result { - match vb.device() { - Device::Cuda(_) => {} - _ => candle::bail!("FlashGTE requires Cuda"), - } - - if vb.dtype() != DType::F16 { - candle::bail!("FlashGTE requires DType::F16") - } - if config.logn_attention_clip1 { candle::bail!("`logn_attention_clip1` is not supported"); } From ebe6ac964e6400d1322a5e6e24f427eab70640b7 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 30 Nov 2024 13:00:12 +0900 Subject: [PATCH 3/7] fix: GTE --- backends/candle/src/models/flash_gte.rs | 6 +- backends/candle/src/models/gte.rs | 111 +++++++++++------------- 2 files changed, 55 insertions(+), 62 deletions(-) diff --git a/backends/candle/src/models/flash_gte.rs b/backends/candle/src/models/flash_gte.rs index 13f97730..b6632d8b 100644 --- a/backends/candle/src/models/flash_gte.rs +++ b/backends/candle/src/models/flash_gte.rs @@ -1,5 +1,5 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{HiddenAct, LayerNorm, Linear}; +use crate::layers::{LayerNorm, Linear}; use crate::models::{ GTEClassificationHead, GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling, GTEMLP, }; @@ -132,7 +132,9 @@ impl GTELayer { max_s: usize, ) -> Result { let _enter = self.span.enter(); - let attn_output = self.attention.forward(&hidden_states, cos, sin)?; + let attn_output = self + .attention + .forward(&hidden_states, cu_seqlens, cos, sin, max_s)?; let normed_attn_res_output = self .attention_layer_norm .forward(&attn_output, Some(hidden_states))?; diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index c5830c39..ad085dba 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -104,67 +104,58 @@ impl GTEAttention { let k = apply_rotary(&k, cos, sin, self.attention_head_size)?; #[allow(unused_variables)] - let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = - (device, get_cublas_lt_wrapper()) - { - #[cfg(feature = "cuda")] - { - // cuBLASLt batch matmul implementation requires inputs to be dims3 - let (batch_size, _, seq_len, _) = k.shape().dims4()?; - let k = k.flatten(0, 1)?; - let q = q.flatten(0, 1)?; - let v = v.flatten(0, 1)?; - let attention_bias = attention_bias.map(|mask| mask.flatten(0, 1)).transpose()?; - - // If attention_bias is set, we fuse the add by giving it as the output matrix - // and setting beta to 1.0 - let beta = match attention_bias.is_some() { - true => Some(1.0), - false => None, - }; - - // Batch matrix multiplication - // Fuse softmax scale and attention_bias add - let attention_scores = cublaslt.batch_matmul( - &k, - &q, - attention_bias.as_ref(), - Some(self.softmax_scale as f32), - beta, - None, - None, - )?; + let context_layer = + if let (Device::Cuda(_), Some(cublaslt)) = (device, get_cublas_lt_wrapper()) { + #[cfg(feature = "cuda")] + { + // cuBLASLt batch matmul implementation requires inputs to be dims3 + let (batch_size, _, seq_len, _) = k.shape().dims4()?; + let k = k.flatten(0, 1)?; + let q = q.flatten(0, 1)?; + let v = v.flatten(0, 1)?; + + // Batch matrix multiplication + // Fuse softmax scale and attention_bias add + let attention_scores = cublaslt.batch_matmul( + &k, + &q, + attention_bias.as_ref(), + Some(self.softmax_scale as f32), + None, + None, + None, + )?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + + let context_layer = cublaslt.batch_matmul( + &v.t()?.contiguous()?, + &attention_probs, + // We save one allocation + Some(&q), + None, + None, + None, + None, + )?; + + // Reshape to dims4 + context_layer.reshape(( + batch_size, + self.num_attention_heads, + seq_len, + self.attention_head_size, + )) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } else { + let attention_scores = q.matmul(&k.t()?)?; + let attention_scores = (attention_scores * self.softmax_scale as f64)?; let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; - - let context_layer = cublaslt.batch_matmul( - &v.t()?.contiguous()?, - &attention_probs, - // We save one allocation - Some(&q), - None, - None, - None, - None, - )?; - - // Reshape to dims4 - context_layer.reshape(( - batch_size, - self.num_attention_heads, - seq_len, - self.attention_head_size, - )) - } - #[cfg(not(feature = "cuda"))] - { - candle::bail!("`cuda` feature is not enabled") - } - } else { - let attention_scores = q.matmul(&k.t()?)?; - let attention_scores = (attention_scores * self.softmax_scale as f64)?; - let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; - attention_probs.matmul(&v.contiguous()?) - }?; + attention_probs.matmul(&v.contiguous()?) + }?; let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; From 9cc25ab1ca85b9f4fab9393da5f21619987cef7a Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 30 Nov 2024 13:03:18 +0900 Subject: [PATCH 4/7] fix: attention_bias --- backends/candle/src/models/gte.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index ad085dba..2a04a256 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -119,7 +119,7 @@ impl GTEAttention { let attention_scores = cublaslt.batch_matmul( &k, &q, - attention_bias.as_ref(), + None, Some(self.softmax_scale as f32), None, None, From d5430f5bc81fd721c7e7f75ac601f55a3ad1c5c4 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 30 Nov 2024 13:08:45 +0900 Subject: [PATCH 5/7] fix: GTE model --- backends/candle/src/lib.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index c52cd776..aec84375 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -351,9 +351,10 @@ impl CandleBackend { { tracing::info!("Starting GTE model on {:?}", device); Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?)) + } else { + tracing::info!("Starting FlashGTE model on {:?}", device); + Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?)) } - tracing::info!("Starting FlashGTE model on {:?}", device); - Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?)) } #[cfg(feature = "cuda")] (Config::Qwen2(config), Device::Cuda(_)) => { From 005a69335a6b890ee8dccc45ba3202646f455fb1 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 30 Nov 2024 17:06:53 +0900 Subject: [PATCH 6/7] fix: gte --- backends/candle/src/models/gte.rs | 35 +++++++++++++---------------- backends/candle/src/models/mod.rs | 2 +- backends/candle/src/models/nomic.rs | 1 + 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index 2a04a256..016e5937 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -1,6 +1,6 @@ use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear}; -use crate::models::{apply_rotary, cos_sin, inv_freqs, Model, PositionEmbeddingType}; -use candle::{Device, IndexOp, Result, Tensor, D}; +use crate::models::{apply_rotary, inv_freqs, Model, PositionEmbeddingType}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; @@ -108,12 +108,6 @@ impl GTEAttention { if let (Device::Cuda(_), Some(cublaslt)) = (device, get_cublas_lt_wrapper()) { #[cfg(feature = "cuda")] { - // cuBLASLt batch matmul implementation requires inputs to be dims3 - let (batch_size, _, seq_len, _) = k.shape().dims4()?; - let k = k.flatten(0, 1)?; - let q = q.flatten(0, 1)?; - let v = v.flatten(0, 1)?; - // Batch matrix multiplication // Fuse softmax scale and attention_bias add let attention_scores = cublaslt.batch_matmul( @@ -127,7 +121,7 @@ impl GTEAttention { )?; let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; - let context_layer = cublaslt.batch_matmul( + cublaslt.batch_matmul( &v.t()?.contiguous()?, &attention_probs, // We save one allocation @@ -136,15 +130,7 @@ impl GTEAttention { None, None, None, - )?; - - // Reshape to dims4 - context_layer.reshape(( - batch_size, - self.num_attention_heads, - seq_len, - self.attention_head_size, - )) + ) } #[cfg(not(feature = "cuda"))] { @@ -157,7 +143,7 @@ impl GTEAttention { attention_probs.matmul(&v.contiguous()?) }?; - let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; + let context_layer = context_layer.flatten_from(D::Minus2)?; let hidden_states = self.o_proj.forward(&context_layer)?; @@ -580,3 +566,14 @@ impl Model for GTEModel { } } } + +fn cos_sin(length: usize, inv_freqs: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> { + let t = Tensor::arange(0u32, length as u32, inv_freqs.device())? + .to_dtype(DType::F32)? + .reshape((length, 1))?; + + let freqs = t.matmul(inv_freqs)?; + let cos = freqs.cos()?.to_dtype(dtype)?; + let sin = freqs.sin()?.to_dtype(dtype)?; + Ok((cos, sin)) +} diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index 2c7b2322..ff10518f 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -44,7 +44,7 @@ pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, NTKScaling, RopeScalin pub use jina::JinaBertModel; pub use jina_code::JinaCodeBertModel; pub use mistral::MistralConfig; -pub use nomic::{apply_rotary, cos_sin, inv_freqs, NomicBertModel, NomicConfig}; +pub use nomic::{apply_rotary, inv_freqs, NomicBertModel, NomicConfig}; pub use qwen2::Qwen2Config; use text_embeddings_backend_core::Batch; diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index 59c3b881..3fd3f645 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -708,6 +708,7 @@ impl Model for NomicBertModel { fn is_padded(&self) -> bool { false } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } From 43ce2c3995c869e9f70a812dde8a3f946ee4ed84 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 30 Nov 2024 18:01:36 +0900 Subject: [PATCH 7/7] fix: GTE --- backends/candle/src/models/gte.rs | 18 +++++------------- backends/candle/src/models/mod.rs | 2 +- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index 016e5937..10f29eb6 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -1,6 +1,6 @@ use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear}; -use crate::models::{apply_rotary, inv_freqs, Model, PositionEmbeddingType}; -use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use crate::models::{apply_rotary, cos_sin, inv_freqs, Model, PositionEmbeddingType}; +use candle::{Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; @@ -426,6 +426,9 @@ impl GTEModel { let cos = self.cos_cache.index_select(&position_ids, 0)?; let sin = self.sin_cache.index_select(&position_ids, 0)?; + let cos = cos.unsqueeze(1)?; + let sin = sin.unsqueeze(1)?; + for layer in &self.layers { let h = layer.forward(&hidden_states, &cos, &sin)?; hidden_states = h; @@ -566,14 +569,3 @@ impl Model for GTEModel { } } } - -fn cos_sin(length: usize, inv_freqs: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> { - let t = Tensor::arange(0u32, length as u32, inv_freqs.device())? - .to_dtype(DType::F32)? - .reshape((length, 1))?; - - let freqs = t.matmul(inv_freqs)?; - let cos = freqs.cos()?.to_dtype(dtype)?; - let sin = freqs.sin()?.to_dtype(dtype)?; - Ok((cos, sin)) -} diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index ff10518f..2c7b2322 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -44,7 +44,7 @@ pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, NTKScaling, RopeScalin pub use jina::JinaBertModel; pub use jina_code::JinaCodeBertModel; pub use mistral::MistralConfig; -pub use nomic::{apply_rotary, inv_freqs, NomicBertModel, NomicConfig}; +pub use nomic::{apply_rotary, cos_sin, inv_freqs, NomicBertModel, NomicConfig}; pub use qwen2::Qwen2Config; use text_embeddings_backend_core::Batch;