From 5a9e3a4b3185204179c3448188017bd65d400788 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Thu, 26 Dec 2024 23:04:35 +0000 Subject: [PATCH 1/3] [Add] benchmark script for CPU offloading (long document QA use case) Signed-off-by: ApostaC Co-authored-by: KuntaiDu --- benchmarks/benchmark_long_document_qa.py | 258 +++++++++++++++++++++++ 1 file changed, 258 insertions(+) create mode 100644 benchmarks/benchmark_long_document_qa.py diff --git a/benchmarks/benchmark_long_document_qa.py b/benchmarks/benchmark_long_document_qa.py new file mode 100644 index 000000000000..8d4425b6fb8a --- /dev/null +++ b/benchmarks/benchmark_long_document_qa.py @@ -0,0 +1,258 @@ +""" +Benchmark the efficiency of prefix caching. + +This script allows you to benchmark the performance of +a model with prefix-caching or cpu-offloading using fixed prompts + +Fixed example usage: + # This command run the vllm with 50GB CPU memory for offloading + # The workload samples 8 different prompts with a default input + # length of 20010 tokens, then replicates each prompt 2 times. + python benchmark_long_document_qa.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --block-allocator CpuOffloadingBlockAllocator \ + --num-documents 8 \ + --repeat-count 2 \ + --cpu-memory-gb 50 + +Commandline arguments: + + # Basic arguments + --model: The model to use for the benchmark. + + --enable-prefix-caching: Enable prefix caching or not. + + --block-allocator: The block allocator that vLLM uses. + - CpuGpuBlockAllocator: The default block allocator. + - CpuOffloadingBlockAllocator: The block allocator that supports + cpu offloading + + --gpu-memory-utilization: GPU memory utilization for vLLM. + + --cpu-memory-gb: The amount of CPU memory (GB) that is used by vLLM. + NOTE: CPU memory should be larger than GPU KV cache size when + using CpuOffloadingBlockAllocator. + + # Workload-related arguments + --num-documents: The number of documents to sample prompts from. + + --repeat-count: The number of times to repeat each prompt. + + # Other functionality + --seed: Random seed for reproducibility. + + --profile-swap-blocks: Profile the swap_blocks function in the custom ops. +""" + +import random +import time + +import torch + +from vllm import LLM, SamplingParams +from vllm.utils import FlexibleArgumentParser + +execution_times = {} + + +def build_result_dict(start_time, end_time, *args): + total_time = end_time - start_time + length = -1 + if len(args) > 1 and isinstance(args[1], torch.Tensor): + length = len(args[1]) + + return { + "start_time": start_time, + "total_time": total_time, + "swap_len": length + } + + +def timing_decorator(func): + + def wrapper(*args, **kwargs): + global execution_times + torch.cuda.synchronize() + start_time = time.time() # Record the start time + result = func(*args, **kwargs) # Call the wrapped function + torch.cuda.synchronize() + end_time = time.time() # Record the end time + if func.__name__ not in execution_times: + execution_times[func.__name__] = [] + + res = build_result_dict(start_time, end_time, *args) + execution_times[func.__name__].append(res) + return result # Return the result of the original function + + return wrapper + + +def process_timing_results(): + global execution_times + for key in execution_times: + len_to_time = {} + len_to_count = {} + for item in execution_times[key]: + swap_len = item["swap_len"] + if swap_len not in len_to_time: + len_to_time[swap_len] = 0 + len_to_time[swap_len] += item["total_time"] + + if swap_len not in len_to_count: + len_to_count[swap_len] = 0 + len_to_count[swap_len] += 1 + + for swap_len in len_to_time: + total_time = len_to_time[swap_len] + count = len_to_count[swap_len] + print(f"{key} on {swap_len} pages: " + f"{(count * swap_len) / total_time} pages per second") + + +def test_long_document_qa(llm=None, sampling_params=None, prompts=None): + + start_time = time.time() + llm.generate(prompts, sampling_params=sampling_params) + end_time = time.time() + print(f"cost time {end_time - start_time}") + + +def repeat_prompts(prompts, repeat_count): + repeated_prompts = prompts * repeat_count + random.shuffle(repeated_prompts) + return repeated_prompts + + +def main(args): + if args.profile_swap_blocks: + from vllm.worker.cache_engine import CacheEngine + CacheEngine.swap_out = timing_decorator(CacheEngine.swap_out) + CacheEngine.swap_in = timing_decorator(CacheEngine.swap_in) + + random.seed(args.seed) + + # append the document id at the beginning to avoid any of the document + # being the prefix of other documents + prompts = [ + str(i) + ' '.join(['hi'] * args.document_length) + for i in range(args.num_documents) + ] + + preemption_mode = "" + if args.block_allocator == "CpuOffloadingBlockAllocator": + preemption_mode = "recompute" + else: + preemption_mode = "swap" + + llm = LLM(model=args.model, + tokenizer_mode='auto', + trust_remote_code=True, + enforce_eager=True, + tensor_parallel_size=args.tensor_parallel_size, + enable_prefix_caching=args.enable_prefix_caching, + block_allocator=args.block_allocator, + preemption_mode=preemption_mode, + swap_space=args.cpu_memory_gb, + enable_chunked_prefill=False, + gpu_memory_utilization=args.gpu_memory_utilization, + max_model_len=30000) + + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + + prompts = repeat_prompts(prompts, args.repeat_count) + + print("------warm up------") + test_long_document_qa( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + random.shuffle(prompts) + + print("------start generating------") + test_long_document_qa( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + if args.profile_swap_blocks: + process_timing_results() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description= + 'Benchmark the performance with or without automatic prefix caching.') + parser.add_argument( + '--model', + type=str, + # this test aims to test long document QA capability, + # so we use llama 3.1 8B as it can process long context + default='meta-llama/Llama-3.1-8B') + parser.add_argument("--dataset-path", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--output-len', type=int, default=10) + parser.add_argument('--enable-prefix-caching', + action='store_true', + help='enable prefix caching') + parser.add_argument('--repeat-count', + type=int, + default=2, + help='Number of times to repeat each prompt') + parser.add_argument( + '--document-length', + type=int, + # Roughly the number of tokens for a system paper, + # excluding images + default=20010, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + parser.add_argument('--num-documents', + type=int, + default=8, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + parser.add_argument("--seed", + type=int, + default=0, + help='Random seed for reproducibility') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='GPU memory utilization for vLLM. Should be a ' + 'float point number ranging from 0 to 1. For this ' + 'test please use a small value so that the GPU ' + 'cannot hold all KV caches of all documents, ' + 'and the effect of CPU offloading can be tested.') + parser.add_argument( + '--cpu-memory-gb', + type=float, + default=1, + help="The amount of CPU memory (GB) that is used by vLLM. Not very " + "useful for CpuGpuBlockAllocator, but useful for " + "CpuOffloadingBlockAllocator to have more CPU KV cache space") + parser.add_argument( + '--block-allocator', + type=str, + default='CpuGpuBlockAllocator', + choices=['CpuGpuBlockAllocator', 'CpuOffloadingBlockAllocator'], + help='The block allocator that vLLM uses. Currently' + ' can be CpuGpuBlockAllocator (the default) and ' + 'CpuOffloadingBlockAllocator (experimental) that ' + 'supports offloading the KV cache to CPU . ' + 'When using CpuOffloadingBlockAllocator, the ' + 'preemption mode must be recompute.') + + parser.add_argument( + '--profile-swap-blocks', + action='store_true', + help='Profile the swap_blocks function in the custom ops') + + args = parser.parse_args() + main(args) From 46cbd938e8a926aa918a3c21a9fbf7e3abfcad9d Mon Sep 17 00:00:00 2001 From: ApostaC Date: Tue, 31 Dec 2024 20:38:32 +0000 Subject: [PATCH 2/3] Simplify the script and address the reviews Signed-off-by: ApostaC --- benchmarks/benchmark_long_document_qa.py | 258 ------------------ .../benchmark_long_document_qa_throughput.py | 169 ++++++++++++ 2 files changed, 169 insertions(+), 258 deletions(-) delete mode 100644 benchmarks/benchmark_long_document_qa.py create mode 100644 benchmarks/benchmark_long_document_qa_throughput.py diff --git a/benchmarks/benchmark_long_document_qa.py b/benchmarks/benchmark_long_document_qa.py deleted file mode 100644 index 8d4425b6fb8a..000000000000 --- a/benchmarks/benchmark_long_document_qa.py +++ /dev/null @@ -1,258 +0,0 @@ -""" -Benchmark the efficiency of prefix caching. - -This script allows you to benchmark the performance of -a model with prefix-caching or cpu-offloading using fixed prompts - -Fixed example usage: - # This command run the vllm with 50GB CPU memory for offloading - # The workload samples 8 different prompts with a default input - # length of 20010 tokens, then replicates each prompt 2 times. - python benchmark_long_document_qa.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --enable-prefix-caching \ - --block-allocator CpuOffloadingBlockAllocator \ - --num-documents 8 \ - --repeat-count 2 \ - --cpu-memory-gb 50 - -Commandline arguments: - - # Basic arguments - --model: The model to use for the benchmark. - - --enable-prefix-caching: Enable prefix caching or not. - - --block-allocator: The block allocator that vLLM uses. - - CpuGpuBlockAllocator: The default block allocator. - - CpuOffloadingBlockAllocator: The block allocator that supports - cpu offloading - - --gpu-memory-utilization: GPU memory utilization for vLLM. - - --cpu-memory-gb: The amount of CPU memory (GB) that is used by vLLM. - NOTE: CPU memory should be larger than GPU KV cache size when - using CpuOffloadingBlockAllocator. - - # Workload-related arguments - --num-documents: The number of documents to sample prompts from. - - --repeat-count: The number of times to repeat each prompt. - - # Other functionality - --seed: Random seed for reproducibility. - - --profile-swap-blocks: Profile the swap_blocks function in the custom ops. -""" - -import random -import time - -import torch - -from vllm import LLM, SamplingParams -from vllm.utils import FlexibleArgumentParser - -execution_times = {} - - -def build_result_dict(start_time, end_time, *args): - total_time = end_time - start_time - length = -1 - if len(args) > 1 and isinstance(args[1], torch.Tensor): - length = len(args[1]) - - return { - "start_time": start_time, - "total_time": total_time, - "swap_len": length - } - - -def timing_decorator(func): - - def wrapper(*args, **kwargs): - global execution_times - torch.cuda.synchronize() - start_time = time.time() # Record the start time - result = func(*args, **kwargs) # Call the wrapped function - torch.cuda.synchronize() - end_time = time.time() # Record the end time - if func.__name__ not in execution_times: - execution_times[func.__name__] = [] - - res = build_result_dict(start_time, end_time, *args) - execution_times[func.__name__].append(res) - return result # Return the result of the original function - - return wrapper - - -def process_timing_results(): - global execution_times - for key in execution_times: - len_to_time = {} - len_to_count = {} - for item in execution_times[key]: - swap_len = item["swap_len"] - if swap_len not in len_to_time: - len_to_time[swap_len] = 0 - len_to_time[swap_len] += item["total_time"] - - if swap_len not in len_to_count: - len_to_count[swap_len] = 0 - len_to_count[swap_len] += 1 - - for swap_len in len_to_time: - total_time = len_to_time[swap_len] - count = len_to_count[swap_len] - print(f"{key} on {swap_len} pages: " - f"{(count * swap_len) / total_time} pages per second") - - -def test_long_document_qa(llm=None, sampling_params=None, prompts=None): - - start_time = time.time() - llm.generate(prompts, sampling_params=sampling_params) - end_time = time.time() - print(f"cost time {end_time - start_time}") - - -def repeat_prompts(prompts, repeat_count): - repeated_prompts = prompts * repeat_count - random.shuffle(repeated_prompts) - return repeated_prompts - - -def main(args): - if args.profile_swap_blocks: - from vllm.worker.cache_engine import CacheEngine - CacheEngine.swap_out = timing_decorator(CacheEngine.swap_out) - CacheEngine.swap_in = timing_decorator(CacheEngine.swap_in) - - random.seed(args.seed) - - # append the document id at the beginning to avoid any of the document - # being the prefix of other documents - prompts = [ - str(i) + ' '.join(['hi'] * args.document_length) - for i in range(args.num_documents) - ] - - preemption_mode = "" - if args.block_allocator == "CpuOffloadingBlockAllocator": - preemption_mode = "recompute" - else: - preemption_mode = "swap" - - llm = LLM(model=args.model, - tokenizer_mode='auto', - trust_remote_code=True, - enforce_eager=True, - tensor_parallel_size=args.tensor_parallel_size, - enable_prefix_caching=args.enable_prefix_caching, - block_allocator=args.block_allocator, - preemption_mode=preemption_mode, - swap_space=args.cpu_memory_gb, - enable_chunked_prefill=False, - gpu_memory_utilization=args.gpu_memory_utilization, - max_model_len=30000) - - sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) - - prompts = repeat_prompts(prompts, args.repeat_count) - - print("------warm up------") - test_long_document_qa( - llm=llm, - prompts=prompts, - sampling_params=sampling_params, - ) - - random.shuffle(prompts) - - print("------start generating------") - test_long_document_qa( - llm=llm, - prompts=prompts, - sampling_params=sampling_params, - ) - - if args.profile_swap_blocks: - process_timing_results() - - -if __name__ == "__main__": - parser = FlexibleArgumentParser( - description= - 'Benchmark the performance with or without automatic prefix caching.') - parser.add_argument( - '--model', - type=str, - # this test aims to test long document QA capability, - # so we use llama 3.1 8B as it can process long context - default='meta-llama/Llama-3.1-8B') - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the dataset.") - parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) - parser.add_argument('--output-len', type=int, default=10) - parser.add_argument('--enable-prefix-caching', - action='store_true', - help='enable prefix caching') - parser.add_argument('--repeat-count', - type=int, - default=2, - help='Number of times to repeat each prompt') - parser.add_argument( - '--document-length', - type=int, - # Roughly the number of tokens for a system paper, - # excluding images - default=20010, - help='Range of input lengths for sampling prompts,' - 'specified as "min:max" (e.g., "128:256").') - parser.add_argument('--num-documents', - type=int, - default=8, - help='Range of input lengths for sampling prompts,' - 'specified as "min:max" (e.g., "128:256").') - parser.add_argument("--seed", - type=int, - default=0, - help='Random seed for reproducibility') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=0.9, - help='GPU memory utilization for vLLM. Should be a ' - 'float point number ranging from 0 to 1. For this ' - 'test please use a small value so that the GPU ' - 'cannot hold all KV caches of all documents, ' - 'and the effect of CPU offloading can be tested.') - parser.add_argument( - '--cpu-memory-gb', - type=float, - default=1, - help="The amount of CPU memory (GB) that is used by vLLM. Not very " - "useful for CpuGpuBlockAllocator, but useful for " - "CpuOffloadingBlockAllocator to have more CPU KV cache space") - parser.add_argument( - '--block-allocator', - type=str, - default='CpuGpuBlockAllocator', - choices=['CpuGpuBlockAllocator', 'CpuOffloadingBlockAllocator'], - help='The block allocator that vLLM uses. Currently' - ' can be CpuGpuBlockAllocator (the default) and ' - 'CpuOffloadingBlockAllocator (experimental) that ' - 'supports offloading the KV cache to CPU . ' - 'When using CpuOffloadingBlockAllocator, the ' - 'preemption mode must be recompute.') - - parser.add_argument( - '--profile-swap-blocks', - action='store_true', - help='Profile the swap_blocks function in the custom ops') - - args = parser.parse_args() - main(args) diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py new file mode 100644 index 000000000000..7f0864aeccba --- /dev/null +++ b/benchmarks/benchmark_long_document_qa_throughput.py @@ -0,0 +1,169 @@ +""" +Offline benchmark to test the long document QA throughput. + +Example usage: + # This command run the vllm with 50GB CPU memory for offloading + # The workload samples 8 different prompts with a default input + # length of 20000 tokens, then replicates each prompt 2 times + # in random order. + python benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --repeat-count 2 + +Commandline arguments: + --num-documents: The number of documents to sample prompts from. + + --document-length: The length of each document in tokens. + (Optional, default: 20000) + + --output-len: The number of tokens to generate for each prompt. + (Optional, default: 10) + + --repeat-count: The number of times to repeat each prompt. + (Optional, default: 2) + + --repeat-mode: The mode to repeat prompts. The supported modes are: + - 'random': shuffle the prompts randomly. (Default) + - 'tile': the entire prompt list is repeated in sequence. (Potentially + lowest cache hit) + - 'interleave': each prompt is repeated consecutively before + moving to the next element. (Highest cache hit) + + --shuffle-seed: Random seed when the repeat mode is "random". + (Optional, default: 0) + +In the meantime, it also supports all the vLLM engine args to initialize the +LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more +details. +""" + +import dataclasses +import random +import time + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def test_long_document_qa(llm=None, sampling_params=None, prompts=None): + """ + Test long document QA with the given prompts and sampling parameters. + Print the time cost processing all the prompts. + """ + start_time = time.time() + llm.generate(prompts, sampling_params=sampling_params) + end_time = time.time() + print(f"Time to execute all requests: {end_time - start_time:.4f} secs") + + +def repeat_prompts(prompts, repeat_count, mode: str): + """ + Repeat each prompt in the list for repeat_count times. + The order of prompts in the output list depends on the mode. + Currently, we support the following modes: + - 'random': shuffle the prompts randomly + - 'tile': the entire prompt list is repeated in sequence. Ex. [1, 2, 3] + -> [1, 2, 3, 1, 2, 3] (1, 2, 3 are prompts) + - 'interleave': each prompt is repeated consecutively before moving to the + next element. Ex. [1, 2, 3] -> [1, 1, 2, 2, 3, 3] + """ + print("Repeat mode: ", mode) + if mode == 'random': + repeated_prompts = prompts * repeat_count + random.shuffle(repeated_prompts) + return repeated_prompts + elif mode == 'tile': + return prompts * repeat_count + elif mode == 'interleave': + repeated_prompts = [] + for prompt in prompts: + repeated_prompts.extend([prompt] * repeat_count) + return repeated_prompts + else: + raise ValueError(f"Invalid mode: {mode}, only support " + "'random', 'tile', 'interleave'") + + +def main(args): + random.seed(args.shuffle_seed) + + # Prepare the prompts: + # we append the document id at the beginning to avoid any of the document + # being the prefix of other documents + prompts = [ + str(i) + ' '.join(['hi'] * args.document_length) + for i in range(args.num_documents) + ] + + prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode) + + warmup_prompts = [ + "This is warm up request " + str(i) + \ + ' '.join(['hi'] * args.document_length) + for i in range(args.num_documents)] + + # Create the LLM engine + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**dataclasses.asdict(engine_args)) + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + + print("------warm up------") + test_long_document_qa( + llm=llm, + prompts=warmup_prompts, + sampling_params=sampling_params, + ) + + print("------start generating------") + test_long_document_qa( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description= + 'Benchmark the performance with or without automatic prefix caching.') + + parser.add_argument( + '--document-length', + type=int, + # Roughly the number of tokens for a system paper, + # excluding images + default=20000, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + + parser.add_argument('--num-documents', + type=int, + default=8, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + + parser.add_argument('--output-len', type=int, default=10) + + parser.add_argument('--repeat-count', + type=int, + default=2, + help='Number of times to repeat each prompt') + + parser.add_argument("--repeat-mode", + type=str, + default='random', + help='The mode to repeat prompts. The supported ' + 'modes are "random", "tile", and "interleave". ' + 'See repeat_prompts() in the source code for details.') + + parser.add_argument("--shuffle-seed", + type=int, + default=0, + help='Random seed when the repeat mode is "random"') + + parser = EngineArgs.add_cli_args(parser) + args = parser.parse_args() + main(args) From 0fdc98e423f9dc6dcce3970a15d86951f7aebe7a Mon Sep 17 00:00:00 2001 From: ApostaC Date: Tue, 31 Dec 2024 22:30:59 +0000 Subject: [PATCH 3/3] [fix] docstring style Signed-off-by: ApostaC --- .../benchmark_long_document_qa_throughput.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py index 7f0864aeccba..13477ef535e8 100644 --- a/benchmarks/benchmark_long_document_qa_throughput.py +++ b/benchmarks/benchmark_long_document_qa_throughput.py @@ -51,7 +51,12 @@ def test_long_document_qa(llm=None, sampling_params=None, prompts=None): """ Test long document QA with the given prompts and sampling parameters. - Print the time cost processing all the prompts. + Print the time spent in processing all the prompts. + + Args: + llm: The language model used for generating responses. + sampling_params: Sampling parameter used to generate the response. + prompts: A list of prompt strings to be processed by the LLM. """ start_time = time.time() llm.generate(prompts, sampling_params=sampling_params) @@ -61,14 +66,24 @@ def test_long_document_qa(llm=None, sampling_params=None, prompts=None): def repeat_prompts(prompts, repeat_count, mode: str): """ - Repeat each prompt in the list for repeat_count times. + Repeat each prompt in the list for a specified number of times. The order of prompts in the output list depends on the mode. - Currently, we support the following modes: - - 'random': shuffle the prompts randomly - - 'tile': the entire prompt list is repeated in sequence. Ex. [1, 2, 3] - -> [1, 2, 3, 1, 2, 3] (1, 2, 3 are prompts) - - 'interleave': each prompt is repeated consecutively before moving to the - next element. Ex. [1, 2, 3] -> [1, 1, 2, 2, 3, 3] + + Args: + prompts: A list of prompts to be repeated. + repeat_count: The number of times each prompt is repeated. + mode: The mode of repetition. Supported modes are: + - 'random': Shuffle the prompts randomly after repetition. + - 'tile': Repeat the entire prompt list in sequence. + Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3]. + - 'interleave': Repeat each prompt consecutively before moving to + the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3]. + + Returns: + A list of repeated prompts in the specified order. + + Raises: + ValueError: If an invalid mode is provided. """ print("Repeat mode: ", mode) if mode == 'random':