Skip to content

Commit a112a84

Browse files
authored
[BugFix] Fix RoPE error in Llama 3.1 (#6693)
1 parent 461089a commit a112a84

File tree

2 files changed

+30
-30
lines changed

2 files changed

+30
-30
lines changed

vllm/config.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,6 @@ def __init__(
154154
self.hf_text_config = get_hf_text_config(self.hf_config)
155155
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
156156

157-
if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
158-
and getattr(self.hf_config, "rope_scaling", None) is None):
159-
# Note(simon): this is a special case for a model that doesn't
160-
# supply rope_scaling. We should remove this once the model is
161-
# updated.
162-
self.hf_config.update({"rope_scaling": {
163-
"type": "extended",
164-
}})
165-
166157
if (not self.disable_sliding_window
167158
and self.hf_text_config.model_type == "gemma2"
168159
and self.hf_text_config.sliding_window is not None):
@@ -1492,24 +1483,32 @@ def _get_and_verify_max_len(
14921483
derived_max_model_len = default_max_len
14931484

14941485
rope_scaling = getattr(hf_config, "rope_scaling", None)
1495-
# The correct one should be "longrope", kept "su" here
1496-
# to be backward compatible
1497-
if rope_scaling is not None and rope_scaling["type"] not in {
1498-
"su", "longrope", "extended"
1499-
}:
1500-
if disable_sliding_window:
1501-
# TODO(robertgshaw): Find a model that supports rope_scaling
1502-
# with sliding window to see if this case should be allowed.
1503-
raise NotImplementedError(
1504-
"Disabling sliding window is not supported for models "
1505-
"with rope_scaling. Please raise an issue so we can "
1506-
"investigate.")
1507-
assert "factor" in rope_scaling
1508-
scaling_factor = rope_scaling["factor"]
1509-
if rope_scaling["type"] == "yarn":
1510-
derived_max_model_len = rope_scaling[
1511-
"original_max_position_embeddings"]
1512-
derived_max_model_len *= scaling_factor
1486+
if rope_scaling is not None:
1487+
if "type" in rope_scaling:
1488+
rope_type = rope_scaling["type"]
1489+
elif "rope_type" in rope_scaling:
1490+
rope_type = rope_scaling["rope_type"]
1491+
else:
1492+
raise ValueError(
1493+
"rope_scaling must have a 'type' or 'rope_type' key.")
1494+
1495+
# The correct one should be "longrope", kept "su" here
1496+
# to be backward compatible
1497+
if rope_type not in ("su", "longrope", "llama3"):
1498+
if disable_sliding_window:
1499+
# TODO(robertgshaw): Find a model that supports rope_scaling
1500+
# with sliding window to see if this case should be allowed.
1501+
raise NotImplementedError(
1502+
"Disabling sliding window is not supported for models "
1503+
"with rope_scaling. Please raise an issue so we can "
1504+
"investigate.")
1505+
1506+
assert "factor" in rope_scaling
1507+
scaling_factor = rope_scaling["factor"]
1508+
if rope_type == "yarn":
1509+
derived_max_model_len = rope_scaling[
1510+
"original_max_position_embeddings"]
1511+
derived_max_model_len *= scaling_factor
15131512

15141513
# If the user specified a max length, make sure it is smaller than the
15151514
# derived length from the HF model config.

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -794,12 +794,13 @@ def get_rope(
794794
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
795795
is_neox_style, dtype)
796796
else:
797-
scaling_type = rope_scaling["type"]
797+
scaling_type = rope_scaling[
798+
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
798799
# The correct one should be "longrope" but keep "su" here
799800
# for backward compatible
800-
if scaling_type not in {"su", "longrope", "extended"}:
801+
if scaling_type not in {"su", "longrope", "llama3"}:
801802
scaling_factor = rope_scaling["factor"]
802-
if scaling_type == "extended":
803+
if scaling_type == "llama3":
803804
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
804805
max_position, base,
805806
is_neox_style, dtype)

0 commit comments

Comments
 (0)