diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 1848fd1de5cf..ca577a6721fe 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -184,6 +184,9 @@ class SamplingParams( allowed_token_ids: If provided, the engine will construct a logits processor which only retains scores for the given token ids. Defaults to None. + extra_args: Arbitrary additional args, that can be used by custom + sampling implementations. Not used by any in-tree sampling + implementations. """ n: int = 1 @@ -227,6 +230,7 @@ class SamplingParams( guided_decoding: Optional[GuidedDecodingParams] = None logit_bias: Optional[dict[int, float]] = None allowed_token_ids: Optional[list[int]] = None + extra_args: Optional[dict[str, Any]] = None @staticmethod def from_optional( @@ -259,6 +263,7 @@ def from_optional( guided_decoding: Optional[GuidedDecodingParams] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, allowed_token_ids: Optional[list[int]] = None, + extra_args: Optional[dict[str, Any]] = None, ) -> "SamplingParams": if logit_bias is not None: # Convert token_id to integer @@ -300,6 +305,7 @@ def from_optional( guided_decoding=guided_decoding, logit_bias=logit_bias, allowed_token_ids=allowed_token_ids, + extra_args=extra_args, ) def __post_init__(self) -> None: @@ -509,7 +515,8 @@ def __repr__(self) -> str: "spaces_between_special_tokens=" f"{self.spaces_between_special_tokens}, " f"truncate_prompt_tokens={self.truncate_prompt_tokens}, " - f"guided_decoding={self.guided_decoding})") + f"guided_decoding={self.guided_decoding}, " + f"extra_args={self.extra_args})") class BeamSearchParams(