@@ -507,8 +507,8 @@ def __init__(
507
507
dtype : torch .dtype ,
508
508
short_factor : List [float ],
509
509
long_factor : List [float ],
510
- short_mscale : float = 1.1 ,
511
- long_mscale : float = 1.225 ,
510
+ short_mscale : float = 1.0 ,
511
+ long_mscale : float = 1.0 ,
512
512
):
513
513
super ().__init__ ()
514
514
@@ -530,6 +530,16 @@ def __init__(
530
530
self .short_mscale = short_mscale
531
531
self .long_mscale = long_mscale
532
532
533
+ scale = (self .max_position_embeddings /
534
+ self .original_max_position_embeddings )
535
+
536
+ if scale <= 1.0 :
537
+ self .scaling_factor = 1.0
538
+ else :
539
+ self .scaling_factor = math .sqrt (
540
+ 1 + math .log (scale ) /
541
+ math .log (self .original_max_position_embeddings ))
542
+
533
543
short_cache = self ._compute_cos_sin_cache (
534
544
original_max_position_embeddings , short_factor , short_mscale )
535
545
short_cache = short_cache .to (dtype )
@@ -565,8 +575,8 @@ def _compute_cos_sin_cache(
565
575
inv_freq = self ._compute_inv_freq (rescale_factors )
566
576
t = torch .arange (max_position_embeddings , dtype = torch .float )
567
577
freqs = torch .einsum ("i,j -> ij" , t , inv_freq )
568
- cos = freqs .cos () * mscale
569
- sin = freqs .sin () * mscale
578
+ cos = freqs .cos () * mscale * self . scaling_factor
579
+ sin = freqs .sin () * mscale * self . scaling_factor
570
580
cache = torch .cat ((cos , sin ), dim = - 1 )
571
581
return cache
572
582
0 commit comments