Skip to content

Fixing FlashAttention ModernBert. #560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 5, 2025
Merged

Fixing FlashAttention ModernBert. #560

merged 4 commits into from
Apr 5, 2025

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Apr 4, 2025

What does this PR do?

Fixes ModernBert to use flash attention when possible.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Narsil
Copy link
Collaborator Author

Narsil commented Apr 4, 2025

@kozistr I've implemented the flash attention version.

I tried to keep things at a minimum here, but I think we can still vastly improve the various caching.

@@ -1,14 +1,15 @@
use std::collections::HashMap;

use crate::flash_attn::flash_attn_varlen;
use crate::layers::{apply_rotary, get_cos_sin, get_inv_freqs, LayerNorm, Linear};
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNormNoBias, Linear};
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was already landed in the orginal PR, it's a good way to handle the fact that the layer norm never has any bias here.

@@ -79,35 +80,34 @@ impl ModernBertAttention {
new_qkv_shape.pop();
new_qkv_shape.push(self.num_attention_heads * 3);
new_qkv_shape.push(self.attention_head_size);
let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let qkv = qkv.reshape(new_qkv_shape.as_slice())?;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No transpose, no contiguous, taken from other flash implementations.


let attention_size = if self.use_local_attention {
let window_size = if self.use_local_attention {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename for clarity, in flash attention, this is called window_size.

.or_else(|_| ModernBertEmbeddings::load(vb.pp("embeddings"), config))?;
let encoder = ModernBertEncoder::load(vb.pp("model.layers"), config)
.or_else(|_| ModernBertEncoder::load(vb.pp("layers"), config))?;
let final_norm = LayerNormNoBias::load(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle various names like the non flash modeling.

} else {
config.local_attention
};
let max_position_embeddings = config.max_position_embeddings;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should model what was there before.


let inv_freqs = get_inv_freqs(rotary_dim, rope_theta as f32, vb.device(), None)?;

let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), true)?;
let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), false)?;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very important, in flash models this is false, it's true in non flash models.

Took me quite a bit to see this difference :(

@@ -343,9 +346,6 @@ impl FlashModernBertModel {
let cos = cos.index_select(&position_ids, 0)?;
let sin = sin.index_select(&position_ids, 0)?;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary with the fast rotary kernel.

@@ -578,7 +578,10 @@ impl ModernBertModel {
}

fn get_local_attention_mask(&self, attention_mask: &Tensor) -> Result<Tensor> {
let attention_mask = attention_mask.to_dtype(DType::U8)?;
let dev = attention_mask.device();
let attention_mask = attention_mask
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order for the non flash code to work on GPU (when flash isn't available for instance).

Then we cannot do the .abs() operation (it's not in candle). Doing the mask on CPU is acceptable for now (GPU + non flash should be slightly rare nowadays).

@@ -1,7 +1,7 @@
use std::fs;

fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/embed.proto");
println!("cargo:rerun-if-changed=../proto/embed.proto");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes to many rebuilds because the file location is incorrect.

Comment on lines +223 to +224
CANDLE_FLASH_ATTN_BUILD_DIR = "./kernels";
CANDLE_LAYER_NORM_BUILD_DIR = "./kernels";
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PRevents flash/layer norm rebuilds in dev shells.

ivarflakstad
ivarflakstad previously approved these changes Apr 4, 2025
Copy link
Member

@ivarflakstad ivarflakstad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't tested these changes myself, but everything LGTM 👌

&& dtype == DType::F16
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
Copy link
Member

@ivarflakstad ivarflakstad Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to communicate to the user why they are not running with FA enabled?
Or is this already done?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took this code from the previous model. The logs do convey if it's flashModernBert or Modernbert.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's good. Just not why it was not enabled, correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a nit. How about adding an error message or comment here additionally that FlashModernBert does not support flash-attn-v1 due to the lack of attention windowing feature?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great comment, I'll change that around.

Copy link
Contributor

@kozistr kozistr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work! I leave some minor comments, and others look great to me :)

&& dtype == DType::F16
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a nit. How about adding an error message or comment here additionally that FlashModernBert does not support flash-attn-v1 due to the lack of attention windowing feature?

Copy link
Contributor

@kozistr kozistr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@Narsil Narsil merged commit f99ce07 into main Apr 5, 2025
14 checks passed
@Narsil Narsil deleted the fix_modernbert_flash branch April 5, 2025 09:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants