Skip to content

Fixing FlashAttention ModernBert. #560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ use crate::models::{
#[cfg(feature = "cuda")]
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashGTEModel, FlashJinaBertModel,
FlashJinaCodeBertModel, FlashMistralModel, FlashNomicBertModel, FlashQwen2Model,
FlashJinaCodeBertModel, FlashMistralModel, FlashModernBertModel, FlashNomicBertModel,
FlashQwen2Model,
};
use anyhow::Context;
use candle::{DType, Device};
Expand Down Expand Up @@ -276,7 +277,7 @@ impl CandleBackend {
tracing::info!("Starting MPNet model on {:?}", device);
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
}
(Config::ModernBert(config), _) => match device {
(Config::ModernBert(config), Device::Cpu | Device::Metal(_)) => match device {
Device::Metal(_) => {
return Err(BackendError::Start(
"ModernBert is not currently supported on MPS device".to_string(),
Expand Down Expand Up @@ -357,6 +358,27 @@ impl CandleBackend {
}
}
#[cfg(feature = "cuda")]
(Config::ModernBert(config), Device::Cuda(_)) => {
if cfg!(feature = "flash-attn")
&& dtype == DType::F16
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
Copy link
Member

@ivarflakstad ivarflakstad Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to communicate to the user why they are not running with FA enabled?
Or is this already done?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took this code from the previous model. The logs do convey if it's flashModernBert or Modernbert.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's good. Just not why it was not enabled, correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a nit. How about adding an error message or comment here additionally that FlashModernBert does not support flash-attn-v1 due to the lack of attention windowing feature?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great comment, I'll change that around.

{
tracing::info!("Starting FlashModernBert model on {:?}", device);
Ok(Box::new(
FlashModernBertModel::load(vb, &config, model_type).s()?,
))
} else {
#[cfg(feature = "flash-attn-v1")]
tracing::warn!("Flash attention V1 cannot be used with ModernBert because it lacks windowing support.");
tracing::info!("Starting ModernBert model on {:?}", device);
Ok(Box::new(
ModernBertModel::load(vb, &config, model_type).s()?,
))
}
}
#[cfg(feature = "cuda")]
(Config::DistilBert(config), Device::Cuda(_)) => {
if cfg!(feature = "flash-attn")
&& dtype == DType::F16
Expand Down
68 changes: 34 additions & 34 deletions backends/candle/src/models/flash_modernbert.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use std::collections::HashMap;

use crate::flash_attn::flash_attn_varlen;
use crate::layers::{apply_rotary, get_cos_sin, get_inv_freqs, LayerNorm, Linear};
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNormNoBias, Linear};
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was already landed in the orginal PR, it's a good way to handle the fact that the layer norm never has any bias here.

use crate::models::modernbert::{
ClassificationHead, ModernBertClassificationHead, ModernBertConfig, ModernBertEmbeddings,
ModernBertMLP,
};
use crate::models::Model;
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::VarBuilder;
use candle_rotary::apply_rotary_inplace;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct ModernBertAttention {
Expand Down Expand Up @@ -79,35 +80,34 @@ impl ModernBertAttention {
new_qkv_shape.pop();
new_qkv_shape.push(self.num_attention_heads * 3);
new_qkv_shape.push(self.attention_head_size);
let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let qkv = qkv.reshape(new_qkv_shape.as_slice())?;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No transpose, no contiguous, taken from other flash implementations.


let qkv = qkv.chunk(3, 1)?;
let query_layer = &qkv[0].contiguous()?;
let key_layer = &qkv[1].contiguous()?;
let value_layer = &qkv[2];
// Split qkv tensor
let q = qkv.narrow(1, 0, self.num_attention_heads)?;
let k = qkv.narrow(1, self.num_attention_heads, self.num_attention_heads)?;
let v = qkv.narrow(1, self.num_attention_heads * 2, self.num_attention_heads)?;

let query_layer = apply_rotary(query_layer, cos, sin, self.attention_head_size)?;
let key_layer = apply_rotary(key_layer, cos, sin, self.attention_head_size)?;
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;

let attention_size = if self.use_local_attention {
let window_size = if self.use_local_attention {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename for clarity, in flash attention, this is called window_size.

Some(self.local_attention)
} else {
None
};

let attention = flash_attn_varlen(
&query_layer,
&key_layer,
&value_layer,
&q,
&k,
&v,
None,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
self.softmax_scale,
false,
attention_size,
attention_size,
window_size,
window_size,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand All @@ -118,9 +118,9 @@ impl ModernBertAttention {
}

struct ModernBertEncoderLayer {
attn_norm: Option<LayerNorm>,
attn_norm: Option<LayerNormNoBias>,
attn: ModernBertAttention,
mlp_norm: LayerNorm,
mlp_norm: LayerNormNoBias,
mlp: ModernBertMLP,

span: tracing::Span,
Expand All @@ -129,7 +129,7 @@ struct ModernBertEncoderLayer {
impl ModernBertEncoderLayer {
pub fn load(vb: VarBuilder, index: usize, config: &ModernBertConfig) -> Result<Self> {
let attn_norm = if index != 0 {
Some(LayerNorm::load(
Some(LayerNormNoBias::load(
vb.pp("attn_norm"),
config.hidden_size,
config.norm_eps as f32,
Expand All @@ -140,7 +140,7 @@ impl ModernBertEncoderLayer {

let attn = ModernBertAttention::load(vb.pp("attn"), index, config)?;

let mlp_norm = LayerNorm::load(
let mlp_norm = LayerNormNoBias::load(
vb.pp("mlp_norm"),
config.hidden_size,
config.norm_eps as f32,
Expand Down Expand Up @@ -236,11 +236,10 @@ impl ModernBertEncoder {
pub struct FlashModernBertModel {
embeddings: ModernBertEmbeddings,
encoder: ModernBertEncoder,
final_norm: LayerNorm,
final_norm: LayerNormNoBias,
pool: Pool,
classifier: Option<Box<dyn ClassificationHead + Send>>,

rotary_dim: usize,
rotary_cache: HashMap<bool, (Tensor, Tensor)>,

device: Device,
Expand Down Expand Up @@ -277,13 +276,22 @@ impl FlashModernBertModel {
}
};

let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config)?;
let encoder = ModernBertEncoder::load(vb.pp("model.layers"), config)?;
let final_norm = LayerNorm::load(
let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config)
.or_else(|_| ModernBertEmbeddings::load(vb.pp("embeddings"), config))?;
let encoder = ModernBertEncoder::load(vb.pp("model.layers"), config)
.or_else(|_| ModernBertEncoder::load(vb.pp("layers"), config))?;
let final_norm = LayerNormNoBias::load(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle various names like the non flash modeling.

vb.pp("model.final_norm"),
config.hidden_size,
config.norm_eps as f32,
)?;
)
.or_else(|_| {
LayerNormNoBias::load(
vb.pp("final_norm"),
config.hidden_size,
config.norm_eps as f32,
)
})?;

let rotary_dim = config.hidden_size / config.num_attention_heads;
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new();
Expand All @@ -295,15 +303,11 @@ impl FlashModernBertModel {
config.global_rope_theta
};

let max_position_embeddings = if use_local_attention {
config.max_position_embeddings
} else {
config.local_attention
};
let max_position_embeddings = config.max_position_embeddings;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should model what was there before.


let inv_freqs = get_inv_freqs(rotary_dim, rope_theta as f32, vb.device(), None)?;

let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), true)?;
let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), false)?;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very important, in flash models this is false, it's true in non flash models.

Took me quite a bit to see this difference :(


rotary_cache.insert(use_local_attention, (cos, sin));
}
Expand All @@ -314,7 +318,6 @@ impl FlashModernBertModel {
final_norm,
pool,
classifier,
rotary_dim,
rotary_cache,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
Expand Down Expand Up @@ -343,9 +346,6 @@ impl FlashModernBertModel {
let cos = cos.index_select(&position_ids, 0)?;
let sin = sin.index_select(&position_ids, 0)?;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary with the fast rotary kernel.

let cos = cos.reshape((batch_size, 1, max_length, self.rotary_dim))?;
let sin = sin.reshape((batch_size, 1, max_length, self.rotary_dim))?;

rotary_cache.insert(use_local_attention, (cos, sin));
}

Expand Down
6 changes: 5 additions & 1 deletion backends/candle/src/models/modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,10 @@ impl ModernBertModel {
}

fn get_local_attention_mask(&self, attention_mask: &Tensor) -> Result<Tensor> {
let attention_mask = attention_mask.to_dtype(DType::U8)?;
let dev = attention_mask.device();
let attention_mask = attention_mask
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order for the non flash code to work on GPU (when flash isn't available for instance).

Then we cannot do the .abs() operation (it's not in candle). Doing the mask on CPU is acceptable for now (GPU + non flash should be slightly rare nowadays).

.to_device(&Device::Cpu)?
.to_dtype(DType::U8)?;

let mask_shape = attention_mask.shape();
let (_, _, seq_len, _) = mask_shape.dims4()?;
Expand All @@ -597,6 +600,7 @@ impl ModernBertModel {

let zero_tensor = Tensor::zeros_like(&attention_mask)?;
let local_attention_mask = attention_mask.where_cond(&window_mask, &zero_tensor)?;
let local_attention_mask = local_attention_mask.to_device(dev)?;

Ok(local_attention_mask)
}
Expand Down
2 changes: 1 addition & 1 deletion backends/grpc-client/build.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fs;

fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/embed.proto");
println!("cargo:rerun-if-changed=../proto/embed.proto");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes to many rebuilds because the file location is incorrect.

fs::create_dir("src/pb").unwrap_or(());

let mut config = prost_build::Config::new();
Expand Down
2 changes: 2 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@
LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib:/run/opengl-driver/lib";
LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib:/run/opengl-driver/lib";
CUDA_ROOT = "${pkgs.cudaPackages.cudatoolkit}";
CANDLE_FLASH_ATTN_BUILD_DIR = "./kernels";
CANDLE_LAYER_NORM_BUILD_DIR = "./kernels";
Comment on lines +223 to +224
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PRevents flash/layer norm rebuilds in dev shells.

};
}
);
Expand Down
Loading