From 5cd8f3dd874daba51800b97fcf20339c4418311d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 9 Apr 2023 06:31:11 +0000 Subject: [PATCH] Add use-dummy-weights option --- benchmark/benchmark_latency.py | 1 + cacheflow/http_frontend/fastapi_frontend.py | 1 + cacheflow/master/server.py | 3 +++ cacheflow/models/llama.py | 4 ++++ cacheflow/models/model_utils.py | 23 +++++++++++++++------ cacheflow/models/opt.py | 4 ++++ cacheflow/worker/controller.py | 2 ++ cacheflow/worker/worker.py | 5 +++-- simple_server.py | 1 + 9 files changed, 36 insertions(+), 8 deletions(-) diff --git a/benchmark/benchmark_latency.py b/benchmark/benchmark_latency.py index 24727713ff88..c6f1bacb8408 100644 --- a/benchmark/benchmark_latency.py +++ b/benchmark/benchmark_latency.py @@ -29,6 +29,7 @@ def main(args: argparse.Namespace): server = Server( model=args.model, model_path=args.model_path, + use_dummy_weights=args.use_dummy_weights, pipeline_parallel_size=args.pipeline_parallel_size, tensor_parallel_size=args.tensor_parallel_size, block_size=args.block_size, diff --git a/cacheflow/http_frontend/fastapi_frontend.py b/cacheflow/http_frontend/fastapi_frontend.py index 5390806a9122..209536310e87 100644 --- a/cacheflow/http_frontend/fastapi_frontend.py +++ b/cacheflow/http_frontend/fastapi_frontend.py @@ -47,6 +47,7 @@ def __init__( self.server = remote_server_class.remote( model=model, model_path=model_path, + use_dummy_weights=False, pipeline_parallel_size=pipeline_parallel_size, tensor_parallel_size=tensor_parallel_size, block_size=block_size, diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index b2a97100e4d3..ff8b549eb4c4 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -16,6 +16,7 @@ def __init__( self, model: str, model_path: str, + use_dummy_weights: bool, pipeline_parallel_size: int, tensor_parallel_size: int, block_size: int, @@ -66,6 +67,7 @@ def __init__( dtype=dtype, seed=seed, model_path=model_path, + use_dummy_weights=use_dummy_weights, max_num_batched_tokens=max_num_batched_tokens, ) self.controllers.append(controller) @@ -179,4 +181,5 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens') + parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') return parser diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 2a3c8b007adf..a8572e4f362a 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -286,3 +286,7 @@ def get_weights(model_name: str, path: str): np.save(f, param.cpu().detach().numpy()) return path + + def initialize_dummy_weights(self) -> None: + for param in self.state_dict().values(): + param.data.uniform_(-0.1, 0.1) diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index aaf81bc2b513..8df6acf79a6a 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -28,18 +28,29 @@ def get_model( model_name: str, dtype: Union[torch.dtype, str], path: str, + use_dummy_weights: bool, ) -> nn.Module: torch_dtype = get_torch_dtype(dtype) torch.set_default_dtype(torch_dtype) config = AutoConfig.from_pretrained(model_name) for model_class_name, model_class in _MODELS.items(): if model_class_name in model_name: - # Download model weights if it's not cached. - weights_dir = model_class.get_weights(model_name, path=path) - # Create a model instance. - model = model_class(config) - # Load the weights from the cached or downloaded files. - model.load_weights(weights_dir) + if use_dummy_weights: + # Create a model instance. + # The weights will be initialized as empty tensors. + model = model_class(config) + model = model.cuda() + # NOTE(woosuk): For precise performance evaluation, we assign + # random values to the weights. + model.initialize_dummy_weights() + else: + # Download model weights if it's not cached. + weights_dir = model_class.get_weights(model_name, path=path) + # Create a model instance. + model = model_class(config) + # Load the weights from the cached or downloaded files. + model.load_weights(weights_dir) + model = model.cuda() return model.eval(), torch_dtype raise ValueError(f'Unsupported model name: {model_name}') diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 9ecd9e70f138..90c6f54e3fca 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -324,3 +324,7 @@ def get_weights(model_name: str, path: str): np.save(f, param.cpu().detach().numpy()) return path + + def initialize_dummy_weights(self) -> None: + for param in self.state_dict().values(): + param.data.uniform_(-0.1, 0.1) diff --git a/cacheflow/worker/controller.py b/cacheflow/worker/controller.py index 577dbc1615a8..dce3fddf89fa 100644 --- a/cacheflow/worker/controller.py +++ b/cacheflow/worker/controller.py @@ -27,6 +27,7 @@ def __init__( dtype: str, seed: int, model_path: str, + use_dummy_weights: bool, max_num_batched_tokens: int, ) -> None: self.stage_id = stage_id @@ -58,6 +59,7 @@ def __init__( tensor_parallel_size=tensor_parallel_size, pipeline_parallel_size=pipeline_parallel_size, model_path=model_path, + use_dummy_weights=use_dummy_weights, max_num_batched_tokens=max_num_batched_tokens, ) self.workers.append(worker) diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 3e92d9597f8e..95ce2c6a869e 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -29,6 +29,7 @@ def __init__( rank: int, world_size: int, model_path: str, + use_dummy_weights: bool, max_num_batched_tokens: int, tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, @@ -43,8 +44,8 @@ def __init__( set_random_seed(seed) # Initialize the model. - self.model, self.dtype = get_model(model_name, dtype=dtype, path=model_path) - self.model = self.model.cuda() + self.model, self.dtype = get_model( + model_name, dtype=dtype, path=model_path, use_dummy_weights=use_dummy_weights) tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) initialize_all_reduce_launcher( diff --git a/simple_server.py b/simple_server.py index d333ece34c07..08842f7b7da7 100644 --- a/simple_server.py +++ b/simple_server.py @@ -22,6 +22,7 @@ def main(args: argparse.Namespace): server = Server( model=args.model, model_path=args.model_path, + use_dummy_weights=args.use_dummy_weights, pipeline_parallel_size=args.pipeline_parallel_size, tensor_parallel_size=args.tensor_parallel_size, block_size=args.block_size,