Skip to content

feat: Implement GTE model to support the non-flash-attn version #446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/candle/src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ mod layer_norm;
mod linear;
#[allow(dead_code, unused)]
mod rms_norm;
mod rotary;

pub use cublaslt::get_cublas_lt_wrapper;
pub use layer_norm::LayerNorm;
pub use linear::{HiddenAct, Linear};
#[allow(unused_imports)]
pub use rms_norm::RMSNorm;
pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling};
73 changes: 73 additions & 0 deletions backends/candle/src/layers/rotary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use candle::{DType, Device, Result, Tensor, D};
use serde::Deserialize;

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct NTKScaling {
pub factor: f32,
}

#[derive(Debug, Clone, PartialEq, Deserialize)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum RopeScaling {
Ntk(NTKScaling),
}

pub fn get_inv_freqs(
dim: usize,
base: f32,
device: &Device,
rope_scaling: Option<&RopeScaling>,
) -> Result<Tensor> {
let get_inv_freqs_inner = |dim: usize, base: f32, device: &Device| {
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / base.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
Tensor::from_vec(inv_freq, (1, inv_freq_len), device)
};

if let Some(rope_scaling) = rope_scaling {
match rope_scaling {
RopeScaling::Ntk(ntk_scaling) => {
let inv_freqs = get_inv_freqs_inner(dim, base * ntk_scaling.factor, device)?;
let s = ntk_scaling.factor.powf(2.0 / dim as f32) as f64;
return inv_freqs / s;
}
}
}
get_inv_freqs_inner(dim, base, device)
}

pub fn get_cos_sin(
length: usize,
inv_freqs: &Tensor,
dtype: DType,
repeat_freqs: bool,
) -> Result<(Tensor, Tensor)> {
let t = Tensor::arange(0u32, length as u32, inv_freqs.device())?
.to_dtype(DType::F32)?
.reshape((length, 1))?;
let mut freqs = t.matmul(inv_freqs)?;
if repeat_freqs {
freqs = Tensor::cat(&[&freqs, &freqs], 1)?;
}

let cos = freqs.cos()?.to_dtype(dtype)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
Ok((cos, sin))
}

pub fn apply_rotary(
x: &Tensor,
cos: &Tensor,
sin: &Tensor,
attention_head_size: usize,
) -> Result<Tensor> {
let dim = attention_head_size / 2;
let x1 = x.narrow(D::Minus1, 0, dim)?;
let x2 = x.narrow(D::Minus1, dim, dim)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(cos)? + rotate_x.broadcast_mul(sin)?)?;
Ok(rope)
}
18 changes: 10 additions & 8 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::compute_cap::{
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, JinaBertModel,
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig, Qwen2Config,
};
#[cfg(feature = "cuda")]
Expand Down Expand Up @@ -218,10 +218,10 @@ impl CandleBackend {
"Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
(Config::Gte(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
"GTE is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
(Config::Gte(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting GTE model on {:?}", device);
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
}
(Config::Qwen2(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
"Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
Expand Down Expand Up @@ -349,10 +349,12 @@ impl CandleBackend {
if dtype != DType::F16
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
{
return Err(BackendError::Start("GTE is only supported on Cuda devices in fp16 with flash attention enabled".to_string()));
tracing::info!("Starting GTE model on {:?}", device);
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
} else {
tracing::info!("Starting FlashGTE model on {:?}", device);
Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?))
}
tracing::info!("Starting FlashGTE model on {:?}", device);
Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?))
}
#[cfg(feature = "cuda")]
(Config::Qwen2(config), Device::Cuda(_)) => {
Expand Down
145 changes: 18 additions & 127 deletions backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, LayerNorm, Linear};
use crate::models::{GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling};
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear};
use crate::models::{GTEClassificationHead, GTEConfig, Model, PositionEmbeddingType, GTEMLP};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct GTEAttention {
Expand Down Expand Up @@ -72,7 +73,7 @@ impl GTEAttention {
let k = qkv.narrow(1, self.num_attention_heads, self.num_attention_heads)?;
let v = qkv.narrow(1, self.num_attention_heads * 2, self.num_attention_heads)?;

candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;

let attention = flash_attn_varlen(
&q,
Expand All @@ -93,60 +94,7 @@ impl GTEAttention {
}
}

struct GTEMLP {
up_gate_proj: Linear,
down_proj: Linear,

act: HiddenAct,
intermediate_size: usize,

span: tracing::Span,
}

impl GTEMLP {
pub fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
let intermediate_size = config.intermediate_size;

let up_gate_proj_weight = vb
.pp("up_gate_proj")
.get((intermediate_size * 2, config.hidden_size), "weight")?;

let up_gate_proj = Linear::new(up_gate_proj_weight, None, None);

let down_proj_weight = vb
.pp("down_proj")
.get((config.hidden_size, intermediate_size), "weight")?;
let down_proj_bias = vb.pp("down_proj").get(config.hidden_size, "bias")?;
let down_proj = Linear::new(down_proj_weight, Some(down_proj_bias), None);

Ok(Self {
up_gate_proj,
down_proj,
intermediate_size,
act: config.hidden_act.clone(),
span: tracing::span!(tracing::Level::TRACE, "mlp"),
})
}

pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let up_gate_states = self.up_gate_proj.forward(hidden_states)?;
let up_states = up_gate_states.narrow(1, 0, self.intermediate_size)?;
let gate_states =
up_gate_states.narrow(1, self.intermediate_size, self.intermediate_size)?;

let gate_states = match self.act {
HiddenAct::Gelu => gate_states.gelu(),
HiddenAct::Relu => gate_states.relu(),
HiddenAct::Swiglu => gate_states.silu(),
}?;
let r = self.down_proj.forward(&(gate_states * up_states)?);
r
}
}

struct GTELayer {
pub struct GTELayer {
attention: GTEAttention,
mlp: GTEMLP,
attention_layer_norm: LayerNorm,
Expand Down Expand Up @@ -198,58 +146,6 @@ impl GTELayer {
}
}

pub struct GTEClassificationHead {
pooler: Option<Linear>,
classifier: Linear,
span: tracing::Span,
}

impl GTEClassificationHead {
#[allow(dead_code)]
pub(crate) fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
let n_classes = match &config.id2label {
None => candle::bail!("`id2label` must be set for classifier models"),
Some(id2label) => id2label.len(),
};

let pooler = if let Ok(pooler_weight) = vb
.pp("pooler.dense")
.get((config.hidden_size, config.hidden_size), "weight")
{
let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?;
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
} else {
None
};

let classifier_weight = vb
.pp("classifier")
.get((n_classes, config.hidden_size), "weight")?;
let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?;
let classifier = Linear::new(classifier_weight, Some(classifier_bias), None);

Ok(Self {
classifier,
pooler,
span: tracing::span!(tracing::Level::TRACE, "classifier"),
})
}

pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let mut hidden_states = hidden_states.unsqueeze(1)?;
if let Some(pooler) = self.pooler.as_ref() {
hidden_states = pooler.forward(&hidden_states)?;
hidden_states = hidden_states.tanh()?;
}

let hidden_states = self.classifier.forward(&hidden_states)?;
let hidden_states = hidden_states.squeeze(1)?;
Ok(hidden_states)
}
}

