Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Token blocks."""
from typing import List
from typing import List, Optional

from vllm.utils import Device

Expand All @@ -25,6 +25,7 @@ def __init__(

self.token_ids = [_BLANK_TOKEN_ID] * block_size
self.num_tokens = 0
self.block_hash: Optional[int] = None

def is_empty(self) -> bool:
return self.num_tokens == 0
Expand Down
56 changes: 22 additions & 34 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ def __init__(
self.cross_block_tables: Dict[str, BlockTable] = {}

def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
return 0 if seq is None \
else len(seq.logical_token_blocks)
return 0 if seq is None else len(seq.logical_token_blocks)

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
Expand All @@ -275,8 +274,8 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
seq_group.get_seqs(status=SequenceStatus.WAITING)[0])
cross_num_required_blocks = self._get_seq_num_required_blocks(
seq_group.get_encoder_seq())
num_required_blocks = self_num_required_blocks + \
cross_num_required_blocks
num_required_blocks = (self_num_required_blocks +
cross_num_required_blocks)

if self.block_sliding_window is not None:

Expand All @@ -293,9 +292,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
else:
return AllocStatus.LATER

def _allocate_sequence(self, \
seq: Sequence, \
ref_count: int, \
def _allocate_sequence(self,
seq: Sequence,
ref_count: int,
is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks)
Expand Down Expand Up @@ -328,10 +327,8 @@ def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same
# decoder prompt.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
block_table: BlockTable = \
self._allocate_sequence(seq,
seq_group.num_seqs(),
is_encoder_decoder)
block_table: BlockTable = self._allocate_sequence(
seq, seq_group.num_seqs(), is_encoder_decoder)

# Assign the self-attention block tables for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
Expand Down Expand Up @@ -368,6 +365,7 @@ def _promote_last_block(
# Compute a new hash for the block so that it can be shared by other
# Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
assert new_hash is not None, "Last block is not full."

# if new_hash is already in the cached table, then free last_block
# and return the cached version
Expand Down Expand Up @@ -406,9 +404,7 @@ def _allocate_last_physical_block(
# content hash.
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1)

Expand Down Expand Up @@ -553,18 +549,14 @@ def swap_in(self,
# dict is efficient in lookup `if cpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
self.block_tables[seq.seq_id] = \
self._swap_block_table(self.block_tables[seq.seq_id],
self.cpu_allocator,
self.gpu_allocator,
mapping)
self.block_tables[seq.seq_id] = self._swap_block_table(
self.block_tables[seq.seq_id], self.cpu_allocator,
self.gpu_allocator, mapping)

if seq_group.is_encoder_decoder():
self.cross_block_tables[request_id] = \
self._swap_block_table(self.cross_block_tables[request_id],
self.cpu_allocator,
self.gpu_allocator,
mapping)
self.cross_block_tables[request_id] = self._swap_block_table(
self.cross_block_tables[request_id], self.cpu_allocator,
self.gpu_allocator, mapping)

return [(cpu_block.block_number, gpu_block.block_number)
for cpu_block, gpu_block in mapping.items()]
Expand All @@ -580,18 +572,14 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
# dict is efficient in lookup `if gpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
self.block_tables[seq.seq_id] = \
self._swap_block_table(self.block_tables[seq.seq_id],
self.gpu_allocator,
self.cpu_allocator,
mapping)
self.block_tables[seq.seq_id] = self._swap_block_table(
self.block_tables[seq.seq_id], self.gpu_allocator,
self.cpu_allocator, mapping)

if seq_group.is_encoder_decoder():
self.cross_block_tables[request_id] = \
self._swap_block_table(self.cross_block_tables[request_id],
self.gpu_allocator,
self.cpu_allocator,
mapping)
self.cross_block_tables[request_id] = self._swap_block_table(
self.cross_block_tables[request_id], self.gpu_allocator,
self.cpu_allocator, mapping)

return [(cpu_block.block_number, gpu_block.block_number)
for cpu_block, gpu_block in mapping.items()]
Expand Down
25 changes: 17 additions & 8 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,24 @@ def get_output_text_to_return(self, buffer_length: int):
return self.output_text[:-buffer_length] if truncate else (
self.output_text)

def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size

# Compute the number of tokens in the sequence
def hash_of_block(self, logical_idx: int) -> Optional[int]:
"""Return the hash of the block if it is full."""
# TODO: The current hashing function is O(L^2). We should optimize
# this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
return hash((hashed_tokens, self.lora_int_id))
assert logical_idx < len(self.logical_token_blocks), (
f"logical_idx={logical_idx} is out of range for "
f"logical_token_blocks={len(self.logical_token_blocks)}")
block = self.logical_token_blocks[logical_idx]
if block.block_hash is not None:
return block.block_hash
if not block.is_full():
return None
num_hashed_tokens = self.num_hashed_tokens_of_block(logical_idx)
hashed_tokens = self.data.get_prefix_token_ids(num_hashed_tokens)
block_hash = hash((hashed_tokens, self.lora_int_id))
# Cache the block hash for future use.
block.block_hash = block_hash
return block_hash

def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size
Expand Down Expand Up @@ -632,7 +641,7 @@ class SequenceGroupMetadata:
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
(SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder
model.
cross_block_table: Optional cross-attention block table associated
Expand Down