From 501724d65b8b4a32e0f8f0dd1761ce1a8dd365b5 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 21 Jan 2025 10:47:00 -0800 Subject: [PATCH 1/6] wip Signed-off-by: Cody Yu --- vllm/core/block/cpu_gpu_block_allocator.py | 5 +++++ vllm/core/block/interfaces.py | 10 ++++++++++ vllm/core/block/naive_block.py | 4 ++++ vllm/core/block/prefix_caching_block.py | 16 +++++++++++++++- vllm/core/block_manager.py | 3 +++ vllm/core/interfaces.py | 5 +++++ vllm/core/placeholder_block_space_manager.py | 3 +++ vllm/core/scheduler.py | 3 +++ vllm/engine/llm_engine.py | 5 +++++ vllm/entrypoints/llm.py | 3 +++ vllm/v1/core/kv_cache_manager.py | 17 +++++++++++++++++ vllm/v1/core/scheduler.py | 3 +++ vllm/v1/engine/core.py | 3 +++ vllm/v1/engine/core_client.py | 5 ++++- vllm/v1/engine/llm_engine.py | 3 +++ 15 files changed, 86 insertions(+), 2 deletions(-) diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 3a57487a6cd8..bbc67b3b37de 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -339,6 +339,11 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: assert device in self._allocators return self._allocators[device].get_prefix_cache_hit_rate() + def reset_prefix_cache(self): + """Reset prefix cache for all devices.""" + for allocator in self._allocators.values(): + allocator.reset_prefix_cache() + def get_and_reset_swaps(self) -> List[Tuple[int, int]]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 985a1098b6cd..00c70e0a4140 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -192,6 +192,11 @@ def get_prefix_cache_hit_rate(self) -> float: """Prefix cache hit rate. -1 means not supported or disabled.""" pass + @abstractmethod + def reset_prefix_cache(self): + """Reset prefix cache.""" + pass + class NoFreeBlocksError(ValueError): pass @@ -297,6 +302,11 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: """Prefix cache hit rate. -1 means not supported or disabled.""" pass + @abstractmethod + def reset_prefix_cache(self): + """Reset prefix cache.""" + pass + @abstractmethod def find_cached_blocks_prefix( self, diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 9b94918ab38e..74f4825a9e09 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -324,6 +324,10 @@ def swap_in(self, blocks: List[Block]) -> None: def get_prefix_cache_hit_rate(self) -> float: return -1 + + def reset_prefix_cache(self): + """No effect.""" + pass def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: # Not applicable for naive block allocator. diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 1238303234de..e8aa39ad3d81 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -105,7 +105,8 @@ def __init__( # Evitor used to maintain how we want to handle those computed blocks # if we find memory pressure is high. - self.evictor: Evictor = make_evictor(eviction_policy) + self.eviction_policy = eviction_policy + self.evictor: Evictor = make_evictor(self.eviction_policy) # We share the refcounter between allocators. This allows us to promote # blocks originally allocated in the hashless allocator to immutable @@ -427,6 +428,19 @@ def all_block_ids(self) -> FrozenSet[int]: def get_prefix_cache_hit_rate(self) -> float: return self.metric_data.get_hit_rate() + + def reset_prefix_cache(self): + """Reset prefix cache.""" + num_used_blocks = self.get_num_total_blocks - self.get_num_free_blocks + if num_used_blocks > 0: + raise RuntimeError("Failed to reset prefix cache because some " + f"blocks ({num_used_blocks}) are not freed yet") + + # Reset the evictor. + self.evictor = make_evictor(self.eviction_policy) + + # Reset the metrics. + self.metric_data = CacheMetricData() def is_block_cached(self, block: Block) -> bool: assert block.content_hash is not None diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index b41e84822188..ab013cf9ce3e 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -455,6 +455,9 @@ def get_num_free_cpu_blocks(self) -> int: def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_allocator.get_prefix_cache_hit_rate(device) + def reset_prefix_cache(self): + self.block_allocator.reset_prefix_cache() + def _can_swap(self, seq_group: SequenceGroup, device: Device, diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index b10b8d3f4a5b..400e60a290f2 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -122,6 +122,11 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: """Prefix cache hit rate. -1 means not supported or disabled.""" pass + @abstractmethod + def reset_prefix_cache(self, device: Device): + """Reset prefix cache.""" + pass + @abstractmethod def get_num_cached_tokens(self, seq: Sequence) -> int: pass diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py index a47e59451853..785291dec9cc 100644 --- a/vllm/core/placeholder_block_space_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -90,5 +90,8 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup, def get_prefix_cache_hit_rate(self, device: Device) -> float: return -1 + def reset_prefix_cache(self, device: Device): + pass + def get_num_cached_tokens(self, seq: Sequence) -> int: return 0 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index b3d396f9cedd..def697511608 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -504,6 +504,9 @@ def has_unfinished_seqs(self) -> bool: def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_manager.get_prefix_cache_hit_rate(device) + def reset_prefix_cache(self): + self.block_manager.reset_prefix_cache() + def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6a6b4a14a4c4..148a6e9ce4c3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -914,6 +914,11 @@ def has_unfinished_requests_for_virtual_engine( """ return self.scheduler[virtual_engine].has_unfinished_seqs() + def reset_prefix_cache(self): + """Reset prefix cache for all devices.""" + for scheduler in self.scheduler: + scheduler.reset_prefix_cache() + @staticmethod def _process_sequence_group_outputs( seq_group: SequenceGroup, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 27386daa4bbc..fe622d9ab236 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1132,6 +1132,9 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.llm_engine.stop_profile() + def reset_prefix_cache(self) -> None: + self.llm_engine.reset_prefix_cache() + # LEGACY def _convert_v1_inputs( self, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index bac77443c856..3ec28bc8fb5d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -285,6 +285,23 @@ def free(self, request: Request) -> None: if block.ref_cnt == 0: self.free_block_queue.append(block) + def reset_prefix_cache(self): + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + """ + num_used_blocks = (self.num_gpu_blocks - self.free_block_queue.num_free_blocks) + if num_used_blocks > 0: + raise RuntimeError("Failed to reset prefix cache because some " + f"blocks ({num_used_blocks}) are not freed yet") + + # Remove all hashes so that no new blocks will hit. + self.cached_block_hash_to_block = {} + # Remove all hashes from all blocks. + for block in self.block_pool: + block.reset_hash() + + def get_num_common_prefix_blocks( self, request: Request, diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 64df21d59fef..870405a3355a 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -529,6 +529,9 @@ def get_num_unfinished_requests(self) -> int: def has_unfinished_requests(self) -> bool: return self.get_num_unfinished_requests() > 0 + def reset_prefix_cache(self): + self.kv_cache_manager.reset_prefix_cache() + def make_stats(self) -> SchedulerStats: return SchedulerStats( num_running_reqs=len(self.running), diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 26ebc7edcf03..2afe7b197220 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -135,6 +135,9 @@ def shutdown(self): def profile(self, is_start: bool = True): self.model_executor.profile(is_start) + def reset_prefix_cache(self): + self.scheduler.reset_prefix_cache() + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index ac0f0f14bf1a..a01dc73fe715 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -108,12 +108,15 @@ def abort_requests(self, request_ids: List[str]) -> None: if len(request_ids) > 0: self.engine_core.abort_requests(request_ids) - def shutdown(self): + def shutdown(self) -> None: self.engine_core.shutdown() def profile(self, is_start: bool = True) -> None: self.engine_core.profile(is_start) + def reset_prefix_cache(self) -> None: + self.engine_core.reset_prefix_cache() + class MPClient(EngineCoreClient): """ diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index f5999ccda644..55d314ebeb95 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -162,6 +162,9 @@ def start_profile(self): def stop_profile(self): self.engine_core.profile(False) + def reset_prefix_cache(self): + self.engine_core.reset_prefix_cache() + def get_tokenizer_group( self, group_type: Type[_G] = BaseTokenizerGroup, From 376fd70049f974c16f16cb7acb85cc7d24778373 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 21 Jan 2025 19:21:33 +0000 Subject: [PATCH 2/6] done Signed-off-by: Cody Yu --- vllm/core/block/prefix_caching_block.py | 16 +++++++++++----- vllm/core/interfaces.py | 6 +++--- vllm/core/placeholder_block_space_manager.py | 2 +- vllm/v1/core/kv_cache_manager.py | 13 ++++++++----- vllm/v1/engine/__init__.py | 9 ++++++++- vllm/v1/engine/core.py | 8 ++++++-- vllm/v1/engine/core_client.py | 16 +++++++++++++++- 7 files changed, 52 insertions(+), 18 deletions(-) diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index e8aa39ad3d81..3f7c75ca2a1f 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -12,6 +12,7 @@ from vllm.core.block.naive_block import (BlockPool, NaiveBlock, NaiveBlockAllocator) from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor +from vllm.logger import init_logger from vllm.sequence import Sequence PrefixHash = int @@ -21,6 +22,8 @@ # then we know this block hasn't been accessed yet. _DEFAULT_LAST_ACCESSED_TIME = -1 +logger = init_logger(__name__) + class BlockTracker: """Used to track the status of a block inside the prefix caching allocator @@ -428,14 +431,17 @@ def all_block_ids(self) -> FrozenSet[int]: def get_prefix_cache_hit_rate(self) -> float: return self.metric_data.get_hit_rate() - + def reset_prefix_cache(self): """Reset prefix cache.""" - num_used_blocks = self.get_num_total_blocks - self.get_num_free_blocks + num_used_blocks = (self.get_num_total_blocks() - + self.get_num_free_blocks()) if num_used_blocks > 0: - raise RuntimeError("Failed to reset prefix cache because some " - f"blocks ({num_used_blocks}) are not freed yet") - + logger.warning( + "Failed to reset prefix cache because some " + "blocks (%d) are not freed yet", num_used_blocks) + return + # Reset the evictor. self.evictor = make_evictor(self.eviction_policy) diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 400e60a290f2..7fced2198861 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -123,9 +123,9 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: pass @abstractmethod - def reset_prefix_cache(self, device: Device): - """Reset prefix cache.""" - pass + def reset_prefix_cache(self): + """Reset prefix cache for all devices.""" + pass @abstractmethod def get_num_cached_tokens(self, seq: Sequence) -> int: diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py index 785291dec9cc..e841065accc6 100644 --- a/vllm/core/placeholder_block_space_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -90,7 +90,7 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup, def get_prefix_cache_hit_rate(self, device: Device) -> float: return -1 - def reset_prefix_cache(self, device: Device): + def reset_prefix_cache(self): pass def get_num_cached_tokens(self, seq: Sequence) -> int: diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 3ec28bc8fb5d..1acd22e3daf0 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -290,18 +290,21 @@ def reset_prefix_cache(self): flows to invalid prefix caching after the weights are updated, or used for resetting prefix caching status for benchmarking. """ - num_used_blocks = (self.num_gpu_blocks - self.free_block_queue.num_free_blocks) + num_used_blocks = (self.num_gpu_blocks - + self.free_block_queue.num_free_blocks) if num_used_blocks > 0: - raise RuntimeError("Failed to reset prefix cache because some " - f"blocks ({num_used_blocks}) are not freed yet") + logger.warning( + "Failed to reset prefix cache because some " + "blocks (%d) are not freed yet", num_used_blocks) + return # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = {} + self.cached_block_hash_to_block = defaultdict(dict) + # Remove all hashes from all blocks. for block in self.block_pool: block.reset_hash() - def get_num_common_prefix_blocks( self, request: Request, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 6d90c38c72cf..abe4952c4baf 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -66,6 +66,11 @@ class EngineCoreProfile: is_start: bool +@dataclass +class EngineCoreResetPrefixCache: + pass + + class EngineCoreRequestType(enum.Enum): """ Request types defined as hex byte strings, so it can be sent over sockets @@ -74,6 +79,8 @@ class EngineCoreRequestType(enum.Enum): ADD = b'\x00' ABORT = b'\x01' PROFILE = b'\x02' + RESET_PREFIX_CACHE = b'\x03' -EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]] +EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, + EngineCoreResetPrefixCache, List[str]] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 2afe7b197220..cf94033a38d9 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -20,7 +20,7 @@ from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, EngineCoreRequestType, - EngineCoreRequestUnion) + EngineCoreRequestUnion, EngineCoreResetPrefixCache) from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus @@ -250,6 +250,8 @@ def _handle_client_request(self, request: EngineCoreRequestUnion) -> None: self.add_request(request) elif isinstance(request, EngineCoreProfile): self.model_executor.profile(request.is_start) + elif isinstance(request, EngineCoreResetPrefixCache): + self.reset_prefix_cache() else: # TODO: make an EngineCoreAbort wrapper assert isinstance(request, list) @@ -274,7 +276,9 @@ def process_input_socket(self, input_path: str): request = decoder_add_req.decode(request_data) elif request_type == EngineCoreRequestType.ABORT.value: request = decoder_abort_req.decode(request_data) - elif request_type == EngineCoreRequestType.PROFILE.value: + elif request_type in ( + EngineCoreRequestType.PROFILE.value, + EngineCoreRequestType.RESET_PREFIX_CACHE.value): request = pickle.loads(request_data) else: raise ValueError(f"Unknown RequestType: {request_type}") diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index a01dc73fe715..19b89003cc69 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -14,7 +14,7 @@ make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, EngineCoreRequestType, - EngineCoreRequestUnion) + EngineCoreRequestUnion, EngineCoreResetPrefixCache) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import PickleEncoder @@ -69,6 +69,9 @@ def add_request(self, request: EngineCoreRequest) -> None: def profile(self, is_start: bool = True) -> None: raise NotImplementedError + def reset_prefix_cache(self) -> None: + raise NotImplementedError + def abort_requests(self, request_ids: List[str]) -> None: raise NotImplementedError @@ -81,6 +84,9 @@ async def add_request_async(self, request: EngineCoreRequest) -> None: async def profile_async(self, is_start: bool = True) -> None: raise NotImplementedError + async def reset_prefix_cache_async(self) -> None: + raise NotImplementedError + async def abort_requests_async(self, request_ids: List[str]) -> None: raise NotImplementedError @@ -232,6 +238,10 @@ def profile(self, is_start: bool = True) -> None: self._send_input(EngineCoreRequestType.PROFILE, EngineCoreProfile(is_start)) + def reset_prefix_cache(self) -> None: + self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, + EngineCoreResetPrefixCache()) + class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" @@ -269,3 +279,7 @@ async def abort_requests_async(self, request_ids: List[str]) -> None: async def profile_async(self, is_start: bool = True) -> None: await self._send_input(EngineCoreRequestType.PROFILE, EngineCoreProfile(is_start)) + + async def reset_prefix_cache_async(self) -> None: + await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, + EngineCoreResetPrefixCache()) From 714423f59355c086f8e7836fe7eb9bcf9b3a3e39 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 21 Jan 2025 21:32:21 +0000 Subject: [PATCH 3/6] test Signed-off-by: Cody Yu --- tests/core/block/test_prefix_caching_block.py | 38 ++++++++++++++++++ tests/v1/core/test_prefix_caching.py | 39 +++++++++++++++++++ vllm/core/block/cpu_gpu_block_allocator.py | 6 ++- vllm/core/block/interfaces.py | 4 +- vllm/core/block/naive_block.py | 26 ++++++++----- vllm/core/block/prefix_caching_block.py | 28 +++++++++++-- vllm/core/block_manager.py | 4 +- vllm/core/interfaces.py | 2 +- vllm/core/placeholder_block_space_manager.py | 4 +- vllm/core/scheduler.py | 4 +- vllm/engine/async_llm_engine.py | 3 ++ vllm/engine/llm_engine.py | 7 +++- vllm/engine/multiprocessing/__init__.py | 7 +++- vllm/engine/multiprocessing/client.py | 12 +++++- vllm/engine/multiprocessing/engine.py | 10 ++++- vllm/engine/protocol.py | 5 +++ vllm/entrypoints/llm.py | 4 +- vllm/entrypoints/openai/api_server.py | 11 ++++++ vllm/v1/core/kv_cache_manager.py | 11 +++++- vllm/v1/core/scheduler.py | 4 +- vllm/v1/engine/async_llm.py | 3 ++ 21 files changed, 196 insertions(+), 36 deletions(-) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 29ac3a3c86cb..6642174c17d8 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -796,6 +796,44 @@ def test_find_cached_blocks_prefix(): block_hashes=block_hashes_seq1) assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks + # Test reset prefix cache + @staticmethod + @pytest.mark.parametrize("num_blocks", [10]) + @pytest.mark.parametrize("block_size", [16]) + def test_reset_prefix_cache(num_blocks: int, block_size: int): + """This test case simulates the case of resetting the prefix cache.""" + + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + token_ids = list(range(3 * block_size)) + + first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + # Free each block in the first chain. + for block in first_chain: + allocator.free(block) + + # Failed to reset prefix cache because some blocks are not freed yet. + assert not allocator.reset_prefix_cache() + assert allocator.get_prefix_cache_hit_rate() > 0.0 + + # Free each block in the second chain. + for block in second_chain: + allocator.free(block) + + # Reset prefix cache. + assert allocator.reset_prefix_cache() + assert allocator.get_prefix_cache_hit_rate() == 0.0 + @staticmethod def create_immutable_chain( block_size: int, diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index fafd9d0ce445..c5860809f9e6 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -587,3 +587,42 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert {block.ref_cnt for block in block_part1[:3]} == {1} # Block 3-5 are free. assert {block.ref_cnt for block in block_part1[3:]} == {0} + + +def test_reset_prefix_cache(): + manager = KVCacheManager( + block_size=16, + num_gpu_blocks=10, + max_model_len=8192, + sliding_window=None, + enable_caching=True, + num_preallocate_tokens=0, + ) + + full_block_token_ids = [i for i in range(3) for _ in range(16)] + unique_token_ids = [3] * 7 + all_token_ids = full_block_token_ids + unique_token_ids + req0 = make_request("0", all_token_ids) + blocks = manager.allocate_slots(req0, 55, []) + assert [b.block_id for b in blocks] == [0, 1, 2, 3] + + unique_token_ids = [4] * 7 + all_token_ids = full_block_token_ids + unique_token_ids + req1 = make_request("1", all_token_ids) + computed_blocks, _ = manager.get_computed_blocks(req1) + assert len(req1.kv_block_hashes) == 3 + assert len(computed_blocks) == 3 + blocks = manager.allocate_slots(req1, 7, computed_blocks) + assert [b.block_id for b in blocks] == [4] + + # Failed to reset prefix cache because some blocks are not freed yet. + assert not manager.reset_prefix_cache() + assert manager.cached_block_hash_to_block + + # Free the blocks. + manager.free(req0) + manager.free(req1) + + assert manager.reset_prefix_cache() + assert not manager.cached_block_hash_to_block + assert all([blk.block_hash is None for blk in manager.block_pool]) diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index bbc67b3b37de..c3e1665b4464 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -339,10 +339,12 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: assert device in self._allocators return self._allocators[device].get_prefix_cache_hit_rate() - def reset_prefix_cache(self): + def reset_prefix_cache(self) -> bool: """Reset prefix cache for all devices.""" + success = True for allocator in self._allocators.values(): - allocator.reset_prefix_cache() + success = success and allocator.reset_prefix_cache() + return success def get_and_reset_swaps(self) -> List[Tuple[int, int]]: """Returns and clears the mapping of source to destination block IDs. diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 00c70e0a4140..cb432db919c7 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -193,7 +193,7 @@ def get_prefix_cache_hit_rate(self) -> float: pass @abstractmethod - def reset_prefix_cache(self): + def reset_prefix_cache(self) -> bool: """Reset prefix cache.""" pass @@ -303,7 +303,7 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: pass @abstractmethod - def reset_prefix_cache(self): + def reset_prefix_cache(self) -> bool: """Reset prefix cache.""" pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 74f4825a9e09..46c9c4ac2662 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -1,12 +1,15 @@ from collections import deque -from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple +from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device +from vllm.logger import init_logger Refcount = int +logger = init_logger(__name__) + class NaiveBlockAllocator(BlockAllocator): """A simple block allocator that manages blocks of memory without prefix @@ -136,16 +139,18 @@ def _allocate_block_id(self) -> BlockId: self._refcounter.incr(block_id) return block_id - def _free_block_id(self, block: Block) -> None: - block_id = block.block_id + def _free_block_id(self, block: Union[Block, BlockId]) -> None: + if isinstance(block, Block): + block_id = block.block_id + block.block_id = None + else: + block_id = block assert block_id is not None refcount = self._refcounter.decr(block_id) if refcount == 0: self._free_block_indices.appendleft(block_id) - block.block_id = None - def free(self, block: Block, keep_block_object: bool = False) -> None: # Release the physical block id self._free_block_id(block) @@ -154,6 +159,9 @@ def free(self, block: Block, keep_block_object: bool = False) -> None: if not keep_block_object: self._block_pool.free_block(block) + def free_block_id(self, block_id: BlockId) -> None: + self._free_block_id(block_id) + def fork(self, last_block: Block) -> List[Block]: """Creates a new sequence of blocks that shares the same underlying memory as the original sequence. @@ -324,10 +332,10 @@ def swap_in(self, blocks: List[Block]) -> None: def get_prefix_cache_hit_rate(self) -> float: return -1 - - def reset_prefix_cache(self): - """No effect.""" - pass + + def reset_prefix_cache(self) -> bool: + """No prefix cache for naive block allocator.""" + return True def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: # Not applicable for naive block allocator. diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 3f7c75ca2a1f..ccdc5daa9595 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -432,22 +432,44 @@ def all_block_ids(self) -> FrozenSet[int]: def get_prefix_cache_hit_rate(self) -> float: return self.metric_data.get_hit_rate() - def reset_prefix_cache(self): - """Reset prefix cache.""" + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ num_used_blocks = (self.get_num_total_blocks() - self.get_num_free_blocks()) if num_used_blocks > 0: logger.warning( "Failed to reset prefix cache because some " "blocks (%d) are not freed yet", num_used_blocks) - return + return False + + # Free all blocks in the evictor. + while (block_id := + self._maybe_allocate_evicted_block_id()) is not None: + self._hashless_allocator.free_block_id(block_id) + + # Should not have any cached blocks because all blocks are evicted. + assert not self._cached_blocks # Reset the evictor. self.evictor = make_evictor(self.eviction_policy) + # Reset the block tracker. + for block_id in self._block_tracker: + self._block_tracker[block_id] = BlockTracker() + # Reset the metrics. self.metric_data = CacheMetricData() + logger.info("Successfully reset prefix cache") + return True + def is_block_cached(self, block: Block) -> bool: assert block.content_hash is not None return block.content_hash in self._cached_blocks diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index ab013cf9ce3e..62a5f0bda061 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -455,8 +455,8 @@ def get_num_free_cpu_blocks(self) -> int: def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_allocator.get_prefix_cache_hit_rate(device) - def reset_prefix_cache(self): - self.block_allocator.reset_prefix_cache() + def reset_prefix_cache(self) -> bool: + return self.block_allocator.reset_prefix_cache() def _can_swap(self, seq_group: SequenceGroup, diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 7fced2198861..9c7e246e3c4e 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -123,7 +123,7 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: pass @abstractmethod - def reset_prefix_cache(self): + def reset_prefix_cache(self) -> bool: """Reset prefix cache for all devices.""" pass diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py index e841065accc6..f9924be4a383 100644 --- a/vllm/core/placeholder_block_space_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -90,8 +90,8 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup, def get_prefix_cache_hit_rate(self, device: Device) -> float: return -1 - def reset_prefix_cache(self): - pass + def reset_prefix_cache(self) -> bool: + return True def get_num_cached_tokens(self, seq: Sequence) -> int: return 0 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index def697511608..b1630b34947b 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -504,8 +504,8 @@ def has_unfinished_seqs(self) -> bool: def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_manager.get_prefix_cache_hit_rate(device) - def reset_prefix_cache(self): - self.block_manager.reset_prefix_cache() + def reset_prefix_cache(self) -> bool: + return self.block_manager.reset_prefix_cache() def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 08fef8250d48..739ea06ae381 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1182,6 +1182,9 @@ async def start_profile(self) -> None: async def stop_profile(self) -> None: self.engine.stop_profile() + async def reset_prefix_cache(self) -> None: + self.engine.reset_prefix_cache() + async def add_lora(self, lora_request: LoRARequest) -> None: self.engine.add_lora(lora_request) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 148a6e9ce4c3..af1eae1a6afd 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -914,10 +914,13 @@ def has_unfinished_requests_for_virtual_engine( """ return self.scheduler[virtual_engine].has_unfinished_seqs() - def reset_prefix_cache(self): + def reset_prefix_cache(self) -> bool: """Reset prefix cache for all devices.""" + + success = True for scheduler in self.scheduler: - scheduler.reset_prefix_cache() + success = success and scheduler.reset_prefix_cache() + return success @staticmethod def _process_sequence_group_outputs( diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 7132f9840001..d9703b820a77 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -121,6 +121,10 @@ class RPCUProfileRequest(Enum): STOP_PROFILE = 2 +class RPCResetPrefixCacheRequest(Enum): + RESET_PREFIX_CACHE = 1 + + @dataclass class RPCLoadAdapterRequest: lora_request: LoRARequest @@ -134,7 +138,8 @@ class RPCAdapterLoadedResponse: RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, - RPCUProfileRequest, RPCLoadAdapterRequest] + RPCUProfileRequest, RPCLoadAdapterRequest, + RPCResetPrefixCacheRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index a9ab89953518..5dd2eca16253 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -27,8 +27,9 @@ VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCAdapterLoadedResponse, RPCError, RPCLoadAdapterRequest, - RPCProcessRequest, RPCStartupRequest, - RPCStartupResponse, + RPCProcessRequest, + RPCResetPrefixCacheRequest, + RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) from vllm.engine.protocol import EngineClient # yapf: enable @@ -667,6 +668,13 @@ async def stop_profile(self) -> None: await self._send_one_way_rpc_request( request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) + async def reset_prefix_cache(self) -> None: + """Reset the prefix cache""" + + await self._send_one_way_rpc_request( + request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE, + socket=self.input_socket) + async def add_lora(self, lora_request: LoRARequest) -> None: """Load a new LoRA adapter into the engine for future requests.""" # Uses the same I/O as generate requests diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 3aa9d30549f3..166f89743b3c 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -16,8 +16,9 @@ VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCAdapterLoadedResponse, RPCError, RPCLoadAdapterRequest, - RPCProcessRequest, RPCStartupRequest, - RPCStartupResponse, + RPCProcessRequest, + RPCResetPrefixCacheRequest, + RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) # yapf: enable from vllm.logger import init_logger @@ -237,6 +238,8 @@ def handle_new_input(self): self.stop_profile() elif isinstance(request, RPCLoadAdapterRequest): self._handle_load_adapter_request(request) + elif isinstance(request, RPCResetPrefixCacheRequest): + self.reset_prefix_cache() else: raise ValueError("Unknown RPCRequest Type: " f"{type(request)}") @@ -361,6 +364,9 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.engine.stop_profile() + def reset_prefix_cache(self) -> bool: + return self.engine.reset_prefix_cache() + def signal_handler(*_) -> None: raise KeyboardInterrupt("MQLLMEngine terminated") diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index f05ff62c4766..de7b2c1b91f5 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -271,6 +271,11 @@ async def stop_profile(self) -> None: """Start profiling the engine""" ... + @abstractmethod + async def reset_prefix_cache(self) -> None: + """Reset the prefix cache""" + ... + @abstractmethod async def add_lora(self, lora_request: LoRARequest) -> None: """Load a new LoRA adapter into the engine for future requests.""" diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fe622d9ab236..0681f57ec22c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1132,8 +1132,8 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.llm_engine.stop_profile() - def reset_prefix_cache(self) -> None: - self.llm_engine.reset_prefix_cache() + def reset_prefix_cache(self) -> bool: + return self.llm_engine.reset_prefix_cache() # LEGACY def _convert_v1_inputs( diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1aeefe86cd05..7015c610ad85 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -519,6 +519,17 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): } +@router.post("/reset_prefix_cache") +async def reset_prefix_cache(raw_request: Request): + """ + Reset the prefix cache. Note that we currently do not check if the + prefix cache is successfully reset in the API server. + """ + logger.info("Resetting prefix cache...") + await engine_client(raw_request).reset_prefix_cache() + return Response(status_code=200) + + @router.post("/invocations") async def invocations(raw_request: Request): """ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 1acd22e3daf0..8c8c8b3b55c0 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -285,10 +285,14 @@ def free(self, request: Request) -> None: if block.ref_cnt == 0: self.free_block_queue.append(block) - def reset_prefix_cache(self): + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalid prefix caching after the weights are updated, or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. """ num_used_blocks = (self.num_gpu_blocks - self.free_block_queue.num_free_blocks) @@ -296,7 +300,7 @@ def reset_prefix_cache(self): logger.warning( "Failed to reset prefix cache because some " "blocks (%d) are not freed yet", num_used_blocks) - return + return False # Remove all hashes so that no new blocks will hit. self.cached_block_hash_to_block = defaultdict(dict) @@ -305,6 +309,9 @@ def reset_prefix_cache(self): for block in self.block_pool: block.reset_hash() + logger.info("Successfully reset prefix cache") + return True + def get_num_common_prefix_blocks( self, request: Request, diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 870405a3355a..8ded5e578713 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -529,8 +529,8 @@ def get_num_unfinished_requests(self) -> int: def has_unfinished_requests(self) -> bool: return self.get_num_unfinished_requests() > 0 - def reset_prefix_cache(self): - self.kv_cache_manager.reset_prefix_cache() + def reset_prefix_cache(self) -> bool: + return self.kv_cache_manager.reset_prefix_cache() def make_stats(self) -> SchedulerStats: return SchedulerStats( diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a74699f7513e..b4d3e441173d 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -321,6 +321,9 @@ async def start_profile(self) -> None: async def stop_profile(self) -> None: await self.engine_core.profile_async(False) + async def reset_prefix_cache(self) -> None: + await self.engine_core.reset_prefix_cache_async() + @property def is_running(self) -> bool: return True From 221897c5063b11b66f89de378e94fbf5fb95fecf Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 21 Jan 2025 22:38:57 +0000 Subject: [PATCH 4/6] revert Signed-off-by: Cody Yu --- vllm/core/block/naive_block.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 46c9c4ac2662..c38ae2dd6761 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -4,12 +4,9 @@ from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device -from vllm.logger import init_logger Refcount = int -logger = init_logger(__name__) - class NaiveBlockAllocator(BlockAllocator): """A simple block allocator that manages blocks of memory without prefix From 6e284e43142dce3c4a149aac390daca18511ce8e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jan 2025 00:39:56 +0800 Subject: [PATCH 5/6] automatically reset prefix cache when sleeping Signed-off-by: youkaichao --- vllm/entrypoints/llm.py | 1 + vllm/executor/executor_base.py | 5 ----- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 73b994a9dbef..563031cfadc4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1153,6 +1153,7 @@ def sleep(self, level: int = 1): where previous model weights are not needed. It reduces CPU memory pressure. """ + self.reset_prefix_cache() self.llm_engine.sleep(level=level) def wake_up(self): diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 069be05eafa5..6be62d406857 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -194,11 +194,6 @@ def stop_profile(self) -> None: self.collective_rpc("stop_profile") def sleep(self, level: int = 1): - if self.cache_config.enable_prefix_caching: - # TODO: support sleep with prefix caching - # by resetting the prefix cache state, - # after https://github.com/vllm-project/vllm/pull/12284 - raise ValueError("Cannot sleep when prefix caching is enabled.") self.collective_rpc("sleep", kwargs=dict(level=level)) def wake_up(self): From db8cdc1c9c2df1d8b9cac319a759554d77a7a7f1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jan 2025 00:43:09 +0800 Subject: [PATCH 6/6] add env var Signed-off-by: youkaichao --- vllm/entrypoints/openai/api_server.py | 21 +++++++++++---------- vllm/envs.py | 7 +++++++ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7015c610ad85..9bb11907f740 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -518,16 +518,17 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): }, } - -@router.post("/reset_prefix_cache") -async def reset_prefix_cache(raw_request: Request): - """ - Reset the prefix cache. Note that we currently do not check if the - prefix cache is successfully reset in the API server. - """ - logger.info("Resetting prefix cache...") - await engine_client(raw_request).reset_prefix_cache() - return Response(status_code=200) +if envs.VLLM_SERVER_DEV_MODE: + + @router.post("/reset_prefix_cache") + async def reset_prefix_cache(raw_request: Request): + """ + Reset the prefix cache. Note that we currently do not check if the + prefix cache is successfully reset in the API server. + """ + logger.info("Resetting prefix cache...") + await engine_client(raw_request).reset_prefix_cache() + return Response(status_code=200) @router.post("/invocations") diff --git a/vllm/envs.py b/vllm/envs.py index b7b597ea15af..1e68326b2d90 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -72,6 +72,7 @@ VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False + VLLM_SERVER_DEV_MODE: bool = False def get_default_cache_root(): @@ -467,6 +468,12 @@ def get_default_config_root(): lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), "VLLM_DISABLE_COMPILE_CACHE": lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), + + # If set, vllm will run in development mode, which will enable + # some additional endpoints for developing and debugging, + # e.g. `/reset_prefix_cache` + "VLLM_SERVER_DEV_MODE": + lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))), } # end-env-vars-definition