Skip to content

Commit 63c4224

Browse files
committed
fix: rotary embedding
1 parent 3b20211 commit 63c4224

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

backends/candle/src/models/modernbert.rs

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::layers::{
44
apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, LayerNorm, Linear,
55
};
66
use crate::models::Model;
7-
use candle::{DType, Device, IndexOp, Module, Result, Shape, Tensor, D};
7+
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
88
use candle_nn::{Embedding, VarBuilder};
99
use serde::Deserialize;
1010
use text_embeddings_backend_core::{Batch, ModelType, Pool};
@@ -454,7 +454,7 @@ pub struct ModernBertModel {
454454

455455
local_attention: usize,
456456
rotary_dim: usize,
457-
rotary_cache: HashMap<bool, (Tensor, Tensor)>,
457+
inv_freqs_cache: HashMap<bool, Tensor>,
458458
pad_token_id: u32,
459459
num_attention_heads: usize,
460460

@@ -506,7 +506,7 @@ impl ModernBertModel {
506506
})?;
507507

508508
let rotary_dim = config.hidden_size / config.num_attention_heads;
509-
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new();
509+
let mut inv_freqs_cache: HashMap<bool, Tensor> = HashMap::new();
510510

511511
for use_local_attention in [true, false] {
512512
let rope_theta = if use_local_attention {
@@ -515,17 +515,9 @@ impl ModernBertModel {
515515
config.global_rope_theta
516516
};
517517

518-
let max_position_embeddings = if use_local_attention {
519-
config.max_position_embeddings
520-
} else {
521-
config.local_attention
522-
};
523-
524518
let inv_freqs = get_inv_freqs(rotary_dim, rope_theta as f32, vb.device(), None)?;
525519

526-
let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), true)?;
527-
528-
rotary_cache.insert(use_local_attention, (cos, sin));
520+
inv_freqs_cache.insert(use_local_attention, inv_freqs);
529521
}
530522

531523
Ok(Self {
@@ -536,7 +528,7 @@ impl ModernBertModel {
536528
classifier,
537529
local_attention: config.local_attention,
538530
rotary_dim,
539-
rotary_cache,
531+
inv_freqs_cache,
540532
pad_token_id: config.pad_token_id as u32,
541533
num_attention_heads: config.num_attention_heads,
542534
device: vb.device().clone(),
@@ -548,18 +540,18 @@ impl ModernBertModel {
548540
fn get_global_attention_mask(
549541
&self,
550542
attention_mask: Option<&Tensor>,
551-
input_shape: &Shape,
543+
input_shape: &(usize, usize),
552544
) -> Result<Tensor> {
553545
let extended_attention_mask = if let Some(attention_mask) = attention_mask {
554546
attention_mask.squeeze(2)?
555547
} else {
556-
Tensor::ones(input_shape, DType::F32, &self.device)?
548+
Tensor::ones(*input_shape, DType::F32, &self.device)?
557549
}
558550
.unsqueeze(1)?
559551
.unsqueeze(1)?
560552
.to_dtype(self.dtype)?;
561553

562-
let (bs, seq_len) = input_shape.dims2()?;
554+
let (bs, seq_len) = *input_shape;
563555
let extended_attention_mask = extended_attention_mask.broadcast_as((
564556
bs,
565557
self.num_attention_heads,
@@ -664,7 +656,7 @@ impl ModernBertModel {
664656
Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?;
665657

666658
let global_attention_mask = self
667-
.get_global_attention_mask(attention_mask.as_ref(), input_ids.shape())?
659+
.get_global_attention_mask(attention_mask.as_ref(), &shape)?
668660
.to_dtype(self.dtype)?;
669661
let silding_attention_mask = self
670662
.get_silding_window_mask(&global_attention_mask)?
@@ -680,7 +672,8 @@ impl ModernBertModel {
680672

681673
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new();
682674
for use_local_attention in [true, false] {
683-
let (cos, sin) = &self.rotary_cache[&use_local_attention];
675+
let inv_freq = &self.inv_freqs_cache[&use_local_attention];
676+
let (cos, sin) = get_cos_sin(max_length, inv_freq, self.dtype, true)?;
684677

685678
let cos = cos.index_select(&position_ids, 0)?;
686679
let sin = sin.index_select(&position_ids, 0)?;

0 commit comments

Comments
 (0)