Skip to content

Commit 43f2322

Browse files
committed
rollback: rotary embedding
1 parent ceccbca commit 43f2322

File tree

3 files changed

+105
-101
lines changed

3 files changed

+105
-101
lines changed

backends/candle/src/layers/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ pub use layer_norm::{LayerNorm, LayerNormNoBias};
1111
pub use linear::{HiddenAct, Linear};
1212
#[allow(unused_imports)]
1313
pub use rms_norm::RMSNorm;
14-
pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling, RotaryEmbedding};
14+
pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling};

backends/candle/src/layers/rotary.rs

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,6 @@
11
use candle::{DType, Device, Result, Tensor, D};
2-
use candle_nn::rotary_emb::rope;
32
use serde::Deserialize;
43

5-
#[derive(Debug, Clone)]
6-
pub struct RotaryEmbedding {
7-
pub cos: Tensor,
8-
pub sin: Tensor,
9-
}
10-
11-
impl RotaryEmbedding {
12-
pub fn new(
13-
dtype: DType,
14-
dim: usize,
15-
max_seq_len: usize,
16-
rope_theta: f64,
17-
device: &Device,
18-
) -> Result<Self> {
19-
let inv_freq: Vec<_> = (0..dim)
20-
.step_by(2)
21-
.map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32)
22-
.collect();
23-
let inv_freq_len = inv_freq.len();
24-
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?.to_dtype(dtype)?;
25-
26-
let t = Tensor::arange(0u32, max_seq_len as u32, device)?
27-
.to_dtype(dtype)?
28-
.reshape((max_seq_len, 1))?;
29-
let freqs = t.matmul(&inv_freq)?;
30-
31-
Ok(Self {
32-
sin: freqs.sin()?,
33-
cos: freqs.cos()?,
34-
})
35-
}
36-
37-
pub fn apply_rotary_emb_qk(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
38-
let q_embed = rope(&q.contiguous()?, &self.cos, &self.sin)?;
39-
let k_embed = rope(&k.contiguous()?, &self.cos, &self.sin)?;
40-
Ok((q_embed, k_embed))
41-
}
42-
}
43-
444
#[derive(Debug, Clone, PartialEq, Deserialize)]
455
pub struct NTKScaling {
466
pub factor: f32,

backends/candle/src/models/modernbert.rs

Lines changed: 104 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNormNoBias, Linear, RotaryEmbedding};
1+
use crate::layers::{
2+
apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, LayerNormNoBias,
3+
Linear,
4+
};
25
use crate::models::Model;
36
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
47
use candle_nn::{Embedding, VarBuilder};
@@ -182,7 +185,7 @@ impl ModernBertAttention {
182185
&self,
183186
hidden_states: &Tensor,
184187
attention_mask: &Tensor,
185-
rotary_embed: &RotaryEmbedding,
188+
rotary_cache: &(Tensor, Tensor),
186189
) -> Result<Tensor> {
187190
let _enter = self.span.enter();
188191
let device = hidden_states.device();
@@ -200,7 +203,18 @@ impl ModernBertAttention {
200203
let key_layer = &qkv[1].contiguous()?;
201204
let value_layer = &qkv[2];
202205

203-
let (query_layer, key_layer) = rotary_embed.apply_rotary_emb_qk(query_layer, key_layer)?;
206+
let query_layer = apply_rotary(
207+
query_layer,
208+
&rotary_cache.0,
209+
&rotary_cache.1,
210+
self.attention_head_size,
211+
)?;
212+
let key_layer = apply_rotary(
213+
key_layer,
214+
&rotary_cache.0,
215+
&rotary_cache.1,
216+
self.attention_head_size,
217+
)?;
204218

205219
#[allow(unused_variables)]
206220
let context_layer =
@@ -305,7 +319,7 @@ impl ModernBertEncoderLayer {
305319
&self,
306320
hidden_states: &Tensor,
307321
attention_mask: &Tensor,
308-
rotary_embed: &RotaryEmbedding,
322+
rotary_cache: &(Tensor, Tensor),
309323
) -> Result<Tensor> {
310324
let _enter = self.span.enter();
311325

@@ -319,7 +333,7 @@ impl ModernBertEncoderLayer {
319333

320334
let attn_outputs = self
321335
.attn
322-
.forward(&attn_norm, attention_mask, rotary_embed)?;
336+
.forward(&attn_norm, attention_mask, rotary_cache)?;
323337

324338
let hidden_states = residual.add(&attn_outputs)?;
325339

@@ -361,8 +375,8 @@ impl ModernBertEncoder {
361375
hidden_states: &Tensor,
362376
global_attention_mask: &Tensor,
363377
local_attention_mask: &Tensor,
364-
global_rotaray_emb: &RotaryEmbedding,
365-
local_rotaray_emb: &RotaryEmbedding,
378+
global_rotaray_cache: &(Tensor, Tensor),
379+
local_rotaray_cache: &(Tensor, Tensor),
366380
) -> Result<Tensor> {
367381
let _enter = self.span.enter();
368382

@@ -371,16 +385,13 @@ impl ModernBertEncoder {
371385
for (index, layer) in self.layers.iter().enumerate() {
372386
let use_local_attention = index % self.global_attn_every_n_layers != 0;
373387

374-
let (attention_mask, rotary_embed) = if use_local_attention {
375-
(
376-
&global_attention_mask.broadcast_add(local_attention_mask)?,
377-
local_rotaray_emb,
378-
)
388+
let (attention_mask, rotary_cache) = if use_local_attention {
389+
(local_attention_mask, local_rotaray_cache)
379390
} else {
380-
(global_attention_mask, global_rotaray_emb)
391+
(global_attention_mask, global_rotaray_cache)
381392
};
382393

383-
hidden_states = layer.forward(&hidden_states, attention_mask, rotary_embed)?;
394+
hidden_states = layer.forward(&hidden_states, attention_mask, rotary_cache)?;
384395
}
385396

386397
Ok(hidden_states)
@@ -456,8 +467,9 @@ pub struct ModernBertModel {
456467
classifier: Option<Box<dyn ClassificationHead + Send>>,
457468

458469
local_attention: usize,
459-
global_rotary_emb: RotaryEmbedding,
460-
local_rotary_emb: RotaryEmbedding,
470+
global_inv_freqs: Tensor,
471+
local_inv_freqs: Tensor,
472+
rotary_dim: usize,
461473
pad_token_id: u32,
462474
num_attention_heads: usize,
463475

@@ -508,21 +520,19 @@ impl ModernBertModel {
508520
)
509521
})?;
510522

511-
let rotary_dim = config.hidden_size / config.num_attention_heads;
523+
let attention_head_size = config.hidden_size / config.num_attention_heads;
512524

513-
let global_rotary_emb = RotaryEmbedding::new(
514-
vb.dtype(),
515-
rotary_dim,
516-
config.max_position_embeddings,
517-
config.global_rope_theta,
525+
let global_inv_freqs = get_inv_freqs(
526+
attention_head_size,
527+
config.global_rope_theta as f32,
518528
vb.device(),
529+
None,
519530
)?;
520-
let local_rotary_emb = RotaryEmbedding::new(
521-
vb.dtype(),
522-
rotary_dim,
523-
config.max_position_embeddings,
524-
config.local_rope_theta,
531+
let local_inv_freqs = get_inv_freqs(
532+
attention_head_size,
533+
config.local_rope_theta as f32,
525534
vb.device(),
535+
None,
526536
)?;
527537

528538
Ok(Self {
@@ -532,8 +542,9 @@ impl ModernBertModel {
532542
pool,
533543
classifier,
534544
local_attention: config.local_attention,
535-
global_rotary_emb,
536-
local_rotary_emb,
545+
global_inv_freqs,
546+
local_inv_freqs,
547+
rotary_dim: attention_head_size,
537548
pad_token_id: config.pad_token_id as u32,
538549
num_attention_heads: config.num_attention_heads,
539550
device: vb.device().clone(),
@@ -563,37 +574,31 @@ impl ModernBertModel {
563574
seq_len,
564575
))?;
565576

566-
let min_value = match self.dtype {
567-
DType::F32 => f32::MIN as f64,
568-
_ => -65504.0_f64, // f16 minumum value
569-
};
577+
Ok(extended_attention_mask)
578+
}
570579

571-
let inverted_mask = ((1.0 - extended_attention_mask)? * min_value)?;
580+
fn get_local_attention_mask(&self, attention_mask: &Tensor) -> Result<Tensor> {
581+
let attention_mask = attention_mask.to_dtype(DType::U8)?;
572582

573-
inverted_mask.to_dtype(self.dtype)
574-
}
583+
let mask_shape = attention_mask.shape();
584+
let (_, _, seq_len, _) = mask_shape.dims4()?;
575585

576-
fn get_local_attention_mask(&self, seq_len: usize) -> Result<Tensor> {
577-
let window_size: usize = self.local_attention / 2;
586+
let rows = Tensor::arange(0, seq_len as i64, attention_mask.device())?.unsqueeze(0)?;
587+
let rows = rows.broadcast_as((seq_len, seq_len))?;
578588

579-
let min_value = match self.dtype {
580-
DType::F32 => f32::MIN as f64,
581-
_ => -65504.0_f64, // f16 minumum value
582-
};
589+
let distance = (&rows - &rows.t()?)?.abs()?;
583590

584-
let mask: Vec<_> = (0..seq_len)
585-
.flat_map(|i| {
586-
(0..seq_len).map(move |j| {
587-
if (j as i32 - i as i32).abs() > window_size as i32 {
588-
min_value
589-
} else {
590-
0.
591-
}
592-
})
593-
})
594-
.collect();
591+
let window_size = (self.local_attention / 2) as i64;
592+
let window_mask = distance
593+
.le(window_size)?
594+
.unsqueeze(0)?
595+
.unsqueeze(0)?
596+
.broadcast_as(mask_shape)?;
597+
598+
let zero_tensor = Tensor::zeros_like(&attention_mask)?;
599+
let local_attention_mask = attention_mask.where_cond(&window_mask, &zero_tensor)?;
595600

596-
Tensor::from_slice(&mask, (seq_len, seq_len), &self.device)?.to_dtype(self.dtype)
601+
Ok(local_attention_mask)
597602
}
598603

599604
fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
@@ -604,7 +609,7 @@ impl ModernBertModel {
604609

605610
let shape = (batch_size, max_length);
606611

607-
let (input_ids, input_lengths, _, attention_mask) = if batch_size > 1 {
612+
let (input_ids, input_lengths, position_ids, attention_mask) = if batch_size > 1 {
608613
let elems = batch_size * max_length;
609614

610615
let mut input_ids = Vec::with_capacity(elems);
@@ -662,20 +667,59 @@ impl ModernBertModel {
662667
};
663668

664669
let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?;
670+
let position_ids = Tensor::from_vec(position_ids, batch_size * max_length, &self.device)?;
665671
let mut input_lengths =
666672
Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?;
667673

668-
let global_attention_mask =
669-
self.get_global_attention_mask(attention_mask.as_ref(), &shape)?;
670-
let local_attention_mask = self.get_local_attention_mask(max_length)?;
674+
let global_attention_mask = self
675+
.get_global_attention_mask(attention_mask.as_ref(), &shape)?
676+
.to_dtype(self.dtype)?;
677+
let local_attention_mask = self
678+
.get_local_attention_mask(&global_attention_mask)?
679+
.to_dtype(self.dtype)?;
680+
681+
let min_value = match self.dtype {
682+
DType::F32 => f32::MIN as f64,
683+
_ => -65504.0, // f16 minimum value
684+
};
685+
686+
let global_attention_mask = ((1.0 - global_attention_mask)? * min_value)?;
687+
let local_attention_mask = ((1.0 - local_attention_mask)? * min_value)?;
688+
689+
let global_rotary_cache =
690+
get_cos_sin(max_length, &self.global_inv_freqs, self.dtype, true)?;
691+
let local_rotary_cache = get_cos_sin(max_length, &self.local_inv_freqs, self.dtype, true)?;
692+
693+
let global_rotary_cache = (
694+
global_rotary_cache
695+
.0
696+
.index_select(&position_ids, 0)?
697+
.reshape((batch_size, 1, max_length, self.rotary_dim))?,
698+
global_rotary_cache
699+
.1
700+
.index_select(&position_ids, 0)?
701+
.reshape((batch_size, 1, max_length, self.rotary_dim))?,
702+
);
703+
704+
let local_rotary_cache = (
705+
local_rotary_cache
706+
.0
707+
.index_select(&position_ids, 0)?
708+
.reshape((batch_size, 1, max_length, self.rotary_dim))?,
709+
local_rotary_cache
710+
.1
711+
.index_select(&position_ids, 0)?
712+
.reshape((batch_size, 1, max_length, self.rotary_dim))?,
713+
);
671714

672715
let hidden_states = self.embeddings.forward(&input_ids)?;
716+
673717
let hidden_states = self.encoder.forward(
674718
&hidden_states,
675719
&global_attention_mask,
676720
&local_attention_mask,
677-
&self.global_rotary_emb,
678-
&self.local_rotary_emb,
721+
&global_rotary_cache,
722+
&local_rotary_cache,
679723
)?;
680724
let outputs = self.final_norm.forward(&hidden_states, None)?;
681725

0 commit comments

Comments
 (0)