-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Description
Motivation.
Currently, guided decoding & logit processor API is incomplete has has several issues. The RFC is intended to bring up problems and solutions. Some of issues may have been already addressed and there are PRs out already.
There are 3 major issues.
- It is not supported from SamplingParamters
- It is not possible to support batch/async logit processing.
- Upon failures, engine will die.
Proposed Change.
API
guided decoding parameters are not supported with SamplingParams. It is addressed from #4130
Performance
Currently, logit processors APIs are applied row by row blocking (
logits_row = logits_processor(prompt_tokens_ids, |
This requires logit processor to be
- stateful (to use a tool like Ray or thread pool). I think this PR [Core] Fix sharing of stateful logits processors #5329 is likely sufficient.
- async. We'd like to propose "prepare" API which can separate out compute_logits from preparing logits.
class LogitPostProcessor:
def initialize(self, logit_processor_config: LogitProcessorConfig):
"""Initialize the post processor. Post processor may have states
such as thread pool or Ray actors. It should be initialized
here.
"""
...
def prepare(
self,
seq_gruop_metadata_list: List[SequenceGroupMetadata]):
"""Asynchronously prepare logit masks."""
...
def apply(self, logits: torch.Tensor) -> torch.Tensor:
"""Apply the prepared masks to a given logits."""
...
# For each model, we will have
def compute_logits(...):
....
def prepare_logits(seq_group_metadata_list):
....
prepare
and apply
assume 1:1 calls. E.g., once prepare is called, apply has to be called before another prepare is called. I think it is the safe assumption. Alternatively, we can make prepare return a class, but that will make interface surface larger, so I don't prefer that solution (but I am open to hear feedback!)
This is the example usage of the API
# each model will have prepare_logits API
self.model.prepare_logits(seq_group_metadata_list)
hidden_states = model_executable(
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
**multi_modal_kwargs,
)
# Compute the logits. logit processors are applied here.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
We are also considering to upstream Ray based batch processing implementation with lmformatenforcer.
Failure Handling
When using a stateful logit processor, it is possible requests are failed. For example, if we use Ray, Ray actors can die. Or there could be user's schema issue that cannot be caught ahead of time.
When it happens, we should fail the seq_group immediately. We will introduce a new status "FINISHED_INTERNAL_ERROR = enum.auto()" to
Line 42 in 246598a
class SequenceStatus(enum.Enum): |
Feedback Period.
No response
CC List.
Any Other Things.
No response