-
Notifications
You must be signed in to change notification settings - Fork 297
Fixing FlashAttention ModernBert. #560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,15 @@ | ||
use std::collections::HashMap; | ||
|
||
use crate::flash_attn::flash_attn_varlen; | ||
use crate::layers::{apply_rotary, get_cos_sin, get_inv_freqs, LayerNorm, Linear}; | ||
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNormNoBias, Linear}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was already landed in the orginal PR, it's a good way to handle the fact that the layer norm never has any bias here. |
||
use crate::models::modernbert::{ | ||
ClassificationHead, ModernBertClassificationHead, ModernBertConfig, ModernBertEmbeddings, | ||
ModernBertMLP, | ||
}; | ||
use crate::models::Model; | ||
use candle::{DType, Device, IndexOp, Result, Tensor}; | ||
use candle_nn::VarBuilder; | ||
use candle_rotary::apply_rotary_inplace; | ||
use text_embeddings_backend_core::{Batch, ModelType, Pool}; | ||
|
||
struct ModernBertAttention { | ||
|
@@ -79,35 +80,34 @@ impl ModernBertAttention { | |
new_qkv_shape.pop(); | ||
new_qkv_shape.push(self.num_attention_heads * 3); | ||
new_qkv_shape.push(self.attention_head_size); | ||
let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; | ||
let qkv = qkv.reshape(new_qkv_shape.as_slice())?; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No transpose, no contiguous, taken from other flash implementations. |
||
|
||
let qkv = qkv.chunk(3, 1)?; | ||
let query_layer = &qkv[0].contiguous()?; | ||
let key_layer = &qkv[1].contiguous()?; | ||
let value_layer = &qkv[2]; | ||
// Split qkv tensor | ||
let q = qkv.narrow(1, 0, self.num_attention_heads)?; | ||
let k = qkv.narrow(1, self.num_attention_heads, self.num_attention_heads)?; | ||
let v = qkv.narrow(1, self.num_attention_heads * 2, self.num_attention_heads)?; | ||
|
||
let query_layer = apply_rotary(query_layer, cos, sin, self.attention_head_size)?; | ||
let key_layer = apply_rotary(key_layer, cos, sin, self.attention_head_size)?; | ||
apply_rotary_inplace(&q, &k, &cos, &sin, true)?; | ||
|
||
let attention_size = if self.use_local_attention { | ||
let window_size = if self.use_local_attention { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename for clarity, in flash attention, this is called window_size. |
||
Some(self.local_attention) | ||
} else { | ||
None | ||
}; | ||
|
||
let attention = flash_attn_varlen( | ||
&query_layer, | ||
&key_layer, | ||
&value_layer, | ||
&q, | ||
&k, | ||
&v, | ||
None, | ||
cu_seqlens, | ||
cu_seqlens, | ||
max_s, | ||
max_s, | ||
self.softmax_scale, | ||
false, | ||
attention_size, | ||
attention_size, | ||
window_size, | ||
window_size, | ||
)?; | ||
let attention = attention.flatten_from(candle::D::Minus2)?; | ||
|
||
|
@@ -118,9 +118,9 @@ impl ModernBertAttention { | |
} | ||
|
||
struct ModernBertEncoderLayer { | ||
attn_norm: Option<LayerNorm>, | ||
attn_norm: Option<LayerNormNoBias>, | ||
attn: ModernBertAttention, | ||
mlp_norm: LayerNorm, | ||
mlp_norm: LayerNormNoBias, | ||
mlp: ModernBertMLP, | ||
|
||
span: tracing::Span, | ||
|
@@ -129,7 +129,7 @@ struct ModernBertEncoderLayer { | |
impl ModernBertEncoderLayer { | ||
pub fn load(vb: VarBuilder, index: usize, config: &ModernBertConfig) -> Result<Self> { | ||
let attn_norm = if index != 0 { | ||
Some(LayerNorm::load( | ||
Some(LayerNormNoBias::load( | ||
vb.pp("attn_norm"), | ||
config.hidden_size, | ||
config.norm_eps as f32, | ||
|
@@ -140,7 +140,7 @@ impl ModernBertEncoderLayer { | |
|
||
let attn = ModernBertAttention::load(vb.pp("attn"), index, config)?; | ||
|
||
let mlp_norm = LayerNorm::load( | ||
let mlp_norm = LayerNormNoBias::load( | ||
vb.pp("mlp_norm"), | ||
config.hidden_size, | ||
config.norm_eps as f32, | ||
|
@@ -236,11 +236,10 @@ impl ModernBertEncoder { | |
pub struct FlashModernBertModel { | ||
embeddings: ModernBertEmbeddings, | ||
encoder: ModernBertEncoder, | ||
final_norm: LayerNorm, | ||
final_norm: LayerNormNoBias, | ||
pool: Pool, | ||
classifier: Option<Box<dyn ClassificationHead + Send>>, | ||
|
||
rotary_dim: usize, | ||
rotary_cache: HashMap<bool, (Tensor, Tensor)>, | ||
|
||
device: Device, | ||
|
@@ -277,13 +276,22 @@ impl FlashModernBertModel { | |
} | ||
}; | ||
|
||
let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config)?; | ||
let encoder = ModernBertEncoder::load(vb.pp("model.layers"), config)?; | ||
let final_norm = LayerNorm::load( | ||
let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config) | ||
.or_else(|_| ModernBertEmbeddings::load(vb.pp("embeddings"), config))?; | ||
let encoder = ModernBertEncoder::load(vb.pp("model.layers"), config) | ||
.or_else(|_| ModernBertEncoder::load(vb.pp("layers"), config))?; | ||
let final_norm = LayerNormNoBias::load( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle various names like the non flash modeling. |
||
vb.pp("model.final_norm"), | ||
config.hidden_size, | ||
config.norm_eps as f32, | ||
)?; | ||
) | ||
.or_else(|_| { | ||
LayerNormNoBias::load( | ||
vb.pp("final_norm"), | ||
config.hidden_size, | ||
config.norm_eps as f32, | ||
) | ||
})?; | ||
|
||
let rotary_dim = config.hidden_size / config.num_attention_heads; | ||
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new(); | ||
|
@@ -295,15 +303,11 @@ impl FlashModernBertModel { | |
config.global_rope_theta | ||
}; | ||
|
||
let max_position_embeddings = if use_local_attention { | ||
config.max_position_embeddings | ||
} else { | ||
config.local_attention | ||
}; | ||
let max_position_embeddings = config.max_position_embeddings; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should model what was there before. |
||
|
||
let inv_freqs = get_inv_freqs(rotary_dim, rope_theta as f32, vb.device(), None)?; | ||
|
||
let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), true)?; | ||
let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), false)?; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very important, in flash models this is false, it's true in non flash models. Took me quite a bit to see this difference :( |
||
|
||
rotary_cache.insert(use_local_attention, (cos, sin)); | ||
} | ||
|
@@ -314,7 +318,6 @@ impl FlashModernBertModel { | |
final_norm, | ||
pool, | ||
classifier, | ||
rotary_dim, | ||
rotary_cache, | ||
device: vb.device().clone(), | ||
span: tracing::span!(tracing::Level::TRACE, "model"), | ||
|
@@ -343,9 +346,6 @@ impl FlashModernBertModel { | |
let cos = cos.index_select(&position_ids, 0)?; | ||
let sin = sin.index_select(&position_ids, 0)?; | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not necessary with the fast rotary kernel. |
||
let cos = cos.reshape((batch_size, 1, max_length, self.rotary_dim))?; | ||
let sin = sin.reshape((batch_size, 1, max_length, self.rotary_dim))?; | ||
|
||
rotary_cache.insert(use_local_attention, (cos, sin)); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -578,7 +578,10 @@ impl ModernBertModel { | |
} | ||
|
||
fn get_local_attention_mask(&self, attention_mask: &Tensor) -> Result<Tensor> { | ||
let attention_mask = attention_mask.to_dtype(DType::U8)?; | ||
let dev = attention_mask.device(); | ||
let attention_mask = attention_mask | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order for the non flash code to work on GPU (when flash isn't available for instance). Then we cannot do the |
||
.to_device(&Device::Cpu)? | ||
.to_dtype(DType::U8)?; | ||
|
||
let mask_shape = attention_mask.shape(); | ||
let (_, _, seq_len, _) = mask_shape.dims4()?; | ||
|
@@ -597,6 +600,7 @@ impl ModernBertModel { | |
|
||
let zero_tensor = Tensor::zeros_like(&attention_mask)?; | ||
let local_attention_mask = attention_mask.where_cond(&window_mask, &zero_tensor)?; | ||
let local_attention_mask = local_attention_mask.to_device(dev)?; | ||
|
||
Ok(local_attention_mask) | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
use std::fs; | ||
|
||
fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
println!("cargo:rerun-if-changed=../../proto/embed.proto"); | ||
println!("cargo:rerun-if-changed=../proto/embed.proto"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This causes to many rebuilds because the file location is incorrect. |
||
fs::create_dir("src/pb").unwrap_or(()); | ||
|
||
let mut config = prost_build::Config::new(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -220,6 +220,8 @@ | |
LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib:/run/opengl-driver/lib"; | ||
LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib:/run/opengl-driver/lib"; | ||
CUDA_ROOT = "${pkgs.cudaPackages.cudatoolkit}"; | ||
CANDLE_FLASH_ATTN_BUILD_DIR = "./kernels"; | ||
CANDLE_LAYER_NORM_BUILD_DIR = "./kernels"; | ||
Comment on lines
+223
to
+224
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PRevents flash/layer norm rebuilds in dev shells. |
||
}; | ||
} | ||
); | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to communicate to the user why they are not running with FA enabled?
Or is this already done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took this code from the previous model. The logs do convey if it's flashModernBert or Modernbert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, that's good. Just not why it was not enabled, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a nit. How about adding an error message or comment here additionally that
FlashModernBert
does not supportflash-attn-v1
due to the lack ofattention windowing
feature?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great comment, I'll change that around.