@@ -4,7 +4,7 @@ use crate::layers::{
4
4
apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct , LayerNorm , Linear ,
5
5
} ;
6
6
use crate :: models:: Model ;
7
- use candle:: { DType , Device , IndexOp , Module , Result , Shape , Tensor , D } ;
7
+ use candle:: { DType , Device , IndexOp , Module , Result , Tensor , D } ;
8
8
use candle_nn:: { Embedding , VarBuilder } ;
9
9
use serde:: Deserialize ;
10
10
use text_embeddings_backend_core:: { Batch , ModelType , Pool } ;
@@ -454,7 +454,7 @@ pub struct ModernBertModel {
454
454
455
455
local_attention : usize ,
456
456
rotary_dim : usize ,
457
- rotary_cache : HashMap < bool , ( Tensor , Tensor ) > ,
457
+ inv_freqs_cache : HashMap < bool , Tensor > ,
458
458
pad_token_id : u32 ,
459
459
num_attention_heads : usize ,
460
460
@@ -506,7 +506,7 @@ impl ModernBertModel {
506
506
} ) ?;
507
507
508
508
let rotary_dim = config. hidden_size / config. num_attention_heads ;
509
- let mut rotary_cache : HashMap < bool , ( Tensor , Tensor ) > = HashMap :: new ( ) ;
509
+ let mut inv_freqs_cache : HashMap < bool , Tensor > = HashMap :: new ( ) ;
510
510
511
511
for use_local_attention in [ true , false ] {
512
512
let rope_theta = if use_local_attention {
@@ -515,17 +515,9 @@ impl ModernBertModel {
515
515
config. global_rope_theta
516
516
} ;
517
517
518
- let max_position_embeddings = if use_local_attention {
519
- config. max_position_embeddings
520
- } else {
521
- config. local_attention
522
- } ;
523
-
524
518
let inv_freqs = get_inv_freqs ( rotary_dim, rope_theta as f32 , vb. device ( ) , None ) ?;
525
519
526
- let ( cos, sin) = get_cos_sin ( max_position_embeddings, & inv_freqs, vb. dtype ( ) , true ) ?;
527
-
528
- rotary_cache. insert ( use_local_attention, ( cos, sin) ) ;
520
+ inv_freqs_cache. insert ( use_local_attention, inv_freqs) ;
529
521
}
530
522
531
523
Ok ( Self {
@@ -536,7 +528,7 @@ impl ModernBertModel {
536
528
classifier,
537
529
local_attention : config. local_attention ,
538
530
rotary_dim,
539
- rotary_cache ,
531
+ inv_freqs_cache ,
540
532
pad_token_id : config. pad_token_id as u32 ,
541
533
num_attention_heads : config. num_attention_heads ,
542
534
device : vb. device ( ) . clone ( ) ,
@@ -548,18 +540,18 @@ impl ModernBertModel {
548
540
fn get_global_attention_mask (
549
541
& self ,
550
542
attention_mask : Option < & Tensor > ,
551
- input_shape : & Shape ,
543
+ input_shape : & ( usize , usize ) ,
552
544
) -> Result < Tensor > {
553
545
let extended_attention_mask = if let Some ( attention_mask) = attention_mask {
554
546
attention_mask. squeeze ( 2 ) ?
555
547
} else {
556
- Tensor :: ones ( input_shape, DType :: F32 , & self . device ) ?
548
+ Tensor :: ones ( * input_shape, DType :: F32 , & self . device ) ?
557
549
}
558
550
. unsqueeze ( 1 ) ?
559
551
. unsqueeze ( 1 ) ?
560
552
. to_dtype ( self . dtype ) ?;
561
553
562
- let ( bs, seq_len) = input_shape. dims2 ( ) ? ;
554
+ let ( bs, seq_len) = * input_shape;
563
555
let extended_attention_mask = extended_attention_mask. broadcast_as ( (
564
556
bs,
565
557
self . num_attention_heads ,
@@ -664,7 +656,7 @@ impl ModernBertModel {
664
656
Tensor :: from_vec ( input_lengths, ( batch_size, 1 ) , & self . device ) ?. to_dtype ( self . dtype ) ?;
665
657
666
658
let global_attention_mask = self
667
- . get_global_attention_mask ( attention_mask. as_ref ( ) , input_ids . shape ( ) ) ?
659
+ . get_global_attention_mask ( attention_mask. as_ref ( ) , & shape) ?
668
660
. to_dtype ( self . dtype ) ?;
669
661
let silding_attention_mask = self
670
662
. get_silding_window_mask ( & global_attention_mask) ?
@@ -680,7 +672,8 @@ impl ModernBertModel {
680
672
681
673
let mut rotary_cache: HashMap < bool , ( Tensor , Tensor ) > = HashMap :: new ( ) ;
682
674
for use_local_attention in [ true , false ] {
683
- let ( cos, sin) = & self . rotary_cache [ & use_local_attention] ;
675
+ let inv_freq = & self . inv_freqs_cache [ & use_local_attention] ;
676
+ let ( cos, sin) = get_cos_sin ( max_length, inv_freq, self . dtype , true ) ?;
684
677
685
678
let cos = cos. index_select ( & position_ids, 0 ) ?;
686
679
let sin = sin. index_select ( & position_ids, 0 ) ?;
0 commit comments