From d35146f678765a0054c9facc214be79d4d79467c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 07:33:19 -0700 Subject: [PATCH 1/2] remove num_input_tokens from attn_metadata Signed-off-by: Chen Zhang --- vllm/forward_context.py | 14 ++++++-------- vllm/v1/attention/backends/flash_attn.py | 3 --- vllm/v1/attention/backends/flashinfer.py | 3 --- vllm/v1/attention/backends/mla/common.py | 3 --- vllm/v1/worker/gpu_model_runner.py | 5 +++-- vllm/v1/worker/tpu_model_runner.py | 5 ++++- 6 files changed, 13 insertions(+), 20 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 06790d8ee2f8..d7c43b56827d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any, if vllm_config.parallel_config.data_parallel_size > 1: dp_size = vllm_config.parallel_config.data_parallel_size dp_rank = vllm_config.parallel_config.data_parallel_rank - if attn_metadata is not None: - if hasattr(attn_metadata, "num_prefill_tokens"): - # for v0 attention backends - batchsize = attn_metadata.num_prefill_tokens + \ - attn_metadata.num_decode_tokens - else: - # for v1 attention backends - batchsize = attn_metadata.num_input_tokens + if attn_metadata is not None and hasattr(attn_metadata, + "num_prefill_tokens"): + # for v0 attention backends + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens else: + # for v1 attention backends or no attn_metadata batchsize = num_tokens num_tokens_across_dp = [0] * dp_size num_tokens_across_dp[dp_rank] = batchsize diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 51ae386d3389..7c20f94b915b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -94,9 +94,6 @@ class FlashAttentionMetadata: scheduler_metadata: Optional[torch.Tensor] = None prefix_scheduler_metadata: Optional[torch.Tensor] = None - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - # for local attention @dataclass class LocalAttentionMetadata: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 17341ecfa4fe..aae170984ab2 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -184,9 +184,6 @@ class FlashInferMetadata: decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - @property def query_start_loc(self): # The GPUModelRunner expects to be able to access this property. diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f826f8a21789..75a11bd46920 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -312,9 +312,6 @@ class MLACommonMetadata(Generic[D]): num_decode_tokens: int num_prefills: int - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - # The dimension of the attention heads head_dim: Optional[int] = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 86f6a301fbb6..83c6aaa9168d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1026,7 +1026,6 @@ def execute_model( else: # Eager mode. num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1078,7 +1077,9 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens): hidden_states = self.model( input_ids=input_ids, positions=positions, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e9cb0dbe8b5e..0728efb168dc 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -771,7 +771,10 @@ def execute_model( xm.mark_step() num_reqs = self.input_batch.num_reqs # Run the decoder - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=scheduler_output.total_num_scheduled_tokens): hidden_states = self.model( input_ids=input_ids, positions=self.position_ids, From 20d930be75bcfee9b1fe1f4d37af78f57732f864 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 07:42:33 -0700 Subject: [PATCH 2/2] fix Signed-off-by: Chen Zhang --- vllm/forward_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index d7c43b56827d..c75d8f088c5b 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -122,7 +122,7 @@ def set_forward_context(attn_metadata: Any, attn_metadata.num_decode_tokens else: # for v1 attention backends - batchsize = attn_metadata.num_input_tokens + batchsize = num_tokens # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch