From 41d38c1d948fcc742ebd9796496b433f29061248 Mon Sep 17 00:00:00 2001 From: Shaoting Feng Date: Fri, 24 Jan 2025 19:14:51 +0000 Subject: [PATCH 1/4] Fix format Signed-off-by: Shaoting Feng --- .../disaggregated_prefill.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 examples/offline_inference/disaggregated_prefill.py diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py new file mode 100644 index 000000000000..bf8778ecde49 --- /dev/null +++ b/examples/offline_inference/disaggregated_prefill.py @@ -0,0 +1,80 @@ +import os +import time +from multiprocessing import Event, Process + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def run_prefill(prefill_done): + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + prompts = [ + "Hello, my name is", + # "Hi, your name is", # To simulate transmission failure + "Tell me a very long story", + ] + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' + ) + llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", + kv_transfer_config=ktc, + max_model_len=2000, + gpu_memory_utilization=0.8) + + llm.generate(prompts, sampling_params) + print("Prefill node is finished.") + prefill_done.set() + + # To keep the prefill node running in case the decode node is not done + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("Script stopped by user.") + + +def run_decode(prefill_done): + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + prompts = [ + "Hello, my name is", + "Hi, your name is", + "Tell me a very long story", + ] + sampling_params = SamplingParams(temperature=0, top_p=0.95) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' + ) + llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", + kv_transfer_config=ktc, + max_model_len=2000, + gpu_memory_utilization=0.8) + + # Wait for the producer to start the pipe + print("Waiting for prefill node to finish...") + prefill_done.wait() + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == "__main__": + prefill_done = Event() + process_a = Process(target=run_prefill, args=(prefill_done, )) + process_b = Process(target=run_decode, args=(prefill_done, )) + + # Start prefill node + process_a.start() + + # Start decode node + process_b.start() + + process_b.join() + process_a.terminate() From cb3d827d150de469a5dc7a8f755c7351c35c7a43 Mon Sep 17 00:00:00 2001 From: Shaoting Feng Date: Mon, 27 Jan 2025 19:44:06 +0000 Subject: [PATCH 2/4] Add comments and explanations Signed-off-by: Shaoting Feng --- .../disaggregated_prefill.py | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index bf8778ecde49..f9daf9f285f9 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -1,3 +1,10 @@ +""" +This file demonstrates the example usage of disaggregated prefilling +We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode), +and then transfer the KV cache between them. + +Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html +""" import os import time from multiprocessing import Event, Process @@ -7,11 +14,16 @@ def run_prefill(prefill_done): + # We use GPU 0 for prefill node. os.environ["CUDA_VISIBLE_DEVICES"] = "0" + # The prefill node receives two requests, while the decode node receives + # three requests. So the decode node will only receive the KV Cache for + # requests 1 and 3. The decode node will use the KV Cache of requests 1 + # and 3 and do prefilling on request 2. prompts = [ "Hello, my name is", - # "Hi, your name is", # To simulate transmission failure + # "Hi, your name is", # To trigger partial prefill of batched requests "Tell me a very long story", ] sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) @@ -19,6 +31,8 @@ def run_prefill(prefill_done): ktc = KVTransferConfig.from_cli( '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' ) + # Example: Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB + # memory. Reduce the value if your GPU has less memory. llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", kv_transfer_config=ktc, max_model_len=2000, @@ -37,6 +51,7 @@ def run_prefill(prefill_done): def run_decode(prefill_done): + # We use GPU 1 for decode node. os.environ["CUDA_VISIBLE_DEVICES"] = "1" prompts = [ @@ -49,6 +64,8 @@ def run_decode(prefill_done): ktc = KVTransferConfig.from_cli( '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' ) + # Example: Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB + # of memory. Reduce the value if your GPU has less memory. llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", kv_transfer_config=ktc, max_model_len=2000, @@ -67,14 +84,15 @@ def run_decode(prefill_done): if __name__ == "__main__": prefill_done = Event() - process_a = Process(target=run_prefill, args=(prefill_done, )) - process_b = Process(target=run_decode, args=(prefill_done, )) + prefill_process = Process(target=run_prefill, args=(prefill_done, )) + decode_process = Process(target=run_decode, args=(prefill_done, )) # Start prefill node - process_a.start() + prefill_process.start() # Start decode node - process_b.start() + decode_process.start() - process_b.join() - process_a.terminate() + # Terminate the prefill node when decode is finished + decode_process.join() + prefill_process.terminate() From ee0c687e315ba1604401035f10d6239190a0223d Mon Sep 17 00:00:00 2001 From: Shaoting Date: Sat, 8 Feb 2025 00:18:21 -0600 Subject: [PATCH 3/4] Add comments for better clarification Signed-off-by: Shaoting --- .../disaggregated_prefill.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index f9daf9f285f9..f14a6a70d5c8 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -2,8 +2,6 @@ This file demonstrates the example usage of disaggregated prefilling We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode), and then transfer the KV cache between them. - -Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html """ import os import time @@ -23,16 +21,22 @@ def run_prefill(prefill_done): # and 3 and do prefilling on request 2. prompts = [ "Hello, my name is", - # "Hi, your name is", # To trigger partial prefill of batched requests + # "Hi, your name is", + # The decode node will actually "prefill" this request. "Tell me a very long story", ] sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + # Using PyNcclConnector to transmit KV caches between vLLM instances. + # This instance is the prefill node (kv_producer, rank 0). + # The number of parallel instances for KV cache transfer is set to 2, + # as required for PyNcclConnector. ktc = KVTransferConfig.from_cli( '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' ) - # Example: Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB - # memory. Reduce the value if your GPU has less memory. + + # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB + # memory. You may need to adjust the value to fit your GPU. llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", kv_transfer_config=ktc, max_model_len=2000, @@ -42,7 +46,8 @@ def run_prefill(prefill_done): print("Prefill node is finished.") prefill_done.set() - # To keep the prefill node running in case the decode node is not done + # To keep the prefill node running in case the decode node is not done; + # otherwise, the script might exit prematurely, causing incomplete decoding. try: while True: time.sleep(1) @@ -61,11 +66,16 @@ def run_decode(prefill_done): ] sampling_params = SamplingParams(temperature=0, top_p=0.95) + # Using PyNcclConnector to transmit KV caches between vLLM instances. + # This instance is the decode node (kv_consumer, rank 1). + # The number of parallel instances for KV cache transfer is set to 2, + # as required for PyNcclConnector. ktc = KVTransferConfig.from_cli( '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' ) - # Example: Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB - # of memory. Reduce the value if your GPU has less memory. + + # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB + # memory. You may need to adjust the value to fit your GPU. llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", kv_transfer_config=ktc, max_model_len=2000, @@ -75,6 +85,8 @@ def run_decode(prefill_done): print("Waiting for prefill node to finish...") prefill_done.wait() + # At this point when the prefill_done is set, the kv-cache should have been + # transferred to this decode node, so we can start decoding. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt From 57685b8349fb43211481c4a0472ee42d3ba1c107 Mon Sep 17 00:00:00 2001 From: Shaoting Date: Sat, 8 Feb 2025 00:26:29 -0600 Subject: [PATCH 4/4] Add Apache-2.0 Signed-off-by: Shaoting --- examples/offline_inference/disaggregated_prefill.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index f14a6a70d5c8..2e41cabaccaf 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 """ This file demonstrates the example usage of disaggregated prefilling We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),