Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmark/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def main(args: argparse.Namespace):
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
Expand Down
1 change: 1 addition & 0 deletions cacheflow/http_frontend/fastapi_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
self.server = remote_server_class.remote(
model=model,
model_path=model_path,
use_dummy_weights=False,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
Expand Down
3 changes: 3 additions & 0 deletions cacheflow/master/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
self,
model: str,
model_path: str,
use_dummy_weights: bool,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
dtype=dtype,
seed=seed,
model_path=model_path,
use_dummy_weights=use_dummy_weights,
max_num_batched_tokens=max_num_batched_tokens,
)
self.controllers.append(controller)
Expand Down Expand Up @@ -179,4 +181,5 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
return parser
4 changes: 4 additions & 0 deletions cacheflow/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,7 @@ def get_weights(model_name: str, path: str):
np.save(f, param.cpu().detach().numpy())

return path

def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
param.data.uniform_(-0.1, 0.1)
23 changes: 17 additions & 6 deletions cacheflow/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,29 @@ def get_model(
model_name: str,
dtype: Union[torch.dtype, str],
path: str,
use_dummy_weights: bool,
) -> nn.Module:
torch_dtype = get_torch_dtype(dtype)
torch.set_default_dtype(torch_dtype)
config = AutoConfig.from_pretrained(model_name)
for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name:
# Download model weights if it's not cached.
weights_dir = model_class.get_weights(model_name, path=path)
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.
model.load_weights(weights_dir)
if use_dummy_weights:
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(config)
model = model.cuda()
# NOTE(woosuk): For precise performance evaluation, we assign
# random values to the weights.
model.initialize_dummy_weights()
else:
# Download model weights if it's not cached.
weights_dir = model_class.get_weights(model_name, path=path)
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.
model.load_weights(weights_dir)
model = model.cuda()
return model.eval(), torch_dtype
raise ValueError(f'Unsupported model name: {model_name}')

Expand Down
4 changes: 4 additions & 0 deletions cacheflow/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,7 @@ def get_weights(model_name: str, path: str):
np.save(f, param.cpu().detach().numpy())

return path

def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
param.data.uniform_(-0.1, 0.1)
2 changes: 2 additions & 0 deletions cacheflow/worker/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
dtype: str,
seed: int,
model_path: str,
use_dummy_weights: bool,
max_num_batched_tokens: int,
) -> None:
self.stage_id = stage_id
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
model_path=model_path,
use_dummy_weights=use_dummy_weights,
max_num_batched_tokens=max_num_batched_tokens,
)
self.workers.append(worker)
Expand Down
5 changes: 3 additions & 2 deletions cacheflow/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
rank: int,
world_size: int,
model_path: str,
use_dummy_weights: bool,
max_num_batched_tokens: int,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
Expand All @@ -43,8 +44,8 @@ def __init__(
set_random_seed(seed)

# Initialize the model.
self.model, self.dtype = get_model(model_name, dtype=dtype, path=model_path)
self.model = self.model.cuda()
self.model, self.dtype = get_model(
model_name, dtype=dtype, path=model_path, use_dummy_weights=use_dummy_weights)
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
initialize_all_reduce_launcher(
Expand Down
1 change: 1 addition & 0 deletions simple_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def main(args: argparse.Namespace):
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
Expand Down