Skip to content

Implement the ModernBert model #459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Ember, GTE and E5. TEI implements many features such as:
#### Text Embeddings

Text Embeddings Inference currently supports Nomic, BERT, CamemBERT, XLM-RoBERTa models with absolute positions, JinaBERT
model with Alibi positions and Mistral, Alibaba GTE, Qwen2 models with Rope positions, and MPNet.
model with Alibi positions and Mistral, Alibaba GTE, Qwen2 models with Rope positions, MPNet, and ModernBERT.

Below are some examples of the currently supported models:

Expand All @@ -85,6 +85,7 @@ Below are some examples of the currently supported models:
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) |
| N/A | 0.1B | MPNet | [sentence-transformers/all-mpnet-base-v2](https://hf.co/sentence-transformers/all-mpnet-base-v2) |
| N/A | 0.4B | ModernBERT | [answerdotai/ModernBERT-large](https://hf.co/answerdotai/ModernBERT-large) |

To explore the list of best performing text embeddings models, visit the
[Massive Text Embedding Benchmark (MTEB) Leaderboard](https://huggingface.co/spaces/mteb/leaderboard).
Expand Down
11 changes: 9 additions & 2 deletions backends/candle/src/flash_attn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ pub(crate) fn flash_attn_varlen(
softmax_scale: f32,
causal: bool,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor, candle::Error> {
let runtime_compute_cap = get_runtime_compute_cap();

if runtime_compute_cap == 75 {
if alibi_slopes.is_some() {
candle::bail!("Flash attention v1 does not support alibi");
}
if window_size_left.is_some() {
if window_size_left.is_some() | window_size_right.is_some() {
candle::bail!("Flash attention v1 does not support attention windowing");
}

Expand All @@ -65,7 +66,13 @@ pub(crate) fn flash_attn_varlen(
{
use candle_flash_attn::{flash_attn_varlen_alibi_windowed, flash_attn_varlen_windowed};

let window_size_right = if causal { Some(0) } else { None };
let window_size_right = if causal {
Some(0)
} else if window_size_right.is_some() {
window_size_right
} else {
None
};

let attention = if let Some(alibi_slopes) = alibi_slopes {
flash_attn_varlen_alibi_windowed(
Expand Down
79 changes: 79 additions & 0 deletions backends/candle/src/layers/layer_norm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,84 @@
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::VarBuilder;

#[derive(Debug)]
pub struct LayerNormNoBias {
weight: Tensor,
epsilon: f32,
span: tracing::Span,
}

impl LayerNormNoBias {
pub fn load(vb: VarBuilder, hidden_size: usize, epsilon: f32) -> Result<Self> {
Ok(Self {
weight: vb
.get(hidden_size, "weight")
.or_else(|_| vb.get(hidden_size, "gamma"))?,
epsilon,
span: tracing::span!(tracing::Level::TRACE, "layer-norm-no-bias"),
})
}

pub fn forward(&self, hidden_states: &Tensor, residual: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();

match hidden_states.device() {
Device::Cpu | Device::Metal(_) => {
let mut hidden_states = hidden_states.clone();
if let Some(residual) = residual {
hidden_states = hidden_states.add(residual)?;
}
let hidden_states_dtype = hidden_states.dtype();
let internal_dtype = match hidden_states_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = hidden_states.dim(D::Minus1)?;
let hidden_states = hidden_states.to_dtype(internal_dtype)?;
let mean_hidden_states =
(hidden_states.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let hidden_states = hidden_states.broadcast_sub(&mean_hidden_states)?;
let norm_hidden_states =
(hidden_states.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let hidden_states_normed = hidden_states
.broadcast_div(&(norm_hidden_states + self.epsilon as f64)?.sqrt()?)?;
let hidden_states = hidden_states_normed
.to_dtype(hidden_states_dtype)?
.broadcast_mul(&self.weight)?;

Ok(hidden_states)
}
Device::Cuda(_) => {
#[cfg(feature = "cuda")]
{
use candle_layer_norm::{fused_add_layer_norm, layer_norm};

let original_shape = hidden_states.shape();
let hidden_states = hidden_states.flatten_to(D::Minus2)?;

let result = if let Some(residual) = residual {
let residual = residual.flatten_to(D::Minus2)?;

let (result, _) = fused_add_layer_norm(
&hidden_states,
&residual,
&self.weight,
None,
self.epsilon,
)?;
Ok(result)
} else {
layer_norm(&hidden_states, &self.weight, None, self.epsilon)
}?;
result.reshape(original_shape)
}
#[cfg(not(feature = "cuda"))]
candle::bail!("`cuda` feature is not enabled")
}
}
}
}

#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
Expand Down Expand Up @@ -49,6 +127,7 @@ impl LayerNorm {
let hidden_states = hidden_states_normed
.to_dtype(hidden_states_dtype)?
.broadcast_mul(&self.weight)?;

hidden_states.broadcast_add(&self.bias)
}
Device::Cuda(_) => {
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod rms_norm;
mod rotary;

pub use cublaslt::get_cublas_lt_wrapper;
pub use layer_norm::LayerNorm;
pub use layer_norm::{LayerNorm, LayerNormNoBias};
pub use linear::{HiddenAct, Linear};
#[allow(unused_imports)]
pub use rms_norm::RMSNorm;
Expand Down
19 changes: 17 additions & 2 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use crate::compute_cap::{
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, NomicBertModel, NomicConfig,
Qwen2Config,
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, ModernBertConfig,
ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config,
};
#[cfg(feature = "cuda")]
use crate::models::{
Expand Down Expand Up @@ -65,6 +65,8 @@ enum Config {
Qwen2(Qwen2Config),
#[serde(rename = "mpnet")]
MPNet(MPNetConfig),
#[serde(rename(deserialize = "modernbert"))]
ModernBert(ModernBertConfig),
}

pub struct CandleBackend {
Expand Down Expand Up @@ -235,6 +237,19 @@ impl CandleBackend {
tracing::info!("Starting MPNet model on {:?}", device);
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
}
(Config::ModernBert(config), _) => match device {
Device::Metal(_) => {
return Err(BackendError::Start(
"ModernBert is not currently supported on MPS device".to_string(),
));
}
_ => {
tracing::info!("Starting ModernBert model on {:?}", device);
Ok(Box::new(
ModernBertModel::load(vb, &config, model_type).s()?,
))
}
},
#[cfg(feature = "cuda")]
(Config::Bert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ impl BertAttention {
self.softmax_scale,
false,
None,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_distilbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl DistilBertAttention {
self.softmax_scale,
false,
None,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ impl GTEAttention {
self.softmax_scale,
false,
None,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ impl JinaAttention {
self.softmax_scale,
false,
None,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ impl JinaCodeAttention {
self.softmax_scale,
false,
None,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ impl MistralAttention {
self.softmax_scale,
true,
self.window_size_left,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
Loading