pub struct FlashGTEModel {
word_embeddings: Embedding,
token_type_embeddings: Option<Embedding>,
Expand Down Expand Up @@ -322,24 +218,19 @@ impl FlashGTEModel {
config.layer_norm_eps,
)?;

let inv_freqs = if let Some(RopeScaling::Ntk(NTKScaling { factor })) = config.rope_scaling {
let inv_freqs = candle_rotary::inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta * factor,
vb.device(),
)?;
let s = factor.powf(2.0 / layers[0].attention.attention_head_size as f32) as f64;
inv_freqs / s
} else {
candle_rotary::inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
vb.device(),
)
}?;

let (cos_cache, sin_cache) =
candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?;
let inv_freqs = get_inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
vb.device(),
config.rope_scaling.as_ref(),
)?;

let (cos_cache, sin_cache) = get_cos_sin(
config.max_position_embeddings,
&inv_freqs,
vb.dtype(),
false,
)?;

Ok(Self {
word_embeddings,
Expand Down
16 changes: 11 additions & 5 deletions backends/candle/src/models/flash_mistral.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, Linear, RMSNorm};
use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm};
use crate::models::{MistralConfig, Model};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct MistralAttention {
Expand Down Expand Up @@ -90,7 +91,7 @@ impl MistralAttention {
self.num_key_value_heads,
)?;

candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;

let attention = flash_attn_varlen(
&q,
Expand Down Expand Up @@ -267,13 +268,18 @@ impl FlashMistralModel {

let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;

let inv_freqs = candle_rotary::inv_freqs(
let inv_freqs = get_inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
vb.device(),
None,
)?;
let (cos_cache, sin_cache) = get_cos_sin(
config.max_position_embeddings,
&inv_freqs,
vb.dtype(),
false,
)?;
let (cos_cache, sin_cache) =
candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?;

Ok(Self {
embeddings,
Expand Down
14 changes: 8 additions & 6 deletions backends/candle/src/models/flash_nomic.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{LayerNorm, Linear};
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear};
use crate::models::nomic::{NomicBertEmbeddings, NomicBertGatedMLP};
use crate::models::{Model, NomicConfig};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::VarBuilder;
use candle_rotary::apply_rotary_inplace;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct NomicAttention {
Expand Down Expand Up @@ -68,7 +69,7 @@ impl NomicAttention {
let qkv = qkv.reshape(new_qkv_shape.as_slice())?;
let qkv = qkv.chunk(3, 1)?;

candle_rotary::apply_rotary_inplace(&qkv[0], &qkv[1], &cos, &sin, true)?;
apply_rotary_inplace(&qkv[0], &qkv[1], &cos, &sin, true)?;

let attention = flash_attn_varlen(
&qkv[0],
Expand Down Expand Up @@ -221,20 +222,21 @@ impl FlashNomicBertModel {
let encoder = NomicBertEncoder::load(vb.pp("encoder"), config)?;

let rotary_dim = encoder.layers[0].attention.attention_head_size;
let inv_freqs = candle_rotary::inv_freqs(rotary_dim, config.rotary_emb_base, vb.device())?;
let rotary_cache = candle_rotary::cos_sin(config.n_positions, &inv_freqs, vb.dtype())?;
let inv_freqs = get_inv_freqs(rotary_dim, config.rotary_emb_base, vb.device(), None)?;
let rotary_cache = get_cos_sin(config.n_positions, &inv_freqs, vb.dtype(), false)?;

let scaled_rotary_cache = if let Some(scaling_factor) = config.rotary_scaling_factor {
let new_base = (config.rotary_emb_base
* ((scaling_factor * config.n_positions as f32
/ config.max_trained_positions as f32)
- (scaling_factor - 1.0)))
.powi((rotary_dim as f32 / (rotary_dim as f32 - 2.0)) as i32);
let inv_freqs = candle_rotary::inv_freqs(rotary_dim, new_base, vb.device())?;
Some(candle_rotary::cos_sin(
let inv_freqs = get_inv_freqs(rotary_dim, new_base, vb.device(), None)?;
Some(get_cos_sin(
config.n_positions,
&inv_freqs,
vb.dtype(),
false,
)?)
} else {
None
Expand Down
Loading
Loading