diff --git a/benchmark/benchmark_latency.py b/benchmark/benchmark_latency.py index a18ef98f40a8..aa611ac0d336 100644 --- a/benchmark/benchmark_latency.py +++ b/benchmark/benchmark_latency.py @@ -35,7 +35,7 @@ def main(args: argparse.Namespace): dtype=args.dtype, seed=args.seed, swap_space=args.swap_space, - max_batch_size=args.max_batch_size, + max_num_batched_tokens=args.max_num_batched_tokens, num_nodes=num_nodes, num_devices_per_node=num_devices_per_node, distributed_init_method=distributed_init_method, @@ -94,6 +94,7 @@ def profile_step(profile=False): parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--batch-size', type=int, default=8) args = parser.parse_args() - args.max_batch_size = max(args.max_batch_size, args.batch_size * args.input_len) + args.max_num_batched_tokens = max( + args.max_num_batched_tokens, args.batch_size * args.input_len) print(args) main(args) diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 1f224316c01b..b2a97100e4d3 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -22,7 +22,7 @@ def __init__( dtype: str, seed: int, swap_space: int, - max_batch_size: int, + max_num_batched_tokens: int, num_nodes: int, num_devices_per_node: int, distributed_init_method: str, @@ -43,7 +43,7 @@ def __init__( tensor_parallel_size=tensor_parallel_size, ) self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks( - max_num_batched_tokens=max_batch_size) + max_num_batched_tokens=max_num_batched_tokens) self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks( swap_space=swap_space) print(f'# GPU blocks: {self.num_gpu_blocks}, ' @@ -66,6 +66,7 @@ def __init__( dtype=dtype, seed=seed, model_path=model_path, + max_num_batched_tokens=max_num_batched_tokens, ) self.controllers.append(controller) @@ -75,7 +76,7 @@ def __init__( block_size=block_size, num_gpu_blocks=self.num_gpu_blocks, num_cpu_blocks=self.num_cpu_blocks, - max_num_batched_tokens=max_batch_size, + max_num_batched_tokens=max_num_batched_tokens, ) # Connect the controllers. for i in range(len(self.controllers) - 1): @@ -168,8 +169,8 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights', help='model path to download and load the weights') # Parallel arguments - parser.add_argument('--pipeline-parallel-size', type=int, default=1, help='number of pipeline stages') - parser.add_argument('--tensor-parallel-size', type=int, default=1, help='number of tensor parallel replicas') + parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') # KV cache arguments parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size') # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. @@ -177,5 +178,5 @@ def add_server_arguments(parser: argparse.ArgumentParser): # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') - parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens') + parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens') return parser diff --git a/cacheflow/parallel_utils/parallel_state.py b/cacheflow/parallel_utils/parallel_state.py index ef4e886d874b..8bb8c4023f57 100644 --- a/cacheflow/parallel_utils/parallel_state.py +++ b/cacheflow/parallel_utils/parallel_state.py @@ -47,6 +47,7 @@ # Memory buffers to avoid dynamic memory allocation _GLOBAL_MEMORY_BUFFER = None +_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None def initialize_model_parallel( tensor_model_parallel_size: int = 1, @@ -205,6 +206,20 @@ def initialize_model_parallel( _set_global_memory_buffer() +def initialize_all_reduce_launcher( + max_num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + disable_graph: bool = False, +) -> None: + global _ALL_REDUCE_LAUNCHER + _ALL_REDUCE_LAUNCHER = GraphAllReduce( + max_num_tokens=max_num_tokens, + hidden_size=hidden_size, + dtype=dtype, + disable_graph=disable_graph, + ) + def model_parallel_is_initialized(): """Check if model and data parallel groups are initialized.""" if _TENSOR_MODEL_PARALLEL_GROUP is None or \ @@ -491,6 +506,9 @@ def get_global_memory_buffer(): assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized' return _GLOBAL_MEMORY_BUFFER +def get_all_reduce_launcher() -> 'GraphAllReduce': + assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized' + return _ALL_REDUCE_LAUNCHER def destroy_model_parallel(): """Set the groups to none.""" @@ -520,3 +538,56 @@ def destroy_model_parallel(): _MPU_PIPELINE_MODEL_PARALLEL_RANK = None global _GLOBAL_MEMORY_BUFFER _GLOBAL_MEMORY_BUFFER = None + + +class GraphAllReduce: + + def __init__( + self, + max_num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + disable_graph: bool = False, + ) -> None: + self.max_num_tokens = max_num_tokens + self.hidden_size = hidden_size + self.disable_graph = disable_graph + + tp_world_size = get_tensor_model_parallel_world_size() + if tp_world_size == 1: + return + + self.group = get_tensor_model_parallel_group() + self.buffer = torch.empty( + size=(max_num_tokens, hidden_size), + dtype=dtype, + device='cuda', + ) + + # Build graphs for different number of tokens. + if not self.disable_graph: + self.graphs = {} + for num_tokens in range(8, max_num_tokens + 1, 8): + self.graphs[num_tokens] = self._build_graph(num_tokens) + + def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph: + # Warm up. + torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group) + torch.cuda.synchronize() + + # Build graph. + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + torch.distributed.all_reduce( + self.buffer[:num_tokens], group=self.group) + torch.cuda.synchronize() + return graph + + def launch(self, x: torch.Tensor) -> torch.Tensor: + # NOTE: x must be a slice of self.buffer. + num_tokens = x.shape[0] + if self.disable_graph: + torch.distributed.all_reduce(x, group=self.group) + else: + self.graphs[num_tokens].replay() + return x diff --git a/cacheflow/parallel_utils/tensor_parallel/layers.py b/cacheflow/parallel_utils/tensor_parallel/layers.py index f9ba8385cc9c..2cbe2b8a6d76 100644 --- a/cacheflow/parallel_utils/tensor_parallel/layers.py +++ b/cacheflow/parallel_utils/tensor_parallel/layers.py @@ -12,6 +12,7 @@ from cacheflow.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_all_reduce_launcher, ) from .mappings import ( copy_to_tensor_model_parallel_region, @@ -407,8 +408,7 @@ def __init__(self, input_size, output_size, *, self.bias.zero_() else: self.register_parameter('bias', None) - - + self.weight_t = self.weight.t() def forward(self, input_): """Forward of RowParallelLinear @@ -425,11 +425,18 @@ def forward(self, input_): input_parallel = input_ else: input_parallel = scatter_to_tensor_model_parallel_region(input_) - # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight) + if get_tensor_model_parallel_world_size() == 1: + # Matrix multiply. + output_ = F.linear(input_parallel, self.weight) + else: + # Matrix multiply. + all_reduce_launcher = get_all_reduce_launcher() + num_tokens = input_parallel.shape[0] + output_buffer = all_reduce_launcher.buffer[:num_tokens] + torch.matmul(input_parallel, self.weight_t, out=output_buffer) + # All-reduce across all the partitions. + output_ = all_reduce_launcher.launch(output_buffer) - # All-reduce across all the partitions. - output_ = reduce_from_tensor_model_parallel_region(output_parallel) if not self.skip_bias_add: output = output_ + self.bias if self.bias is not None else output_ output_bias = None diff --git a/cacheflow/worker/controller.py b/cacheflow/worker/controller.py index bb357b132665..577dbc1615a8 100644 --- a/cacheflow/worker/controller.py +++ b/cacheflow/worker/controller.py @@ -27,6 +27,7 @@ def __init__( dtype: str, seed: int, model_path: str, + max_num_batched_tokens: int, ) -> None: self.stage_id = stage_id self.stage_devices = stage_devices @@ -57,6 +58,7 @@ def __init__( tensor_parallel_size=tensor_parallel_size, pipeline_parallel_size=pipeline_parallel_size, model_path=model_path, + max_num_batched_tokens=max_num_batched_tokens, ) self.workers.append(worker) diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index db0d46aabe9e..3e92d9597f8e 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -9,7 +9,9 @@ from cacheflow.sequence import SequenceOutputs from cacheflow.worker.cache_engine import CacheEngine from cacheflow.parallel_utils.parallel_state import ( - initialize_model_parallel, get_tensor_model_parallel_world_size) + initialize_model_parallel, + initialize_all_reduce_launcher, + get_tensor_model_parallel_world_size) from cacheflow.utils import set_random_seed @@ -27,6 +29,7 @@ def __init__( rank: int, world_size: int, model_path: str, + max_num_batched_tokens: int, tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, ) -> None: @@ -44,6 +47,8 @@ def __init__( self.model = self.model.cuda() tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) + initialize_all_reduce_launcher( + max_num_batched_tokens, self.model.config.hidden_size, self.dtype) self.num_layers = self.model.config.num_hidden_layers assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0 self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size diff --git a/simple_server.py b/simple_server.py index 4d6aa93b97b8..d333ece34c07 100644 --- a/simple_server.py +++ b/simple_server.py @@ -28,7 +28,7 @@ def main(args: argparse.Namespace): dtype=args.dtype, seed=args.seed, swap_space=args.swap_space, - max_batch_size=args.max_batch_size, + max_num_batched_tokens=args.max_num_batched_tokens, num_nodes=num_nodes, num_devices_per_node=num_devices_per_node, distributed_init_method=distributed_init_method,