Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ WORKDIR ${COMMON_WORKDIR}
# -----------------------
# hipBLASLt build stages
FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH="6f65c6e"
ARG HIPBLASLT_BRANCH="8b71e7a8d26ba95774fdc372883ee0be57af3d28"
RUN git clone https://github.com/ROCm/hipBLASLt \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
Expand Down Expand Up @@ -70,7 +70,7 @@ FROM export_rccl_${BUILD_RCCL} AS export_rccl
# -----------------------
# flash attn build stages
FROM base AS build_flash_attn
ARG FA_BRANCH="ae7928c"
ARG FA_BRANCH="23a2b1c2f21de2289db83de7d42e125586368e66"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
RUN git clone ${FA_REPO} \
Expand Down
1 change: 1 addition & 0 deletions setup_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"vllm/model_executor/layers/sampler.py",
"vllm/sampling_params.py",
"vllm/utils.py",
"vllm/block.py",
]


Expand Down
54 changes: 39 additions & 15 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
from vllm.utils import (get_cpu_memory, is_cpu, is_hip, is_neuron,
print_warning_once)

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -133,6 +134,17 @@ def __init__(
code_revision, rope_scaling)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)

if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
print_warning_once(
"Gemma 2 uses sliding window attention for every odd layer, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f"({self.hf_text_config.sliding_window}).")
self.disable_sliding_window = True

self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
Expand Down Expand Up @@ -1224,20 +1236,32 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None and rope_scaling["type"] != "su":
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate.")
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor
if rope_scaling is not None:
if "type" in rope_scaling:
rope_type = rope_scaling["type"]
elif "rope_type" in rope_scaling:
rope_type = rope_scaling["rope_type"]
else:
raise ValueError(
"rope_scaling must have a 'type' or 'rope_type' key.")

# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_type not in ("su", "longrope", "llama3"):
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate.")

assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_type == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor

# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
Expand Down
12 changes: 7 additions & 5 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,12 +612,14 @@ def _free_block_table(self, block_table: BlockTable) -> None:
self.cpu_allocator.free(block)

def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
block_table = self.block_tables[seq.seq_id]
seq_id = seq.seq_id
block_table = self.block_tables.pop(seq_id,[])
#if seq.seq_id not in self.block_tables:
# # Already freed or haven't been scheduled yet.
# return
#block_table = self.block_tables[seq.seq_id]
self._free_block_table(block_table)
del self.block_tables[seq.seq_id]
#del self.block_tables[seq.seq_id]

def free_cross(self, seq_group: SequenceGroup) -> None:
if seq_group.request_id not in self.cross_block_tables:
Expand Down
22 changes: 13 additions & 9 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,12 +681,13 @@ def _process_model_outputs(
"""

now = time.time()

# Organize outputs by [sequence group][step] instead of
# [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group(
output, num_seq_groups=len(scheduled_seq_groups))

seq_groups = [scheduled_seq_group.seq_group for scheduled_seq_group in scheduled_seq_groups]

# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(
scheduled_seq_groups, output_by_sequence_group,
Expand All @@ -708,14 +709,17 @@ def _process_model_outputs(
# Create the outputs.
request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
for seq_group in ignored_seq_groups:
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
[seq_group.maybe_set_first_token_time(now) for seq_group in seq_groups]
request_outputs = [RequestOutputFactory.create(seq_group) for seq_group in seq_groups]
#for scheduled_seq_group in scheduled_seq_groups:
# seq_group = scheduled_seq_group.seq_group
# seq_group.maybe_set_first_token_time(now)
# request_output = RequestOutputFactory.create(seq_group)
# request_outputs.append(request_output)
request_outputs.extend([RequestOutputFactory.create(seq_group) for seq_group in ignored_seq_groups])
#for seq_group in ignored_seq_groups:
# request_output = RequestOutputFactory.create(seq_group)
# request_outputs.append(request_output)
return request_outputs

def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Expand Down
24 changes: 24 additions & 0 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,30 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
# Process samples
samples = outputs.samples
if len(samples)==1:
#if there's only 1 sample, it has to be from 1 running seq in seq group
parent_seq = next(iter(seq_group.seqs_dict.values()))
child_sample = samples[0]
if not seq_group.sampling_params.use_beam_search:
#fastpath
parent_seq.append_token_id(child_sample.output_token,
child_sample.logprobs)
if self.detokenizer and seq_group.sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
parent_seq, seq_group.sampling_params)
else:
new_char_count = 0

stopped = self.stop_checker.maybe_stop_sequence(
parent_seq,
new_char_count,
seq_group.sampling_params,
lora_req=seq_group.lora_request,
)
#if parent_seq.is_finished():
if stopped:
self.scheduler.free_seq(parent_seq)
return
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict: Dict[int, List[SequenceOutput]] = {
Expand Down
24 changes: 13 additions & 11 deletions vllm/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def maybe_stop_sequence(
new_char_count: int,
sampling_params: SamplingParams,
lora_req: Optional[LoRARequest] = None,
) -> None:
) -> bool:
"""Stop the finished sequences.

new_char_count is the number of chars added to the
Expand All @@ -42,49 +42,51 @@ def maybe_stop_sequence(

# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
outlen = seq.get_output_len()
if outlen < sampling_params.min_tokens:
return False

last_token_id = seq.get_last_token_id()
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
and last_token_id == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
return True

# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
return True

# Check if any stop strings are matched.
stop_str = self._check_stop_strings(seq, new_char_count,
sampling_params)
if stop_str is not None:
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
return True

# Check if the sequence has reached max_model_len.
if seq.get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
return True

# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
if outlen == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
return True
return False

@staticmethod
def _check_stop_strings(seq: Sequence, new_char_count: int,
Expand Down
Loading