-
Notifications
You must be signed in to change notification settings - Fork 297
Fixing FlashAttention ModernBert. #560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@kozistr I've implemented the flash attention version. I tried to keep things at a minimum here, but I think we can still vastly improve the various caching. |
@@ -1,14 +1,15 @@ | |||
use std::collections::HashMap; | |||
|
|||
use crate::flash_attn::flash_attn_varlen; | |||
use crate::layers::{apply_rotary, get_cos_sin, get_inv_freqs, LayerNorm, Linear}; | |||
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNormNoBias, Linear}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was already landed in the orginal PR, it's a good way to handle the fact that the layer norm never has any bias here.
@@ -79,35 +80,34 @@ impl ModernBertAttention { | |||
new_qkv_shape.pop(); | |||
new_qkv_shape.push(self.num_attention_heads * 3); | |||
new_qkv_shape.push(self.attention_head_size); | |||
let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; | |||
let qkv = qkv.reshape(new_qkv_shape.as_slice())?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No transpose, no contiguous, taken from other flash implementations.
|
||
let attention_size = if self.use_local_attention { | ||
let window_size = if self.use_local_attention { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename for clarity, in flash attention, this is called window_size.
.or_else(|_| ModernBertEmbeddings::load(vb.pp("embeddings"), config))?; | ||
let encoder = ModernBertEncoder::load(vb.pp("model.layers"), config) | ||
.or_else(|_| ModernBertEncoder::load(vb.pp("layers"), config))?; | ||
let final_norm = LayerNormNoBias::load( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle various names like the non flash modeling.
} else { | ||
config.local_attention | ||
}; | ||
let max_position_embeddings = config.max_position_embeddings; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should model what was there before.
|
||
let inv_freqs = get_inv_freqs(rotary_dim, rope_theta as f32, vb.device(), None)?; | ||
|
||
let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), true)?; | ||
let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), false)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very important, in flash models this is false, it's true in non flash models.
Took me quite a bit to see this difference :(
@@ -343,9 +346,6 @@ impl FlashModernBertModel { | |||
let cos = cos.index_select(&position_ids, 0)?; | |||
let sin = sin.index_select(&position_ids, 0)?; | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary with the fast rotary kernel.
@@ -578,7 +578,10 @@ impl ModernBertModel { | |||
} | |||
|
|||
fn get_local_attention_mask(&self, attention_mask: &Tensor) -> Result<Tensor> { | |||
let attention_mask = attention_mask.to_dtype(DType::U8)?; | |||
let dev = attention_mask.device(); | |||
let attention_mask = attention_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order for the non flash code to work on GPU (when flash isn't available for instance).
Then we cannot do the .abs()
operation (it's not in candle). Doing the mask on CPU is acceptable for now (GPU + non flash should be slightly rare nowadays).
@@ -1,7 +1,7 @@ | |||
use std::fs; | |||
|
|||
fn main() -> Result<(), Box<dyn std::error::Error>> { | |||
println!("cargo:rerun-if-changed=../../proto/embed.proto"); | |||
println!("cargo:rerun-if-changed=../proto/embed.proto"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This causes to many rebuilds because the file location is incorrect.
CANDLE_FLASH_ATTN_BUILD_DIR = "./kernels"; | ||
CANDLE_LAYER_NORM_BUILD_DIR = "./kernels"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PRevents flash/layer norm rebuilds in dev shells.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tested these changes myself, but everything LGTM 👌
&& dtype == DType::F16 | ||
// Allow disabling because of flash attention v1 precision problems | ||
// See: https://github.com/huggingface/text-embeddings-inference/issues/37 | ||
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to communicate to the user why they are not running with FA enabled?
Or is this already done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took this code from the previous model. The logs do convey if it's flashModernBert or Modernbert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, that's good. Just not why it was not enabled, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a nit. How about adding an error message or comment here additionally that FlashModernBert
does not support flash-attn-v1
due to the lack of attention windowing
feature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great comment, I'll change that around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your work! I leave some minor comments, and others look great to me :)
&& dtype == DType::F16 | ||
// Allow disabling because of flash attention v1 precision problems | ||
// See: https://github.com/huggingface/text-embeddings-inference/issues/37 | ||
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a nit. How about adding an error message or comment here additionally that FlashModernBert
does not support flash-attn-v1
due to the lack of attention windowing
feature?
Co-authored-by: Hyeongchan Kim <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
What does this PR do?
Fixes ModernBert to use flash attention when possible.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.