1
- use crate :: layers:: { get_cublas_lt_wrapper, HiddenAct , LayerNormNoBias , Linear , RotaryEmbedding } ;
1
+ use crate :: layers:: {
2
+ apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct , LayerNormNoBias ,
3
+ Linear ,
4
+ } ;
2
5
use crate :: models:: Model ;
3
6
use candle:: { DType , Device , IndexOp , Module , Result , Tensor , D } ;
4
7
use candle_nn:: { Embedding , VarBuilder } ;
@@ -182,7 +185,7 @@ impl ModernBertAttention {
182
185
& self ,
183
186
hidden_states : & Tensor ,
184
187
attention_mask : & Tensor ,
185
- rotary_embed : & RotaryEmbedding ,
188
+ rotary_cache : & ( Tensor , Tensor ) ,
186
189
) -> Result < Tensor > {
187
190
let _enter = self . span . enter ( ) ;
188
191
let device = hidden_states. device ( ) ;
@@ -200,7 +203,18 @@ impl ModernBertAttention {
200
203
let key_layer = & qkv[ 1 ] . contiguous ( ) ?;
201
204
let value_layer = & qkv[ 2 ] ;
202
205
203
- let ( query_layer, key_layer) = rotary_embed. apply_rotary_emb_qk ( query_layer, key_layer) ?;
206
+ let query_layer = apply_rotary (
207
+ query_layer,
208
+ & rotary_cache. 0 ,
209
+ & rotary_cache. 1 ,
210
+ self . attention_head_size ,
211
+ ) ?;
212
+ let key_layer = apply_rotary (
213
+ key_layer,
214
+ & rotary_cache. 0 ,
215
+ & rotary_cache. 1 ,
216
+ self . attention_head_size ,
217
+ ) ?;
204
218
205
219
#[ allow( unused_variables) ]
206
220
let context_layer =
@@ -305,7 +319,7 @@ impl ModernBertEncoderLayer {
305
319
& self ,
306
320
hidden_states : & Tensor ,
307
321
attention_mask : & Tensor ,
308
- rotary_embed : & RotaryEmbedding ,
322
+ rotary_cache : & ( Tensor , Tensor ) ,
309
323
) -> Result < Tensor > {
310
324
let _enter = self . span . enter ( ) ;
311
325
@@ -319,7 +333,7 @@ impl ModernBertEncoderLayer {
319
333
320
334
let attn_outputs = self
321
335
. attn
322
- . forward ( & attn_norm, attention_mask, rotary_embed ) ?;
336
+ . forward ( & attn_norm, attention_mask, rotary_cache ) ?;
323
337
324
338
let hidden_states = residual. add ( & attn_outputs) ?;
325
339
@@ -361,8 +375,8 @@ impl ModernBertEncoder {
361
375
hidden_states : & Tensor ,
362
376
global_attention_mask : & Tensor ,
363
377
local_attention_mask : & Tensor ,
364
- global_rotaray_emb : & RotaryEmbedding ,
365
- local_rotaray_emb : & RotaryEmbedding ,
378
+ global_rotaray_cache : & ( Tensor , Tensor ) ,
379
+ local_rotaray_cache : & ( Tensor , Tensor ) ,
366
380
) -> Result < Tensor > {
367
381
let _enter = self . span . enter ( ) ;
368
382
@@ -371,16 +385,13 @@ impl ModernBertEncoder {
371
385
for ( index, layer) in self . layers . iter ( ) . enumerate ( ) {
372
386
let use_local_attention = index % self . global_attn_every_n_layers != 0 ;
373
387
374
- let ( attention_mask, rotary_embed) = if use_local_attention {
375
- (
376
- & global_attention_mask. broadcast_add ( local_attention_mask) ?,
377
- local_rotaray_emb,
378
- )
388
+ let ( attention_mask, rotary_cache) = if use_local_attention {
389
+ ( local_attention_mask, local_rotaray_cache)
379
390
} else {
380
- ( global_attention_mask, global_rotaray_emb )
391
+ ( global_attention_mask, global_rotaray_cache )
381
392
} ;
382
393
383
- hidden_states = layer. forward ( & hidden_states, attention_mask, rotary_embed ) ?;
394
+ hidden_states = layer. forward ( & hidden_states, attention_mask, rotary_cache ) ?;
384
395
}
385
396
386
397
Ok ( hidden_states)
@@ -456,8 +467,9 @@ pub struct ModernBertModel {
456
467
classifier : Option < Box < dyn ClassificationHead + Send > > ,
457
468
458
469
local_attention : usize ,
459
- global_rotary_emb : RotaryEmbedding ,
460
- local_rotary_emb : RotaryEmbedding ,
470
+ global_inv_freqs : Tensor ,
471
+ local_inv_freqs : Tensor ,
472
+ rotary_dim : usize ,
461
473
pad_token_id : u32 ,
462
474
num_attention_heads : usize ,
463
475
@@ -508,21 +520,19 @@ impl ModernBertModel {
508
520
)
509
521
} ) ?;
510
522
511
- let rotary_dim = config. hidden_size / config. num_attention_heads ;
523
+ let attention_head_size = config. hidden_size / config. num_attention_heads ;
512
524
513
- let global_rotary_emb = RotaryEmbedding :: new (
514
- vb. dtype ( ) ,
515
- rotary_dim,
516
- config. max_position_embeddings ,
517
- config. global_rope_theta ,
525
+ let global_inv_freqs = get_inv_freqs (
526
+ attention_head_size,
527
+ config. global_rope_theta as f32 ,
518
528
vb. device ( ) ,
529
+ None ,
519
530
) ?;
520
- let local_rotary_emb = RotaryEmbedding :: new (
521
- vb. dtype ( ) ,
522
- rotary_dim,
523
- config. max_position_embeddings ,
524
- config. local_rope_theta ,
531
+ let local_inv_freqs = get_inv_freqs (
532
+ attention_head_size,
533
+ config. local_rope_theta as f32 ,
525
534
vb. device ( ) ,
535
+ None ,
526
536
) ?;
527
537
528
538
Ok ( Self {
@@ -532,8 +542,9 @@ impl ModernBertModel {
532
542
pool,
533
543
classifier,
534
544
local_attention : config. local_attention ,
535
- global_rotary_emb,
536
- local_rotary_emb,
545
+ global_inv_freqs,
546
+ local_inv_freqs,
547
+ rotary_dim : attention_head_size,
537
548
pad_token_id : config. pad_token_id as u32 ,
538
549
num_attention_heads : config. num_attention_heads ,
539
550
device : vb. device ( ) . clone ( ) ,
@@ -563,37 +574,31 @@ impl ModernBertModel {
563
574
seq_len,
564
575
) ) ?;
565
576
566
- let min_value = match self . dtype {
567
- DType :: F32 => f32:: MIN as f64 ,
568
- _ => -65504.0_f64 , // f16 minumum value
569
- } ;
577
+ Ok ( extended_attention_mask)
578
+ }
570
579
571
- let inverted_mask = ( ( 1.0 - extended_attention_mask) ? * min_value) ?;
580
+ fn get_local_attention_mask ( & self , attention_mask : & Tensor ) -> Result < Tensor > {
581
+ let attention_mask = attention_mask. to_dtype ( DType :: U8 ) ?;
572
582
573
- inverted_mask . to_dtype ( self . dtype )
574
- }
583
+ let mask_shape = attention_mask . shape ( ) ;
584
+ let ( _ , _ , seq_len , _ ) = mask_shape . dims4 ( ) ? ;
575
585
576
- fn get_local_attention_mask ( & self , seq_len : usize ) -> Result < Tensor > {
577
- let window_size : usize = self . local_attention / 2 ;
586
+ let rows = Tensor :: arange ( 0 , seq_len as i64 , attention_mask . device ( ) ) ? . unsqueeze ( 0 ) ? ;
587
+ let rows = rows . broadcast_as ( ( seq_len , seq_len ) ) ? ;
578
588
579
- let min_value = match self . dtype {
580
- DType :: F32 => f32:: MIN as f64 ,
581
- _ => -65504.0_f64 , // f16 minumum value
582
- } ;
589
+ let distance = ( & rows - & rows. t ( ) ?) ?. abs ( ) ?;
583
590
584
- let mask: Vec < _ > = ( 0 ..seq_len)
585
- . flat_map ( |i| {
586
- ( 0 ..seq_len) . map ( move |j| {
587
- if ( j as i32 - i as i32 ) . abs ( ) > window_size as i32 {
588
- min_value
589
- } else {
590
- 0.
591
- }
592
- } )
593
- } )
594
- . collect ( ) ;
591
+ let window_size = ( self . local_attention / 2 ) as i64 ;
592
+ let window_mask = distance
593
+ . le ( window_size) ?
594
+ . unsqueeze ( 0 ) ?
595
+ . unsqueeze ( 0 ) ?
596
+ . broadcast_as ( mask_shape) ?;
597
+
598
+ let zero_tensor = Tensor :: zeros_like ( & attention_mask) ?;
599
+ let local_attention_mask = attention_mask. where_cond ( & window_mask, & zero_tensor) ?;
595
600
596
- Tensor :: from_slice ( & mask , ( seq_len , seq_len ) , & self . device ) ? . to_dtype ( self . dtype )
601
+ Ok ( local_attention_mask )
597
602
}
598
603
599
604
fn forward ( & self , batch : Batch ) -> Result < ( Option < Tensor > , Option < Tensor > ) > {
@@ -604,7 +609,7 @@ impl ModernBertModel {
604
609
605
610
let shape = ( batch_size, max_length) ;
606
611
607
- let ( input_ids, input_lengths, _ , attention_mask) = if batch_size > 1 {
612
+ let ( input_ids, input_lengths, position_ids , attention_mask) = if batch_size > 1 {
608
613
let elems = batch_size * max_length;
609
614
610
615
let mut input_ids = Vec :: with_capacity ( elems) ;
@@ -662,20 +667,59 @@ impl ModernBertModel {
662
667
} ;
663
668
664
669
let input_ids = Tensor :: from_vec ( input_ids, shape, & self . device ) ?;
670
+ let position_ids = Tensor :: from_vec ( position_ids, batch_size * max_length, & self . device ) ?;
665
671
let mut input_lengths =
666
672
Tensor :: from_vec ( input_lengths, ( batch_size, 1 ) , & self . device ) ?. to_dtype ( self . dtype ) ?;
667
673
668
- let global_attention_mask =
669
- self . get_global_attention_mask ( attention_mask. as_ref ( ) , & shape) ?;
670
- let local_attention_mask = self . get_local_attention_mask ( max_length) ?;
674
+ let global_attention_mask = self
675
+ . get_global_attention_mask ( attention_mask. as_ref ( ) , & shape) ?
676
+ . to_dtype ( self . dtype ) ?;
677
+ let local_attention_mask = self
678
+ . get_local_attention_mask ( & global_attention_mask) ?
679
+ . to_dtype ( self . dtype ) ?;
680
+
681
+ let min_value = match self . dtype {
682
+ DType :: F32 => f32:: MIN as f64 ,
683
+ _ => -65504.0 , // f16 minimum value
684
+ } ;
685
+
686
+ let global_attention_mask = ( ( 1.0 - global_attention_mask) ? * min_value) ?;
687
+ let local_attention_mask = ( ( 1.0 - local_attention_mask) ? * min_value) ?;
688
+
689
+ let global_rotary_cache =
690
+ get_cos_sin ( max_length, & self . global_inv_freqs , self . dtype , true ) ?;
691
+ let local_rotary_cache = get_cos_sin ( max_length, & self . local_inv_freqs , self . dtype , true ) ?;
692
+
693
+ let global_rotary_cache = (
694
+ global_rotary_cache
695
+ . 0
696
+ . index_select ( & position_ids, 0 ) ?
697
+ . reshape ( ( batch_size, 1 , max_length, self . rotary_dim ) ) ?,
698
+ global_rotary_cache
699
+ . 1
700
+ . index_select ( & position_ids, 0 ) ?
701
+ . reshape ( ( batch_size, 1 , max_length, self . rotary_dim ) ) ?,
702
+ ) ;
703
+
704
+ let local_rotary_cache = (
705
+ local_rotary_cache
706
+ . 0
707
+ . index_select ( & position_ids, 0 ) ?
708
+ . reshape ( ( batch_size, 1 , max_length, self . rotary_dim ) ) ?,
709
+ local_rotary_cache
710
+ . 1
711
+ . index_select ( & position_ids, 0 ) ?
712
+ . reshape ( ( batch_size, 1 , max_length, self . rotary_dim ) ) ?,
713
+ ) ;
671
714
672
715
let hidden_states = self . embeddings . forward ( & input_ids) ?;
716
+
673
717
let hidden_states = self . encoder . forward (
674
718
& hidden_states,
675
719
& global_attention_mask,
676
720
& local_attention_mask,
677
- & self . global_rotary_emb ,
678
- & self . local_rotary_emb ,
721
+ & global_rotary_cache ,
722
+ & local_rotary_cache ,
679
723
) ?;
680
724
let outputs = self . final_norm . forward ( & hidden_states, None ) ?;
681
725
0 commit comments