@@ -154,15 +154,6 @@ def __init__(
154
154
self .hf_text_config = get_hf_text_config (self .hf_config )
155
155
self .dtype = _get_and_verify_dtype (self .hf_text_config , dtype )
156
156
157
- if (getattr (self .hf_config , "max_position_embeddings" , 0 ) == 131072
158
- and getattr (self .hf_config , "rope_scaling" , None ) is None ):
159
- # Note(simon): this is a special case for a model that doesn't
160
- # supply rope_scaling. We should remove this once the model is
161
- # updated.
162
- self .hf_config .update ({"rope_scaling" : {
163
- "type" : "extended" ,
164
- }})
165
-
166
157
if (not self .disable_sliding_window
167
158
and self .hf_text_config .model_type == "gemma2"
168
159
and self .hf_text_config .sliding_window is not None ):
@@ -1492,24 +1483,32 @@ def _get_and_verify_max_len(
1492
1483
derived_max_model_len = default_max_len
1493
1484
1494
1485
rope_scaling = getattr (hf_config , "rope_scaling" , None )
1495
- # The correct one should be "longrope", kept "su" here
1496
- # to be backward compatible
1497
- if rope_scaling is not None and rope_scaling ["type" ] not in {
1498
- "su" , "longrope" , "extended"
1499
- }:
1500
- if disable_sliding_window :
1501
- # TODO(robertgshaw): Find a model that supports rope_scaling
1502
- # with sliding window to see if this case should be allowed.
1503
- raise NotImplementedError (
1504
- "Disabling sliding window is not supported for models "
1505
- "with rope_scaling. Please raise an issue so we can "
1506
- "investigate." )
1507
- assert "factor" in rope_scaling
1508
- scaling_factor = rope_scaling ["factor" ]
1509
- if rope_scaling ["type" ] == "yarn" :
1510
- derived_max_model_len = rope_scaling [
1511
- "original_max_position_embeddings" ]
1512
- derived_max_model_len *= scaling_factor
1486
+ if rope_scaling is not None :
1487
+ if "type" in rope_scaling :
1488
+ rope_type = rope_scaling ["type" ]
1489
+ elif "rope_type" in rope_scaling :
1490
+ rope_type = rope_scaling ["rope_type" ]
1491
+ else :
1492
+ raise ValueError (
1493
+ "rope_scaling must have a 'type' or 'rope_type' key." )
1494
+
1495
+ # The correct one should be "longrope", kept "su" here
1496
+ # to be backward compatible
1497
+ if rope_type not in ("su" , "longrope" , "llama3" ):
1498
+ if disable_sliding_window :
1499
+ # TODO(robertgshaw): Find a model that supports rope_scaling
1500
+ # with sliding window to see if this case should be allowed.
1501
+ raise NotImplementedError (
1502
+ "Disabling sliding window is not supported for models "
1503
+ "with rope_scaling. Please raise an issue so we can "
1504
+ "investigate." )
1505
+
1506
+ assert "factor" in rope_scaling
1507
+ scaling_factor = rope_scaling ["factor" ]
1508
+ if rope_type == "yarn" :
1509
+ derived_max_model_len = rope_scaling [
1510
+ "original_max_position_embeddings" ]
1511
+ derived_max_model_len *= scaling_factor
1513
1512
1514
1513
# If the user specified a max length, make sure it is smaller than the
1515
1514
# derived length from the HF model config.
0 commit comments