Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions benchmark/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import argparse
import time
from typing import List

from tqdm import tqdm
import numpy as np
import torch

from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory


def main(args: argparse.Namespace):
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')

(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
address='local',
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))

# Create a server.
server = Server(
model=args.model,
model_path=args.model_path,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_batch_size=args.max_batch_size,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
)

# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
sampling_params_dict = {
'n': 1,
'temperature': 0.0,
'top_p': 1.0,
'use_beam_search': False,
'stop_token_ids': set(),
'max_num_steps': args.output_len,
}
sampling_params = SamplingParams.from_dict(sampling_params_dict)
input_token_ids = [0] * args.input_len

def profile_step(profile=False):
if profile:
torch.cuda.cudart().cudaProfilerStart()
for _ in range(args.batch_size):
frontend._add_query(input_token_ids, sampling_params)
server.add_sequence_groups(frontend.get_inputs())
start_time = time.time()
while True:
server.step()
if not server.has_unfinished_requests():
break
end_time = time.time()
latency = end_time - start_time
if profile:
torch.cuda.cudart().cudaProfilerStop()
return latency

print("Warm up step")
profile_step()

# Benchmark.
latencies = []
for _ in tqdm(range(3), desc="Profile step"):
latencies.append(profile_step())
print(f'Avg latency: {np.mean(latencies)} seconds')


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = add_server_arguments(parser)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
args = parser.parse_args()
args.max_batch_size = max(args.max_batch_size, args.batch_size * args.input_len)
print(args)
main(args)
3 changes: 0 additions & 3 deletions cacheflow/parallel_utils/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate,
linear_with_grad_accumulation_and_async_allreduce

)

from .mappings import (
Expand Down Expand Up @@ -39,7 +37,6 @@
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate",
"linear_with_grad_accumulation_and_async_allreduce",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
Expand Down
Loading