From 76ee8948dc7e0f8b8ef41db86d2bdacbdab4967a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 4 Feb 2025 18:44:19 -0800 Subject: [PATCH] [V1][Misc] Shorten FinishReason enum and use constant strings Small follow-on to https://github.com/vllm-project/vllm/pull/12579 Signed-off-by: Nick Hill --- vllm/v1/engine/__init__.py | 12 +++++++++--- vllm/v1/engine/detokenizer.py | 7 +++---- vllm/v1/metrics/loggers.py | 6 +++--- vllm/v1/metrics/stats.py | 7 +++---- vllm/v1/request.py | 14 +++++++------- 5 files changed, 25 insertions(+), 21 deletions(-) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 6bd548bdcd8e..d5933cac50c2 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -14,11 +14,17 @@ from vllm.multimodal.inputs import PlaceholderRange from vllm.sampling_params import SamplingParams +# These are possible values of RequestOutput.finish_reason, +# so form part of the external API. +FINISH_REASON_STRINGS = ("stop", "length", "abort") -class RequestFinishedReason(enum.IntEnum): + +class FinishReason(enum.IntEnum): """ Reason a request finished - stop, length, or abort. + Int rather than Str for more compact serialization. + stop - a stop string was emitted length - max_tokens was consumed, or max_model_len was reached abort - aborted for another reason @@ -29,7 +35,7 @@ class RequestFinishedReason(enum.IntEnum): ABORT = 2 def __str__(self): - return self.name.lower() + return FINISH_REASON_STRINGS[self.value] @dataclass @@ -62,7 +68,7 @@ class EngineCoreOutput( request_id: str new_token_ids: List[int] finished: bool - finish_reason: Optional[RequestFinishedReason] = None + finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 2bce23e68d27..861fcb012c34 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -8,8 +8,7 @@ from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) -from vllm.v1.engine import (EngineCoreOutput, EngineCoreRequest, - RequestFinishedReason) +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason logger = init_logger(__name__) @@ -19,7 +18,7 @@ class DetokenizerOutput: output_text: str token_ids: List[int] finished: bool - finish_reason: Optional[RequestFinishedReason] = None + finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None @@ -148,7 +147,7 @@ def update_from_output( stop_str, truncate_to = stop if truncate_to != -1: self.output_text = self.output_text[:truncate_to] - finish_reason = RequestFinishedReason.STOP + finish_reason = FinishReason.STOP stop_reason = stop_str # TODO: handle stop_token_ids here too? diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index b62351a8fd6e..eb1acf584c6b 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -9,7 +9,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.v1.engine import RequestFinishedReason +from vllm.v1.engine import FinishReason from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -117,13 +117,13 @@ def __init__(self, model_config: ModelConfig): documentation="Number of generation tokens processed.", labelnames=labelnames).labels(*labelvalues) - self.counter_request_success: Dict[RequestFinishedReason, + self.counter_request_success: Dict[FinishReason, prometheus_client.Counter] = {} counter_request_success_base = prometheus_client.Counter( name="vllm:request_success_total", documentation="Count of successfully processed requests.", labelnames=labelnames + ["finished_reason"]) - for reason in RequestFinishedReason: + for reason in FinishReason: self.counter_request_success[ reason] = counter_request_success_base.labels(*(labelvalues + [str(reason)])) diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 36c95e07d8a9..e3f1efcc9b1a 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from vllm.outputs import RequestOutput - from vllm.v1.engine import EngineCoreOutput, RequestFinishedReason + from vllm.v1.engine import EngineCoreOutput, FinishReason @dataclass @@ -32,7 +32,7 @@ class RequestStateStats: class FinishedRequestStats: """Stats associated with a finished request.""" - finish_reason: "RequestFinishedReason" + finish_reason: "FinishReason" num_prompt_tokens: int = 0 num_generation_tokens: int = 0 @@ -74,8 +74,7 @@ def update_from_output(self, output: "EngineCoreOutput", request_state_stats.num_generation_tokens += num_new_generation_tokens request_state_stats.last_token_time = now - def update_from_finished_request(self, - finish_reason: "RequestFinishedReason", + def update_from_finished_request(self, finish_reason: "FinishReason", request_output: "RequestOutput", request_state_stats: RequestStateStats): self.finished_requests.append( diff --git a/vllm/v1/request.py b/vllm/v1/request.py index eb9bf99b406f..89b39ea615d2 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -6,7 +6,7 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import RequestMetrics -from vllm.v1.engine import EngineCoreRequest, RequestFinishedReason +from vllm.v1.engine import EngineCoreRequest, FinishReason from vllm.v1.utils import ConstantList if TYPE_CHECKING: @@ -109,7 +109,7 @@ def num_output_tokens(self) -> int: def is_finished(self) -> bool: return RequestStatus.is_finished(self.status) - def get_finished_reason(self) -> Union[RequestFinishedReason, None]: + def get_finished_reason(self) -> Union[FinishReason, None]: return RequestStatus.get_finished_reason(self.status) def has_encoder_inputs(self) -> bool: @@ -150,7 +150,7 @@ def is_finished(status: "RequestStatus") -> bool: @staticmethod def get_finished_reason( - status: "RequestStatus") -> Union[RequestFinishedReason, None]: + status: "RequestStatus") -> Union[FinishReason, None]: return _FINISHED_REASON_MAP.get(status) @@ -159,8 +159,8 @@ def get_finished_reason( # are longer than the model's length cap. Therefore, the stop # reason should also be "length" as in OpenAI API. _FINISHED_REASON_MAP = { - RequestStatus.FINISHED_STOPPED: RequestFinishedReason.STOP, - RequestStatus.FINISHED_LENGTH_CAPPED: RequestFinishedReason.LENGTH, - RequestStatus.FINISHED_ABORTED: RequestFinishedReason.ABORT, - RequestStatus.FINISHED_IGNORED: RequestFinishedReason.LENGTH, + RequestStatus.FINISHED_STOPPED: FinishReason.STOP, + RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH, + RequestStatus.FINISHED_ABORTED: FinishReason.ABORT, + RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH, }