Skip to content

[RFC]: Improve guided decoding (logit_processor) APIs and performance. #5423

@rkooo567

Description

@rkooo567

Motivation.

Currently, guided decoding & logit processor API is incomplete has has several issues. The RFC is intended to bring up problems and solutions. Some of issues may have been already addressed and there are PRs out already.

There are 3 major issues.

  • It is not supported from SamplingParamters
  • It is not possible to support batch/async logit processing.
  • Upon failures, engine will die.

Proposed Change.

API

guided decoding parameters are not supported with SamplingParams. It is addressed from #4130

Performance

Currently, logit processors APIs are applied row by row blocking (

logits_row = logits_processor(prompt_tokens_ids,
). Instead, we can use parallel processing (e.g., ray or thread pool) to improve the logit processing performance. We are using this mechanism internally at Anyscale. We'd like to support this feature in OSS, and would like to improve logit processor API to support 1. async. 2. batching.

This requires logit processor to be

class LogitPostProcessor:
   def initialize(self, logit_processor_config: LogitProcessorConfig):
       """Initialize the post processor. Post processor may have states
           such as thread pool or Ray actors. It should be initialized
           here.
       """
       ...

   def prepare(
           self,
           seq_gruop_metadata_list: List[SequenceGroupMetadata]):
       """Asynchronously prepare logit masks."""
       ...

   def apply(self, logits: torch.Tensor) -> torch.Tensor:
       """Apply the prepared masks to a given logits."""
       ...

# For each model, we will have

def compute_logits(...):
    ....

def prepare_logits(seq_group_metadata_list):
    ....

prepare and apply assume 1:1 calls. E.g., once prepare is called, apply has to be called before another prepare is called. I think it is the safe assumption. Alternatively, we can make prepare return a class, but that will make interface surface larger, so I don't prefer that solution (but I am open to hear feedback!)

This is the example usage of the API

        # each model will have prepare_logits API
        self.model.prepare_logits(seq_group_metadata_list)
        hidden_states = model_executable(
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            **multi_modal_kwargs,
        )
        # Compute the logits. logit processors are applied here.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

We are also considering to upstream Ray based batch processing implementation with lmformatenforcer.

Failure Handling

When using a stateful logit processor, it is possible requests are failed. For example, if we use Ray, Ray actors can die. Or there could be user's schema issue that cannot be caught ahead of time.

When it happens, we should fail the seq_group immediately. We will introduce a new status "FINISHED_INTERNAL_ERROR = enum.auto()" to

class SequenceStatus(enum.Enum):
. If any logit processor is failed, we will mark the relevant seq_group as failed, and the request will be aborted.

Feedback Period.

No response

CC List.

cc @simon-mo @Yard1

Any Other Things.

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions