diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 06790d8ee2f8..c75d8f088c5b 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any, if vllm_config.parallel_config.data_parallel_size > 1: dp_size = vllm_config.parallel_config.data_parallel_size dp_rank = vllm_config.parallel_config.data_parallel_rank - if attn_metadata is not None: - if hasattr(attn_metadata, "num_prefill_tokens"): - # for v0 attention backends - batchsize = attn_metadata.num_prefill_tokens + \ - attn_metadata.num_decode_tokens - else: - # for v1 attention backends - batchsize = attn_metadata.num_input_tokens + if attn_metadata is not None and hasattr(attn_metadata, + "num_prefill_tokens"): + # for v0 attention backends + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens else: + # for v1 attention backends or no attn_metadata batchsize = num_tokens num_tokens_across_dp = [0] * dp_size num_tokens_across_dp[dp_rank] = batchsize @@ -124,7 +122,7 @@ def set_forward_context(attn_metadata: Any, attn_metadata.num_decode_tokens else: # for v1 attention backends - batchsize = attn_metadata.num_input_tokens + batchsize = num_tokens # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 41bb9aba2995..217dcd7c33ac 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -94,9 +94,6 @@ class FlashAttentionMetadata: scheduler_metadata: Optional[torch.Tensor] = None prefix_scheduler_metadata: Optional[torch.Tensor] = None - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - # for local attention @dataclass class LocalAttentionMetadata: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index bce446bd2b82..6e964b471fae 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -183,9 +183,6 @@ class FlashInferMetadata: decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - @property def query_start_loc(self): # The GPUModelRunner expects to be able to access this property. diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b032006d1ad1..fd3be901f4c3 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -312,9 +312,6 @@ class MLACommonMetadata(Generic[D]): num_decode_tokens: int num_prefills: int - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - # The dimension of the attention heads head_dim: Optional[int] = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e3d8b94fe9d7..4711beadbd9f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1036,7 +1036,6 @@ def execute_model( num_input_tokens = round_up(num_scheduled_tokens, tp_size) else: num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1088,7 +1087,9 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens): output = self.model( input_ids=input_ids, positions=positions, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 67f8af29db0e..d716542f7898 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -769,7 +769,10 @@ def execute_model( xm.mark_step() num_reqs = self.input_batch.num_reqs # Run the decoder - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=scheduler_output.total_num_scheduled_tokens): hidden_states = self.model( input_ids=input_ids, positions=self.position_ids,