Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,3 +626,33 @@ def test_reset_prefix_cache():
assert manager.reset_prefix_cache()
assert not manager.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool])


def test_uncache_blocks():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)

req0 = make_request("0", list(range(30)))
blocks = manager.allocate_slots(req0, 30, [])
assert [b.block_id for b in blocks] == [0, 1]
assert len(manager.cached_block_hash_to_block) == 1

req0.num_computed_tokens = 30

# Simulate speculative tokens.
for _ in range(5):
req0.append_output_token_ids(8)
manager.append_slots(req0, 5)
assert len(manager.cached_block_hash_to_block) == 2

# After sampling, assuming only 1 token is accepted.
req0.num_computed_tokens = 31
num_uncached_blocks = manager.uncache_blocks(req0)
assert num_uncached_blocks == 1
assert len(manager.cached_block_hash_to_block) == 1
33 changes: 31 additions & 2 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,29 @@ def free(self, request: Request) -> None:
if block.ref_cnt == 0:
self.free_block_queue.append(block)

def uncache_blocks(self, request: Request) -> int:
"""Uncache the blocks that are no longer full based on the
num_computed_tokens in the given request. This happens when
the blocks were full and cached due to speculative tokens, but the
speculative tokens are not accepted.

Args:
request: The request.

Returns:
The number of uncached blocks.
"""
blocks = self.req_to_blocks[request.request_id]
num_computed_tokens = request.num_computed_tokens
num_full_blocks = num_computed_tokens // self.block_size
num_uncached_blocks = 0
for block in blocks[num_full_blocks:]:
# If the block is not cached, the following blocks are not cached.
if not self._maybe_evict_cached_block(block):
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: what's the case a full block is not cached ?

Copy link
Collaborator Author

@comaniac comaniac Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't check if this block is full, so this happens when we encounter the first partial block. For example, your block table is [Full, Full, Full (but need uncache), Partial, Empty (pre-allocated)], then we will break at Partial.

break
num_uncached_blocks += 1
return num_uncached_blocks

def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
Expand Down Expand Up @@ -386,21 +409,24 @@ def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:

# If the block is cached, evict it.
if self.enable_caching:
self._evict_cached_block(curr_block)
self._maybe_evict_cached_block(curr_block)

curr_block.incr_ref()
ret.append(curr_block)
idx += 1

return ret

def _evict_cached_block(self, block: KVCacheBlock) -> None:
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
"""
If a block is cached in `cached_block_hash_to_block`, we reset its hash
metadata and evict it from the cache.

Args:
block: The block to evict.

Returns:
True if the block is evicted, False otherwise.
"""
block_hash = block.block_hash
if block_hash and block_hash in self.cached_block_hash_to_block:
Expand All @@ -410,6 +436,9 @@ def _evict_cached_block(self, block: KVCacheBlock) -> None:
if len(self.cached_block_hash_to_block[block_hash]) == 0:
del self.cached_block_hash_to_block[block_hash]

return True
return False

def _get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
Expand Down
Loading