diff --git a/README.md b/README.md index 7cce45b9efbc..7008d8a204cb 100644 --- a/README.md +++ b/README.md @@ -11,5 +11,6 @@ pip install -e . ## Run ```bash -python server.py +ray start --head +python server.py [--tensor-parallel-size ] ``` diff --git a/cacheflow/models/__init__.py b/cacheflow/models/__init__.py index cd8f134a5a74..511a6822214c 100644 --- a/cacheflow/models/__init__.py +++ b/cacheflow/models/__init__.py @@ -1,12 +1,10 @@ from cacheflow.models.input_metadata import InputMetadata from cacheflow.models.model_utils import get_memory_analyzer from cacheflow.models.model_utils import get_model -from cacheflow.models.utils import set_seed __all__ = [ 'InputMetadata', 'get_memory_analyzer', 'get_model', - 'set_seed', ] diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index fe1e194d94a8..7f24670b7eaa 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -112,7 +112,7 @@ def forward( output[:num_prompt_tokens], query[:num_prompt_tokens], key[:num_prompt_tokens], - value[:num_prompt_tokens], + value[:num_prompt_tokens], input_metadata.prompt_lens, ) diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index e4787d4a8271..8a341fbac627 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -43,4 +43,8 @@ def __repr__(self) -> str: f'num_generation_tokens={self.num_generation_tokens}, ' f'num_valid_tokens={self.num_valid_tokens}, ' f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' - f'max_context_len={self.max_context_len})') + f'max_context_len={self.max_context_len}), ' + f'prompt_lens={self.prompt_lens}, ' + f'slot_mapping={self.slot_mapping}, ' + f'context_lens={self.context_lens}, ' + f'block_tables={self.block_tables})') diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 11668b588da5..69675588c3c4 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -31,12 +31,13 @@ def __init__( model_name: str, block_size: int, dtype: torch.dtype, + tensor_parallel_size: int, ) -> None: self.model_name = model_name self.block_size = block_size self.dtype = dtype + self.tensor_parallel_size = tensor_parallel_size - # TODO(woosuk): Support tensor parallelism. config = AutoConfig.from_pretrained(model_name) self.num_layers = config.num_hidden_layers self.hidden_size = config.hidden_size @@ -48,26 +49,25 @@ def __init__( self.max_position = config.max_position_embeddings def _get_param_size(self) -> int: - # TODO(woosuk): Support tensor parallelism. - word_embedding = self.vocab_size * self.embedding_size + word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size if self.embedding_size != self.vocab_size: # Project in/out. word_embedding += 2 * self.embedding_size * self.vocab_size position_embedding = self.max_position * self.hidden_size ln1 = 2 * self.hidden_size - q = self.hidden_size * self.hidden_size + self.hidden_size - k = self.hidden_size * self.hidden_size + self.hidden_size - v = self.hidden_size * self.hidden_size + self.hidden_size - out = self.hidden_size * self.hidden_size + self.hidden_size + q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size + k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size + v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size + out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size mha = ln1 + q + k + v + out ln2 = 2 * self.hidden_size - ffn1 = self.hidden_size * self.ffn_size + self.ffn_size - ffn2 = self.ffn_size * self.hidden_size + self.hidden_size + ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size + ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size ffn = ln2 + ffn1 + ffn2 - total = (word_embedding + position_embedding + + total = (word_embedding + position_embedding + self.num_layers * (mha + ffn)) dtype_size = get_dtype_size(self.dtype) return dtype_size * total @@ -76,15 +76,17 @@ def _get_max_act_size( self, max_num_batched_tokens: int, ) -> int: - # TODO(woosuk): Support tensor parallelism. # NOTE: We approxmiately calculate the maximum activation size by - # 1) estimating the maximum activation tensor size during inference, and - # 2) multiplying it by 4. + # estimating + # 1) the maximum activation tensor size during inference + # 2) the residual tensor size during inference # Here, we assume that FlashAttention is used and # thus the attention maps are never materialized in GPU DRAM. - qkv = 3 * (max_num_batched_tokens * self.hidden_size) - ffn = max_num_batched_tokens * self.ffn_size - max_act = 4 * max(qkv, ffn) + residual = max_num_batched_tokens * self.hidden_size + qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size + ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size + # Double the activation size for input and output. + max_act = 2 * (max(qkv, ffn) + residual) dtype_size = get_dtype_size(self.dtype) return dtype_size * max_act diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index 98ff6d44ebb0..b1fdacea075a 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -1,7 +1,9 @@ from typing import Union +import numpy as np import torch import torch.nn as nn +from transformers import AutoConfig from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer @@ -21,13 +23,20 @@ def get_model( model_name: str, dtype: Union[torch.dtype, str], + path: str, ) -> nn.Module: torch_dtype = get_torch_dtype(dtype) - for model_class, hf_model in _MODELS.items(): - if model_class in model_name: - model = hf_model.from_pretrained( - model_name, torch_dtype=torch_dtype) - return model.eval() + torch.set_default_dtype(torch_dtype) + config = AutoConfig.from_pretrained(model_name) + for model_class_name, model_class in _MODELS.items(): + if model_class_name in model_name: + # Download model weights if it's not cached. + weights_dir = model_class.download_weights(model_name, path=path) + # Create a model instance. + model = model_class(config) + # Load the weights from the cached or downloaded files. + model.load_weights(weights_dir) + return model.eval(), torch_dtype raise ValueError(f'Unsupported model name: {model_name}') @@ -35,10 +44,11 @@ def get_memory_analyzer( model_name: str, block_size: int, dtype: Union[torch.dtype, str], + tensor_parallel_size: int = 1, ) -> CacheFlowMemoryAnalyzer: torch_dtype = get_torch_dtype(dtype) for model_class, memory_analyzer in _MEMORY_ANALYZERS.items(): if model_class in model_name: return memory_analyzer( - model_name, block_size, torch_dtype) + model_name, block_size, torch_dtype, tensor_parallel_size) raise ValueError(f'Unsupported model name: {model_name}') diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 579033589059..e9d8e853cc08 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -1,14 +1,24 @@ """1D OPT model compatible with HuggingFace weights.""" +import os +import glob +import filelock +from tqdm import tqdm from typing import Dict, List, Optional, Tuple +import numpy as np import torch from torch import nn from transformers import OPTConfig -from transformers import PreTrainedModel +from huggingface_hub import snapshot_download from cacheflow.models import InputMetadata from cacheflow.models.attention import OPTCacheFlowAttention from cacheflow.models.sample import Sampler +from cacheflow.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) from cacheflow.sequence import SequenceOutputs KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -36,15 +46,26 @@ def __init__( ) -> None: super().__init__() self.embed_dim = embed_dim - self.num_heads = num_heads - self.head_dim = embed_dim // num_heads + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + total_num_heads = num_heads + assert num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = embed_dim // total_num_heads self.scaling = self.head_dim**-0.5 # TODO(woosuk): Fuse the three linear layers into one QKV linear layer. - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, + gather_output=False, + perform_initialization=False) + self.v_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, + gather_output=False, + perform_initialization=False) + self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, + gather_output=False, + perform_initialization=False) + self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, + input_is_parallel=True, + perform_initialization=False) self.attn = OPTCacheFlowAttention(scale=self.scaling) @@ -55,13 +76,13 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) key_cache, value_cache = kv_cache attn_output = self.attn( q, k, v, key_cache, value_cache, input_metadata, cache_event) - output = self.out_proj(attn_output) + output, _ = self.out_proj(attn_output) return output @@ -69,6 +90,7 @@ class OPTDecoderLayer(nn.Module): def __init__(self, config: OPTConfig): super().__init__() + self.config = config self.embed_dim = config.hidden_size self.self_attn = OPTAttention( embed_dim=self.embed_dim, @@ -81,9 +103,16 @@ def __init__(self, config: OPTConfig): self.self_attn_layer_norm = nn.LayerNorm( self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) - self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) - self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) - self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim, + bias=config.enable_bias, + gather_output=False, + perform_initialization=False) + self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim, + bias=config.enable_bias, + input_is_parallel=True, + perform_initialization=False) + self.final_layer_norm = nn.LayerNorm( + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) def forward( self, @@ -112,9 +141,9 @@ def forward( # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention if self.do_layer_norm_before: hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.fc1(hidden_states) + hidden_states, _ = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) + hidden_states, _ = self.fc2(hidden_states) hidden_states = residual + hidden_states # 350m applies layer norm AFTER attention if not self.do_layer_norm_before: @@ -122,29 +151,23 @@ def forward( return hidden_states -class OPTPreTrainedModel(PreTrainedModel): - config_class = OPTConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["OPTDecoderLayer"] - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - - def _init_weights(self, module) -> None: - del module # unused - return - - -class OPTDecoder(OPTPreTrainedModel): +class OPTDecoder(nn.Module): def __init__(self, config: OPTConfig): - super().__init__(config) + super().__init__() + self.config = config self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) - self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.word_embed_proj_dim, + perform_initialization=False) + # Positional embeddings are replicated (not sharded). + self.embed_positions = OPTLearnedPositionalEmbedding( + config.max_position_embeddings, config.hidden_size) + # Project out & in will be replicated if they exist. if config.word_embed_proj_dim != config.hidden_size: self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) else: @@ -167,9 +190,6 @@ def __init__(self, config: OPTConfig): self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - # Initialize weights and apply final processing - self.post_init() - def forward( self, input_ids: torch.LongTensor, @@ -200,13 +220,11 @@ def forward( return hidden_states -class OPTModel(OPTPreTrainedModel): +class OPTModel(nn.Module): def __init__(self, config: OPTConfig): - super().__init__(config) + super().__init__() self.decoder = OPTDecoder(config) - # Initialize weights and apply final processing - self.post_init() def forward( self, @@ -220,41 +238,17 @@ def forward( input_ids, positions, kv_caches, input_metadata, cache_events) -class OPTForCausalLM(OPTPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] +class OPTForCausalLM(nn.Module): def __init__(self, config): - super().__init__(config) + super().__init__() + self.config = config self.model = OPTModel(config) - # the lm_head weight is automatically tied to the embed tokens weight - self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + # TODO(zhuohan): create a new weight after implementing pipeline + # parallelism + self.lm_head_weight = self.model.decoder.embed_tokens.weight self.sampler = Sampler() - # Initialize weights and apply final processing - self.post_init() - - # NOTE(woosuk): While the following methods are not called in the model code, - # they may be internally used by the transformers library. - # For example, tie_weights() does not work without these methods. - # Thus, do not delete these methods. - def get_input_embeddings(self): - return self.model.decoder.embed_tokens - - def set_input_embeddings(self, value): - self.model.decoder.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - def forward( self, input_ids: torch.LongTensor, @@ -266,5 +260,72 @@ def forward( hidden_states = self.model( input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler( - self.lm_head.weight, hidden_states, input_metadata) + self.lm_head_weight, hidden_states, input_metadata) return next_tokens + + _column_parallel_weights = ["embed_tokens.weight", + "q_proj.weight", "k_proj.weight", + "v_proj.weight", "fc1.weight"] + _column_parallel_biases = ["q_proj.bias", "k_proj.bias", + "v_proj.bias", "fc1.bias"] + _row_parallel_weights = ["out_proj.weight", "fc2.weight"] + + def load_weights(self, weights_path: str): + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + for name, param in state_dict.items(): + if "lm_head_weight" in name: + continue + loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, + name))) + for p in (self._column_parallel_weights + + self._column_parallel_biases): + if p in name: + shard_size = param.shape[0] + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break + for p in self._row_parallel_weights: + if p in name: + shard_size = param.shape[1] + loaded_weight = loaded_weight[ + :, + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break + + assert param.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + + @staticmethod + def download_weights(model_name: str, path: str): + path = os.path.join(path, f"{model_name}-np") + path = os.path.abspath(os.path.expanduser(path)) + os.makedirs(path, exist_ok=True) + lock_path = os.path.join(path, "file_lock") + lock = filelock.FileLock(lock_path) + + with lock: + test_weight_path = os.path.join( + path, "model.decoder.embed_positions.weight") + if os.path.exists(test_weight_path): + return path + + folder = snapshot_download(model_name, allow_patterns="*.bin", + cache_dir=os.path.join(path, "cache")) + bin_files = glob.glob(os.path.join(folder, "*.bin")) + + if "/" in model_name: + model_name = model_name.split("/")[1].lower() + + for bin_file in tqdm(bin_files, desc="Convert format"): + state = torch.load(bin_file) + for name, param in tqdm(state.items(), leave=False): + if name.startswith("decoder."): + name = "model." + name + param_path = os.path.join(path, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + + return path diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index a8b60c145208..371986a75177 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -6,7 +6,7 @@ from cacheflow.models import InputMetadata from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import SequenceOutputs - +from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_parallel_region class Sampler(nn.Module): @@ -24,6 +24,7 @@ def forward( # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) + logits = gather_from_tensor_model_parallel_region(logits) # Apply temperature scaling. temperatures = _get_temperatures(input_metadata) diff --git a/cacheflow/models/utils.py b/cacheflow/models/utils.py index 4b705bf7d969..cdad84f03831 100644 --- a/cacheflow/models/utils.py +++ b/cacheflow/models/utils.py @@ -27,14 +27,6 @@ def get_dtype_size(dtype: Union[torch.dtype, str]) -> int: return torch.tensor([], dtype=torch_dtype).element_size() -def set_seed(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - def get_gpu_memory(gpu: int = 0) -> int: return torch.cuda.get_device_properties(gpu).total_memory diff --git a/cacheflow/parallel_utils/README.md b/cacheflow/parallel_utils/README.md new file mode 100644 index 000000000000..b25e3afddad9 --- /dev/null +++ b/cacheflow/parallel_utils/README.md @@ -0,0 +1 @@ +The files in this folder are ported from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core). We only keep the codes that are used in inference. \ No newline at end of file diff --git a/cacheflow/parallel_utils/__init__.py b/cacheflow/parallel_utils/__init__.py new file mode 100644 index 000000000000..78a7c4463eac --- /dev/null +++ b/cacheflow/parallel_utils/__init__.py @@ -0,0 +1,12 @@ +import cacheflow.parallel_utils.parallel_state +import cacheflow.parallel_utils.tensor_parallel +import cacheflow.parallel_utils.utils + +# Alias parallel_state as mpu, its legacy name +mpu = parallel_state + +__all__ = [ + "parallel_state", + "tensor_parallel", + "utils", +] diff --git a/cacheflow/parallel_utils/parallel_state.py b/cacheflow/parallel_utils/parallel_state.py new file mode 100644 index 000000000000..ef4e886d874b --- /dev/null +++ b/cacheflow/parallel_utils/parallel_state.py @@ -0,0 +1,522 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Model and data parallel groups.""" + +import torch +from typing import Optional + +from .utils import GlobalMemoryBuffer + +# Intra-layer model parallel group that the current rank belongs to. +_TENSOR_MODEL_PARALLEL_GROUP = None +# Inter-layer model parallel group that the current rank belongs to. +_PIPELINE_MODEL_PARALLEL_GROUP = None +# Model parallel group (both intra- and pipeline) that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Embedding group. +_EMBEDDING_GROUP = None +# Position embedding group. +_POSITION_EMBEDDING_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None + +_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None +_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None + +# These values enable us to change the mpu sizes on the fly. +_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_TENSOR_MODEL_PARALLEL_RANK = None +_MPU_PIPELINE_MODEL_PARALLEL_RANK = None + +# A list of ranks that have a copy of the embedding. +_EMBEDDING_GLOBAL_RANKS = None + +# A list of ranks that have a copy of the position embedding. +_POSITION_EMBEDDING_GLOBAL_RANKS = None + +# A list of global ranks for each pipeline group to ease calculation of the source +# rank when broadcasting from the first or last pipeline stage. +_PIPELINE_GLOBAL_RANKS = None + +# A list of global ranks for each data parallel group to ease calculation of the source +# rank when broadcasting weights from src to all other data parallel ranks +_DATA_PARALLEL_GLOBAL_RANKS = None + +# Memory buffers to avoid dynamic memory allocation +_GLOBAL_MEMORY_BUFFER = None + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + virtual_pipeline_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_split_rank: Optional[int] = None, +) -> None: + """ + Initialize model data parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. + virtual_pipeline_model_parallel_size: number of virtual stages (interleaved + pipeline). + pipeline_model_parallel_split_rank: for models with both encoder and decoder, + rank in pipeline with split point. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 8 tensor model-parallel groups, 4 pipeline model-parallel groups + and 8 data-parallel groups as: + 8 data_parallel groups: + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] + 8 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 4 pipeline model-parallel groups: + [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + + if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by tensor_model_parallel_size " + f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})" + ) + + data_parallel_size: int = world_size // (tensor_model_parallel_size * + pipeline_model_parallel_size) + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + num_data_parallel_groups: int = world_size // data_parallel_size + + if virtual_pipeline_model_parallel_size is not None: + if not pipeline_model_parallel_size > 2: + raise RuntimeError("pipeline-model-parallel size should be greater than 2 with " + "interleaved schedule") + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size + + if pipeline_model_parallel_split_rank is not None: + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank + + rank = torch.distributed.get_rank() + + # Build the data-parallel groups. + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GLOBAL_RANKS + assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized' + all_data_parallel_group_ranks = [] + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + for j in range(tensor_model_parallel_size): + ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) + all_data_parallel_group_ranks.append(list(ranks)) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GLOBAL_RANKS = ranks + + # Build the model-parallel groups. + global _MODEL_PARALLEL_GROUP + assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized' + for i in range(data_parallel_size): + ranks = [data_parallel_group_ranks[i] + for data_parallel_group_ranks in all_data_parallel_group_ranks] + group = torch.distributed.new_group(ranks) + if rank in ranks: + _MODEL_PARALLEL_GROUP = group + + # Build the tensor model-parallel groups. + global _TENSOR_MODEL_PARALLEL_GROUP + assert _TENSOR_MODEL_PARALLEL_GROUP is None, \ + 'tensor model parallel group is already initialized' + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _TENSOR_MODEL_PARALLEL_GROUP = group + + # Build the pipeline model-parallel groups and embedding groups + # (first and last rank in each pipeline model-parallel group). + global _PIPELINE_MODEL_PARALLEL_GROUP + global _PIPELINE_GLOBAL_RANKS + assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \ + 'pipeline model parallel group is already initialized' + global _EMBEDDING_GROUP + global _EMBEDDING_GLOBAL_RANKS + assert _EMBEDDING_GROUP is None, 'embedding group is already initialized' + global _POSITION_EMBEDDING_GROUP + global _POSITION_EMBEDDING_GLOBAL_RANKS + assert _POSITION_EMBEDDING_GROUP is None, \ + 'position embedding group is already initialized' + for i in range(num_pipeline_model_parallel_groups): + ranks = range(i, world_size, num_pipeline_model_parallel_groups) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _PIPELINE_MODEL_PARALLEL_GROUP = group + _PIPELINE_GLOBAL_RANKS = ranks + # Setup embedding group (to exchange gradients between + # first and last stages). + if len(ranks) > 1: + embedding_ranks = [ranks[0], ranks[-1]] + position_embedding_ranks = [ranks[0]] + if pipeline_model_parallel_split_rank is not None: + if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks: + embedding_ranks = [ranks[0], + ranks[pipeline_model_parallel_split_rank], + ranks[-1]] + if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks: + position_embedding_ranks = [ranks[0], + ranks[pipeline_model_parallel_split_rank]] + else: + embedding_ranks = ranks + position_embedding_ranks = ranks + + group = torch.distributed.new_group(embedding_ranks) + if rank in embedding_ranks: + _EMBEDDING_GROUP = group + if rank in ranks: + _EMBEDDING_GLOBAL_RANKS = embedding_ranks + + group = torch.distributed.new_group(position_embedding_ranks) + if rank in position_embedding_ranks: + _POSITION_EMBEDDING_GROUP = group + if rank in ranks: + _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks + + # Initialize global memory buffer + # This isn't really "parallel state" but there isn't another good place to + # put this. If we end up with a more generic initialization of megatron-core + # we could stick it there + _set_global_memory_buffer() + + +def model_parallel_is_initialized(): + """Check if model and data parallel groups are initialized.""" + if _TENSOR_MODEL_PARALLEL_GROUP is None or \ + _PIPELINE_MODEL_PARALLEL_GROUP is None or \ + _DATA_PARALLEL_GROUP is None: + return False + return True + + +def get_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, \ + 'model parallel group is not initialized' + return _MODEL_PARALLEL_GROUP + + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \ + 'intra_layer_model parallel group is not initialized' + return _TENSOR_MODEL_PARALLEL_GROUP + + +def get_pipeline_model_parallel_group(): + """Get the pipeline model parallel group the caller rank belongs to.""" + assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \ + 'pipeline_model parallel group is not initialized' + return _PIPELINE_MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, \ + 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + + +def get_embedding_group(): + """Get the embedding group the caller rank belongs to.""" + assert _EMBEDDING_GROUP is not None, \ + 'embedding group is not initialized' + return _EMBEDDING_GROUP + + +def get_position_embedding_group(): + """Get the position embedding group the caller rank belongs to.""" + assert _POSITION_EMBEDDING_GROUP is not None, \ + 'position embedding group is not initialized' + return _POSITION_EMBEDDING_GROUP + + +def set_tensor_model_parallel_world_size(world_size): + """Set the tensor model parallel size""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def set_pipeline_model_parallel_world_size(world_size): + """Set the pipeline model parallel size""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) + + +def set_tensor_model_parallel_rank(rank): + """Set tensor model parallel rank.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = rank + + +def set_pipeline_model_parallel_rank(rank): + """Set pipeline model parallel rank.""" + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank + + +def set_pipeline_model_parallel_split_rank(rank): + """Set pipeline model parallel split rank.""" + global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK + _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: + return _MPU_TENSOR_MODEL_PARALLEL_RANK + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_RANK + return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) + + + +def is_pipeline_first_stage(ignore_virtual=False): + """Return True if in the first pipeline model-parallel stage, False otherwise.""" + if not ignore_virtual: + if get_virtual_pipeline_model_parallel_world_size() is not None and \ + get_virtual_pipeline_model_parallel_rank() != 0: + return False + return get_pipeline_model_parallel_rank() == 0 + + +def is_pipeline_last_stage(ignore_virtual=False): + """Return True if in the last pipeline model-parallel stage, False otherwise.""" + if not ignore_virtual: + virtual_pipeline_model_parallel_world_size = \ + get_virtual_pipeline_model_parallel_world_size() + if virtual_pipeline_model_parallel_world_size is not None and \ + get_virtual_pipeline_model_parallel_rank() != ( + virtual_pipeline_model_parallel_world_size - 1): + return False + return get_pipeline_model_parallel_rank() == ( + get_pipeline_model_parallel_world_size() - 1) + + +def is_rank_in_embedding_group(ignore_virtual=False): + """Return true if current rank is in embedding group, False otherwise.""" + rank = torch.distributed.get_rank() + global _EMBEDDING_GLOBAL_RANKS + if ignore_virtual: + return rank in _EMBEDDING_GLOBAL_RANKS + if rank in _EMBEDDING_GLOBAL_RANKS: + if rank == _EMBEDDING_GLOBAL_RANKS[0]: + return is_pipeline_first_stage(ignore_virtual=False) + elif rank == _EMBEDDING_GLOBAL_RANKS[-1]: + return is_pipeline_last_stage(ignore_virtual=False) + else: + return True + return False + + +def is_rank_in_position_embedding_group(): + """Return true if current rank is in position embedding group, False otherwise.""" + rank = torch.distributed.get_rank() + global _POSITION_EMBEDDING_GLOBAL_RANKS + return rank in _POSITION_EMBEDDING_GLOBAL_RANKS + + +def is_pipeline_stage_before_split(rank=None): + """Return True if pipeline stage executes encoder block for a model + with both encoder and decoder.""" + if get_pipeline_model_parallel_world_size() == 1: + return True + if rank is None: + rank = get_pipeline_model_parallel_rank() + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: + return True + if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: + return True + return False + + +def is_pipeline_stage_after_split(rank=None): + """Return True if pipeline stage executes decoder block for a model + with both encoder and decoder.""" + if get_pipeline_model_parallel_world_size() == 1: + return True + if rank is None: + rank = get_pipeline_model_parallel_rank() + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: + return True + if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: + return True + return False + + +def is_pipeline_stage_at_split(): + """Return true if pipeline stage executes decoder block and next + stage executes encoder block for a model with both encoder and + decoder.""" + rank = get_pipeline_model_parallel_rank() + return is_pipeline_stage_before_split(rank) and \ + is_pipeline_stage_after_split(rank+1) + + +def get_virtual_pipeline_model_parallel_rank(): + """Return the virtual pipeline-parallel rank.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + + +def set_virtual_pipeline_model_parallel_rank(rank): + """Set the virtual pipeline-parallel rank.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank + + +def get_virtual_pipeline_model_parallel_world_size(): + """Return the virtual pipeline-parallel world size.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_data_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the data parallel group.""" + assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \ + "Data parallel group is not initialized" + return _DATA_PARALLEL_GLOBAL_RANKS[0] + + +def get_pipeline_model_parallel_first_rank(): + """Return the global rank of the first process in the pipeline for the + current tensor parallel group""" + assert _PIPELINE_GLOBAL_RANKS is not None, \ + "Pipeline parallel group is not initialized" + return _PIPELINE_GLOBAL_RANKS[0] + + +def get_pipeline_model_parallel_last_rank(): + """Return the global rank of the last process in the pipeline for the + current tensor parallel group""" + assert _PIPELINE_GLOBAL_RANKS is not None, \ + "Pipeline parallel group is not initialized" + last_rank_local = get_pipeline_model_parallel_world_size() - 1 + return _PIPELINE_GLOBAL_RANKS[last_rank_local] + +def get_pipeline_model_parallel_next_rank(): + """Return the global rank that follows the caller in the pipeline""" + assert _PIPELINE_GLOBAL_RANKS is not None, \ + "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + + +def get_pipeline_model_parallel_prev_rank(): + """Return the global rank that preceeds the caller in the pipeline""" + assert _PIPELINE_GLOBAL_RANKS is not None, \ + "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return torch.distributed.get_world_size(group=get_data_parallel_group()) + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return torch.distributed.get_rank(group=get_data_parallel_group()) + +def _set_global_memory_buffer(): + """Initialize global buffer""" + global _GLOBAL_MEMORY_BUFFER + assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized' + _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() + +def get_global_memory_buffer(): + """Return the global GlobalMemoryBuffer object""" + assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized' + return _GLOBAL_MEMORY_BUFFER + + +def destroy_model_parallel(): + """Set the groups to none.""" + global _MODEL_PARALLEL_GROUP + _MODEL_PARALLEL_GROUP = None + global _TENSOR_MODEL_PARALLEL_GROUP + _TENSOR_MODEL_PARALLEL_GROUP = None + global _PIPELINE_MODEL_PARALLEL_GROUP + _PIPELINE_MODEL_PARALLEL_GROUP = None + global _DATA_PARALLEL_GROUP + _DATA_PARALLEL_GROUP = None + global _EMBEDDING_GROUP + _EMBEDDING_GROUP = None + global _POSITION_EMBEDDING_GROUP + _POSITION_EMBEDDING_GROUP = None + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = None + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + _MPU_PIPELINE_MODEL_PARALLEL_RANK = None + global _GLOBAL_MEMORY_BUFFER + _GLOBAL_MEMORY_BUFFER = None diff --git a/cacheflow/parallel_utils/tensor_parallel/__init__.py b/cacheflow/parallel_utils/tensor_parallel/__init__.py new file mode 100644 index 000000000000..fba4f9abe116 --- /dev/null +++ b/cacheflow/parallel_utils/tensor_parallel/__init__.py @@ -0,0 +1,58 @@ +from .layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, + set_tensor_model_parallel_attributes, + set_defaults_if_not_set_tensor_model_parallel_attributes, + copy_tensor_model_parallel_attributes, + param_is_not_tensor_parallel_duplicate, + linear_with_grad_accumulation_and_async_allreduce + +) + +from .mappings import ( + copy_to_tensor_model_parallel_region, + gather_from_tensor_model_parallel_region, + gather_from_sequence_parallel_region, + scatter_to_tensor_model_parallel_region, + scatter_to_sequence_parallel_region, +) + +from .random import ( + checkpoint, + get_cuda_rng_tracker, + model_parallel_cuda_manual_seed, +) + +from .utils import ( + split_tensor_along_last_dim, + split_tensor_into_1d_equal_chunks, + gather_split_1d_tensor, +) + +__all__ = [ + #layers.py + "ColumnParallelLinear", + "RowParallelLinear", + "VocabParallelEmbedding", + "set_tensor_model_parallel_attributes", + "set_defaults_if_not_set_tensor_model_parallel_attributes", + "copy_tensor_model_parallel_attributes", + "param_is_not_tensor_parallel_duplicate", + "linear_with_grad_accumulation_and_async_allreduce", + # mappings.py + "copy_to_tensor_model_parallel_region", + "gather_from_tensor_model_parallel_region", + "gather_from_sequence_parallel_region", +# "reduce_from_tensor_model_parallel_region", + "scatter_to_tensor_model_parallel_region", + "scatter_to_sequence_parallel_region", + # random.py + "checkpoint", + "get_cuda_rng_tracker", + "model_parallel_cuda_manual_seed", + # utils.py + "split_tensor_along_last_dim", + "split_tensor_into_1d_equal_chunks", + "gather_split_1d_tensor", +] diff --git a/cacheflow/parallel_utils/tensor_parallel/layers.py b/cacheflow/parallel_utils/tensor_parallel/layers.py new file mode 100644 index 000000000000..978ca04e690e --- /dev/null +++ b/cacheflow/parallel_utils/tensor_parallel/layers.py @@ -0,0 +1,719 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch + +import math +import os +from typing import Optional +import warnings + +import torch +import torch.nn.functional as F +import torch.nn.init as init +from torch.nn.parameter import Parameter + +from cacheflow.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_group, + get_global_memory_buffer, +) +from .mappings import ( + copy_to_tensor_model_parallel_region, + gather_from_tensor_model_parallel_region, + gather_from_sequence_parallel_region, + reduce_from_tensor_model_parallel_region, + scatter_to_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) + +from .random import get_cuda_rng_tracker +from .utils import ( + divide, + split_tensor_along_last_dim, + VocabUtility, +) + +_grad_accum_fusion_available = True +try: + import fused_weight_gradient_mlp_cuda +except ImportError: + _grad_accum_fusion_available = False + +_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, + 'partition_dim': -1, + 'partition_stride': 1} + +def param_is_not_tensor_parallel_duplicate(param): + return (hasattr(param, 'tensor_model_parallel') and + param.tensor_model_parallel) or ( + get_tensor_model_parallel_rank() == 0) + + +def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): + # Make sure the attributes are not set. + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + assert not hasattr(tensor, attribute) + # Set the attributes. + setattr(tensor, 'tensor_model_parallel', is_parallel) + setattr(tensor, 'partition_dim', dim) + setattr(tensor, 'partition_stride', stride) + + +def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): + def maybe_set(attribute, value): + if not hasattr(tensor, attribute): + setattr(tensor, attribute, value) + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) + + +def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): + def maybe_copy(attribute): + if hasattr(source_tensor, attribute): + setattr(destination_tensor, attribute, + getattr(source_tensor, attribute)) + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + maybe_copy(attribute) + + +def _initialize_affine_weight_gpu(weight, init_method, + partition_dim, stride=1): + """Initialize affine weight for model parallel on GPU.""" + + set_tensor_model_parallel_attributes(tensor=weight, + is_parallel=True, + dim=partition_dim, + stride=stride) + + with get_cuda_rng_tracker().fork(): + init_method(weight) + + +def _initialize_affine_weight_cpu(weight, output_size, input_size, + per_partition_size, partition_dim, + init_method, stride=1, + return_master_weight=False, + *, params_dtype=None): + """Initialize affine weight for model parallel. + + Build the master weight on all processes and scatter + the relevant chunk.""" + + set_tensor_model_parallel_attributes(tensor=weight, + is_parallel=True, + dim=partition_dim, + stride=stride) + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + # Initialize master weight + master_weight = torch.empty(output_size, input_size, + dtype=torch.float, + requires_grad=False) + init_method(master_weight) + master_weight = master_weight.to(dtype=params_dtype) + + # Split and copy + per_partition_per_stride_size = divide(per_partition_size, stride) + weight_list = torch.split(master_weight, per_partition_per_stride_size, + dim=partition_dim) + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + my_weight_list = weight_list[rank::world_size] + + with torch.no_grad(): + torch.cat(my_weight_list, dim=partition_dim, out=weight) + if return_master_weight: + return master_weight + return None + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + + Keyword Arguments: + init_method: method to initialize weights. + params_dtype + use_cpu_initialization + perform_initialization + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, *, + init_method=init.xavier_normal_, + params_dtype: torch.dtype=None, + use_cpu_initialization: bool=False, + perform_initialization: bool=True): + super(VocabParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + # Set the defaults for compatibility. + self.padding_idx = None + self.max_norm = None + self.norm_type = 2. + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() + # Divide the weight matrix along the vocaburaly dimension. + self.vocab_start_index, self.vocab_end_index = \ + VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, get_tensor_model_parallel_rank(), + self.tensor_model_parallel_size) + self.num_embeddings_per_partition = self.vocab_end_index - \ + self.vocab_start_index + + # Allocate weights and initialize. + if use_cpu_initialization: + self.weight = Parameter(torch.empty( + self.num_embeddings_per_partition, self.embedding_dim, + dtype=params_dtype)) + if perform_initialization: + _initialize_affine_weight_cpu( + self.weight, self.num_embeddings, self.embedding_dim, + self.num_embeddings_per_partition, 0, init_method, + params_dtype=params_dtype) + else: + self.weight = Parameter(torch.empty( + self.num_embeddings_per_partition, self.embedding_dim, + device=torch.cuda.current_device(), dtype=params_dtype)) + if perform_initialization: + _initialize_affine_weight_gpu(self.weight, init_method, + partition_dim=0, stride=1) + + def forward(self, input_): + if self.tensor_model_parallel_size > 1: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | \ + (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output_parallel = F.embedding(masked_input, self.weight, + self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, + self.sparse) + # Mask the output embedding. + if self.tensor_model_parallel_size > 1: + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = reduce_from_tensor_model_parallel_region(output_parallel) + return output + + +class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + """See linear_with_grad_accumulation_and_async_allreduce""" + + @staticmethod + def forward(ctx, input, weight, bias, gradient_accumulation_fusion, + async_grad_allreduce, sequence_parallel): + ctx.save_for_backward(input, weight) + ctx.use_bias = bias is not None + ctx.gradient_accumulation_fusion = gradient_accumulation_fusion + ctx.async_grad_allreduce = async_grad_allreduce + ctx.sequence_parallel = sequence_parallel + + if sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = \ + get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") + torch.distributed._all_gather_base( + all_gather_buffer, + input, + group=get_tensor_model_parallel_group()) + total_input = all_gather_buffer + else: + total_input = input + + output = torch.matmul(total_input, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + if ctx.sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = \ + get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") + handle = torch.distributed._all_gather_base( + all_gather_buffer, + input, + group=get_tensor_model_parallel_group(), async_op=True) + + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # gather is scheduled before the input gradient computation + total_input = all_gather_buffer + else: + total_input = input + grad_input = grad_output.matmul(weight) + + if ctx.sequence_parallel: + handle.wait() + + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], + grad_output.shape[2]) + total_input = total_input.view(total_input.shape[0] * total_input.shape[1], + total_input.shape[2]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = torch.distributed.all_reduce( + grad_input, group=get_tensor_model_parallel_group(), async_op=True) + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # all-reduce is scheduled before the weight gradient computation + + if ctx.sequence_parallel: + assert not ctx.async_grad_allreduce + dim_size = list(input.size()) + sub_grad_input = torch.empty(dim_size, dtype=input.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + # reduce_scatter + handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, + group=get_tensor_model_parallel_group(), + async_op=True) + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # reduce scatter is scheduled before the weight gradient computation + + + if ctx.gradient_accumulation_fusion: + if weight.main_grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad) + elif weight.main_grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, weight.main_grad) + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.sequence_parallel: + handle.wait() + return sub_grad_input, grad_weight, grad_bias, None, None, None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + +def linear_with_grad_accumulation_and_async_allreduce( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + gradient_accumulation_fusion: bool, + async_grad_allreduce: bool, + sequence_parallel_enabled: bool, +) -> torch.Tensor: + """Linear layer execution with asynchronous communication and + gradient accumulation fusion in backprop. + + This has the option to accumulate the result of backprop + calculation into an existing gradient buffer, preventing the need + to do an additional addition kernel after the gradient + calculation. + + Additionally, the tensor parallel all reduce of the input + gradients can be done asynchronously with the calculation of + the weight gradients. + + In the case of sequence parallelism, the reduce scatter of the + input gradients is done asynchronously with the calcluation of the + weight gradients. + + Use of this module requires that the environment variable + CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective + operations, noted in the code, that should be scheduled before + compute kernels to overlap the communication with the computation, + which is necessary for a speedup but not for correctness so that + ordering isn't imposed by the scheduler. Setting + CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled + in the order they are called. + + Arguments: + + input (torch.Tensor required): input like torch.nn.functional.linear + + weight (torch.Tensor required): weight like torch.nn.functional.linear + + bias (torch.Tensor optional): bias like torch.nn.functional.linear + + gradient_accumulation_fusion (bool required): Perform the gradient + accumulation fusion, requires the custom CUDA extension + fused_weight_gradient_mlp_cuda module. To use + gradient_accumulation_fusion you must install APEX with + --cpp_ext and --cuda_ext. For example: "pip install + --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" + " Note that the extension requires CUDA>=11. Otherwise, you + must turn off gradient accumulation fusion." + + async_grad_allreduce (bool required): Do the allreduce of input + gradients asyncronously with the computation of weight + gradients. If sequence_parallel_enabled is True, this must be + False, as no all reduce is performed. + + sequence_parallel_enabled (bool required): Indicates that sequence + parallelism is used and thus in the forward pass the input is + all gathered, and the backward pass the input gradients are + reduce scattered. + """ + args = [ + input, + weight, + bias, + gradient_accumulation_fusion, + async_grad_allreduce, + sequence_parallel_enabled, + ] + + if not linear_with_grad_accumulation_and_async_allreduce.warned: + if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": + if sequence_parallel_enabled: + warnings.warn( + "When using sequence parallelism it is recommended to set the " + "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " + "maximum speedup") + linear_with_grad_accumulation_and_async_allreduce.warned = True + + if async_grad_allreduce: + warnings.warn( + "When using async grad allreduce it is recommended to set the " + "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " + "maximum speedup") + linear_with_grad_accumulation_and_async_allreduce.warned = True + + with torch.cuda.amp.autocast(enabled=False): + return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) +linear_with_grad_accumulation_and_async_allreduce.warned = False + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments + bias: If true, add bias + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + async_tensor_model_parallel_allreduce: + params_dtype: + use_cpu_initialization: + gradient_accumulation_fusion: + sequence_parallel_enabled: + """ + + def __init__(self, input_size, output_size, *, + bias=True, gather_output=True, + init_method=init.xavier_normal_, stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + async_tensor_model_parallel_allreduce=True, + params_dtype=None, + use_cpu_initialization=False, + perform_initialization=True, + gradient_accumulation_fusion=False, + sequence_parallel_enabled: bool = False, + ): + super(ColumnParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + world_size = get_tensor_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, world_size) + self.skip_bias_add = skip_bias_add + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + # Initialize weight. + if use_cpu_initialization: + self.weight = Parameter(torch.empty(self.output_size_per_partition, + self.input_size, + dtype=params_dtype)) + if perform_initialization: + self.master_weight = _initialize_affine_weight_cpu( + self.weight, self.output_size, self.input_size, + self.output_size_per_partition, 0, init_method, + stride=stride, return_master_weight=keep_master_weight_for_test) + else: + self.weight = Parameter(torch.empty( + self.output_size_per_partition, self.input_size, + device=torch.cuda.current_device(), dtype=params_dtype)) + if perform_initialization: + _initialize_affine_weight_gpu(self.weight, init_method, + partition_dim=0, stride=stride) + + if bias: + if use_cpu_initialization: + self.bias = Parameter(torch.empty( + self.output_size_per_partition, dtype=params_dtype)) + else: + self.bias = Parameter(torch.empty( + self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=params_dtype)) + set_tensor_model_parallel_attributes(self.bias, True, 0, stride) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + self.async_tensor_model_parallel_allreduce = ( + async_tensor_model_parallel_allreduce and + world_size > 1) + if sequence_parallel_enabled: + if world_size <= 1: + warnings.warn( + f"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is {world_size}. " + f"Disabling sequence parallel." + ) + sequence_parallel_enabled = False + self.sequence_parallel_enabled = sequence_parallel_enabled + + if gradient_accumulation_fusion: + if not _grad_accum_fusion_available: + raise RuntimeError( + "ColumnParallelLinear was called with gradient_accumulation_fusion set " + "to True but the custom CUDA extension fused_weight_gradient_mlp_cuda " + "module is not found. To use gradient_accumulation_fusion you must " + "install APEX with --cpp_ext and --cuda_ext. For example: " + "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" " + "Note that the extension requires CUDA>=11. Otherwise, you must turn off " + "gradient accumulation fusion." + ) + self.gradient_accumulation_fusion = gradient_accumulation_fusion + + if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled: + raise RuntimeError( + "`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` " + "cannot be enabled at the same time." + ) + + + def forward(self, input_): + """Forward of ColumnParallelLinear + + Args: + input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + + Returns: + - output + - bias + """ + bias = self.bias if not self.skip_bias_add else None + + if self.async_tensor_model_parallel_allreduce or \ + self.sequence_parallel_enabled: + input_parallel = input_ + else: + input_parallel = copy_to_tensor_model_parallel_region(input_) + # Matrix multiply. + output_parallel = linear_with_grad_accumulation_and_async_allreduce( + input=input_parallel, + weight=self.weight, + bias=bias, + gradient_accumulation_fusion=self.gradient_accumulation_fusion, + async_grad_allreduce=self.async_tensor_model_parallel_allreduce, + sequence_parallel_enabled=self.sequence_parallel_enabled, + ) + if self.gather_output: + # All-gather across the partitions. + assert not self.sequence_parallel_enabled + output = gather_from_tensor_model_parallel_region(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments: + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimization where bias + can be fused with other elementwise operations. We skip + adding bias but instead return it. + params_dtype: + use_cpu_initialization: + perform_initialization: + gradient_accumulation_fusion: + sequence_parallel_enabled: + """ + + def __init__(self, input_size, output_size, *, + bias=True, input_is_parallel=False, + init_method=init.xavier_normal_, stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + params_dtype=None, + use_cpu_initialization=False, + perform_initialization=True, + gradient_accumulation_fusion=False, + sequence_parallel_enabled: bool = False, + ): + super(RowParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + # Divide the weight matrix along the last dimension. + world_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, world_size) + self.skip_bias_add = skip_bias_add + self.gradient_accumulation_fusion = gradient_accumulation_fusion + self.sequence_parallel_enabled = sequence_parallel_enabled + if self.sequence_parallel_enabled and not self.input_is_parallel: + raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`") + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + # Initialize weight. + if use_cpu_initialization: + self.weight = Parameter(torch.empty(self.output_size, + self.input_size_per_partition, + dtype=params_dtype)) + if perform_initialization: + self.master_weight = _initialize_affine_weight_cpu( + self.weight, self.output_size, self.input_size, + self.input_size_per_partition, 1, init_method, + stride=stride, return_master_weight=keep_master_weight_for_test, + params_dtype=params_dtype) + else: + self.weight = Parameter(torch.empty( + self.output_size, self.input_size_per_partition, + device=torch.cuda.current_device(), dtype=params_dtype)) + if perform_initialization: + _initialize_affine_weight_gpu(self.weight, init_method, + partition_dim=1, stride=stride) + if bias: + if use_cpu_initialization: + self.bias = Parameter(torch.empty(self.output_size, + dtype=params_dtype)) + else: + self.bias = Parameter(torch.empty( + self.output_size, device=torch.cuda.current_device(), + dtype=params_dtype)) + setattr(self.bias, 'sequence_parallel', sequence_parallel_enabled) + + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + assert not self.sequence_parallel_enabled + input_parallel = scatter_to_tensor_model_parallel_region(input_) + # Matrix multiply. + output_parallel = linear_with_grad_accumulation_and_async_allreduce( + input=input_parallel, + weight=self.weight, + bias=None, + gradient_accumulation_fusion=self.gradient_accumulation_fusion, + async_grad_allreduce=False, + sequence_parallel_enabled=False, + ) + + # All-reduce across all the partitions. + if self.sequence_parallel_enabled: + output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) + else: + output_ = reduce_from_tensor_model_parallel_region(output_parallel) + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias diff --git a/cacheflow/parallel_utils/tensor_parallel/mappings.py b/cacheflow/parallel_utils/tensor_parallel/mappings.py new file mode 100644 index 000000000000..d9ca3b460d7b --- /dev/null +++ b/cacheflow/parallel_utils/tensor_parallel/mappings.py @@ -0,0 +1,279 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import torch + +from cacheflow.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_group, +) +from .utils import split_tensor_along_last_dim + + +def _reduce(input_): + """All-reduce the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + if get_tensor_model_parallel_world_size()==1: + return input_ + + # All-reduce. + torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) + + return input_ + + +def _split_along_last_dim(input_): + """Split the tensor along its last dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along last dimension. + input_list = split_tensor_along_last_dim(input_, world_size) + + # Note: torch.split does not create contiguous tensors by default. + rank = get_tensor_model_parallel_rank() + output = input_list[rank].contiguous() + + return output + + +def _split_along_first_dim(input_): + """Split the tensor along its first dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along first dimension. + dim_size = input_.size()[0] + assert dim_size % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + local_dim_size = dim_size // world_size + rank = get_tensor_model_parallel_rank() + dim_offset = rank * local_dim_size + + output = input_[dim_offset:dim_offset+local_dim_size].contiguous() + + return output + + +def _gather_along_last_dim(input_): + """Gather tensors and concatinate along the last dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Size and dimension. + last_dim = input_.dim() - 1 + rank = get_tensor_model_parallel_rank() + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output + + +def _gather_along_first_dim(input_): + """Gather tensors and concatinate along the first dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._all_gather_base(output, input_.contiguous(), + group=get_tensor_model_parallel_group()) + + return output + +def _reduce_scatter_along_first_dim(input_): + """Reduce-scatter the input tensor across model parallel group.""" + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + assert dim_size[0] % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + + dim_size[0] = dim_size[0] // world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._reduce_scatter_base(output, input_.contiguous(), + group=get_tensor_model_parallel_group()) + return output + + +class _CopyToModelParallelRegion(torch.autograd.Function): + """Pass the input to the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_): + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output) + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class _ScatterToModelParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split_along_last_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _split_along_last_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_last_dim(grad_output) + + +class _GatherFromModelParallelRegion(torch.autograd.Function): + """Gather the input from model parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_): + return _gather_along_last_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _gather_along_last_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _split_along_last_dim(grad_output) + + +class _ScatterToSequenceParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _split_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + """Gather the input from sequence parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_, tensor_parallel_output_grad=True): + return _gather_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_, tensor_parallel_output_grad=True): + ctx.tensor_parallel_output_grad = tensor_parallel_output_grad + return _gather_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + tensor_parallel_output_grad = ctx.tensor_parallel_output_grad + + # If the computation graph after the gather operation is + # in the tensor parallel mode, output gradients need to reduce + # scattered and whereas if the computation is duplicated, + # output gradients need to be scattered. + if tensor_parallel_output_grad: + return _reduce_scatter_along_first_dim(grad_output), None + else: + return _split_along_first_dim(grad_output), None + + +class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): + """Reduce scatter the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +# ----------------- +# Helper functions. +# ----------------- + +def copy_to_tensor_model_parallel_region(input_): + return _CopyToModelParallelRegion.apply(input_) + + +def reduce_from_tensor_model_parallel_region(input_): + return _ReduceFromModelParallelRegion.apply(input_) + + +def scatter_to_tensor_model_parallel_region(input_): + return _ScatterToModelParallelRegion.apply(input_) + + +def gather_from_tensor_model_parallel_region(input_): + return _GatherFromModelParallelRegion.apply(input_) + + +def scatter_to_sequence_parallel_region(input_): + return _ScatterToSequenceParallelRegion.apply(input_) + + +def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): + return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) + + +def reduce_scatter_to_sequence_parallel_region(input_): + return _ReduceScatterToSequenceParallelRegion.apply(input_) + diff --git a/cacheflow/parallel_utils/tensor_parallel/random.py b/cacheflow/parallel_utils/tensor_parallel/random.py new file mode 100644 index 000000000000..6f36febdeb64 --- /dev/null +++ b/cacheflow/parallel_utils/tensor_parallel/random.py @@ -0,0 +1,253 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch + +import contextlib + +import torch +from torch import _C +from torch.cuda import _lazy_call, device as device_ctx_manager +from torch.utils.checkpoint import detach_variable + +from cacheflow.parallel_utils.parallel_state import ( + get_data_parallel_rank, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +from .utils import ( + split_tensor_into_1d_equal_chunks, + gather_split_1d_tensor, +) + +from cacheflow.parallel_utils.utils import safely_set_viewless_tensor_data + +# Default name for the model parallel rng tracker. +_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' + + +def _set_cuda_rng_state(new_state, device=-1): + """Sets the random number generator state of the current GPU. + + Argumentss: + new_state (torch.ByteTensor): The desired state + This function is adapted from PyTorch repo (torch.cuda.set_rng_state) + with a single change: the input state is not cloned. Cloning caused + major performance issues for +4 GPU cases. + """ + if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): + # older PyTorch + def cb(): + with device_ctx_manager(device): + _C._cuda_setRNGState(new_state) + else: + # newer PyTorch + if device == -1: + device = torch.device('cuda') + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device('cuda', device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state) + + _lazy_call(cb) + + + +class CudaRNGStatesTracker: + """Tracker for the cuda RNG states. + + Using the `add` method, a cuda rng state is initialized based on + the input `seed` and is assigned to `name`. Later, by forking the + rng state, we can perform operations and return to our starting + cuda state. + """ + + def __init__(self): + # Map from a string name to the cuda rng state. + self.states_ = {} + # Seeds are just for book keeping and ensure no seed is set twice. + self.seeds_ = set() + + def reset(self): + """Set to the initial state (no tracker).""" + self.states_ = {} + self.seeds_ = set() + + def get_states(self): + """Get rng states. Copy the dictionary so we have direct + pointers to the states, not just a pointer to the dictionary.""" + states = {} + for name in self.states_: + states[name] = self.states_[name] + return states + + def set_states(self, states): + """Set the rng states. For efficiency purposes, we do not check + the size of seed for compatibility.""" + self.states_ = states + + def add(self, name, seed): + """Track the rng state.""" + # Check seed is not already used. + if seed in self.seeds_: + raise Exception('seed {} already exists'.format(seed)) + self.seeds_.add(seed) + # Check that state is not already defined. + if name in self.states_: + raise Exception('cuda rng state {} already exists'.format(name)) + # Get the current rng state. + orig_rng_state = torch.cuda.get_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + self.states_[name] = torch.cuda.get_rng_state() + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + """Fork the cuda rng state, perform operations, and exit with + the original state.""" + # Check if we have added the state + if name not in self.states_: + raise Exception('cuda rng state {} is not added'.format(name)) + # Store current rng state. + orig_cuda_rng_state = torch.cuda.get_rng_state() + # Set rng state to the desired one + _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + try: + yield + finally: + # Update the current rng state for later use. + self.states_[name] = torch.cuda.get_rng_state() + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state) + + +# RNG tracker object. +_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + + +def get_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _CUDA_RNG_STATE_TRACKER + + +def model_parallel_cuda_manual_seed(seed): + """Initialize model parallel cuda seed. + + This function should be called after the model parallel is + initialized. Also, no torch.cuda.manual_seed should be called + after this function. Basically, this is replacement for that + function. + Two set of RNG states are tracked: + default state: This is for data parallelism and is the same among a + set of model parallel GPUs but different across + different model paralle groups. This is used for + example for dropout in the non-tensor-model-parallel regions. + tensor-model-parallel state: This state is different among a set of model + parallel GPUs, but the same across data parallel + groups. This is used for example for dropout in + model parallel regions. + """ + # 2718 is just for fun and any POSITIVE value will work. + offset = seed + 2718 + tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() + # Data parallel gets the original seed. + data_parallel_seed = seed + + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + torch.cuda.manual_seed(data_parallel_seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, + tensor_model_parallel_seed) + + +class CheckpointFunction(torch.autograd.Function): + """This function is adapted from torch.utils.checkpoint with + two main changes: + 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` + 2) the states in the model parallel tracker are also properly + tracked/set/reset. + """ + @staticmethod + def forward(ctx, run_function, distribute_saved_activations, *args): + ctx.run_function = run_function + ctx.distribute_saved_activations \ + = distribute_saved_activations + + # Copy the rng states. + ctx.fwd_cpu_rng_state = torch.get_rng_state() + ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + with torch.no_grad(): + outputs = run_function(*args) + + # Divide hidden states across model parallel group and only keep + # the chunk corresponding to the current rank. + if distribute_saved_activations: + ctx.input_0_shape = args[0].data.shape + safely_set_viewless_tensor_data( + args[0], + split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)) + + # Store everything. + ctx.save_for_backward(*args) + + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad(), " + "please use .backward() if possible") + inputs = ctx.saved_tensors + if ctx.distribute_saved_activations: + safely_set_viewless_tensor_data( + inputs[0], + gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)) + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = torch.cuda.get_rng_state() + bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_cpu_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) + + # Compute the forward pass. + detached_inputs = detach_variable(inputs) + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + torch.autograd.backward(outputs, args) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp + for inp in detached_inputs) + return (None, None) + grads + + +def checkpoint(function, distribute_saved_activations, *args): + """Checkpoint a model or part of the model. + This has been directly copied from torch.utils.checkpoint.""" + return CheckpointFunction.apply(function, + distribute_saved_activations, *args) diff --git a/cacheflow/parallel_utils/tensor_parallel/utils.py b/cacheflow/parallel_utils/tensor_parallel/utils.py new file mode 100644 index 000000000000..dd15faa9a65b --- /dev/null +++ b/cacheflow/parallel_utils/tensor_parallel/utils.py @@ -0,0 +1,108 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import torch +from typing import List, Sequence + +from cacheflow.parallel_utils.utils import divide +from cacheflow.parallel_utils import parallel_state + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """ Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + +def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): + """ Break a tensor into equal 1D chunks across tensor parallel ranks. + + Returns a Tensor or View with this rank's portion of the data. + + Arguments: + tensor: The tensor to split + + Keyword Arguments: + new_buffer (bool): If True, returns a new Tensor. + If False, returns a view into the existing Tensor. + Default is False + + """ + partition_size = torch.numel(tensor) // \ + parallel_state.get_tensor_model_parallel_world_size() + start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() + end_index = start_index + partition_size + if new_buffer: + data = torch.empty(partition_size, dtype=tensor.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + data.copy_(tensor.view(-1)[start_index:end_index]) + else: + data = tensor.view(-1)[start_index:end_index] + return data + + +def gather_split_1d_tensor(tensor): + """ Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor + model parallel ranks. + + Returns a new Tensor with the gathered data. + + Arguments: + tensor: A Tensor or view of this rank's portion of the data. + """ + numel_gathered = torch.numel(tensor) * \ + parallel_state.get_tensor_model_parallel_world_size() + gathered = torch.empty(numel_gathered, dtype=tensor.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + # TODO: This API is experimental in pytorch (as of Feb 2022) and + # this might break in future pytorch releases. We chose this API + # as opposed to torch.distributed.all_gather for efficiency reasons. + # This API calls directly NCCL all-gather versus the former does + # internal copies and can potentially cause slow down. + torch.distributed._all_gather_base(gathered, tensor, + group=parallel_state.get_tensor_model_parallel_group()) + return gathered + + +class VocabUtility: + """ Split the vocabulary into `world_size` chunks and return the first + and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last) + + """ + + @staticmethod + def vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size: int, rank, world_size: int + ) -> Sequence[int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]: + per_partition_vocab_size = divide(global_vocab_size, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, world_size + ) diff --git a/cacheflow/parallel_utils/utils.py b/cacheflow/parallel_utils/utils.py new file mode 100644 index 000000000000..4b2b76130246 --- /dev/null +++ b/cacheflow/parallel_utils/utils.py @@ -0,0 +1,120 @@ +"""Utility functions used throughout Megatron core""" +from functools import reduce +import operator + +import torch + +from cacheflow.parallel_utils import parallel_state + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator + ) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations. + Caller should ensure that buffers of the same name + are not used concurrently.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name): + required_len = reduce(operator.mul, tensor_shape, 1) + if self.buffer.get((name, dtype), None) is None or \ + self.buffer[(name, dtype)].numel() < required_len: + self.buffer[(name, dtype)] = \ + torch.empty(required_len, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + +def _kernel_make_viewless_tensor(inp, requires_grad): + '''Make a viewless tensor. + + View tensors have the undesirable side-affect of retaining a reference + to the originally-viewed tensor, even after manually setting the '.data' + field. This method creates a new tensor that links to the old tensor's + data, without linking the viewed tensor, referenced via the '._base' + field. + ''' + out = torch.empty( + (1,), + dtype = inp.dtype, + device = inp.device, + requires_grad = requires_grad, + ) + out.data = inp.data + return out + +class MakeViewlessTensor(torch.autograd.Function): + ''' + Autograd function to make a viewless tensor. + + This function should be used in cases where the computation graph needs + to be propagated, but we only want a viewless tensor (e.g., + ParallelTransformer's hidden_states). Call this function by passing + 'keep_graph = True' to 'make_viewless_tensor()'. + ''' + @staticmethod + def forward(ctx, inp, requires_grad): + return _kernel_make_viewless_tensor(inp, requires_grad) + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + +def make_viewless_tensor(inp, requires_grad, keep_graph): + ''' + Entry-point for creating viewless tensors. + + This method should be used, rather than calling 'MakeViewlessTensor' + or '_kernel_make_viewless_tensor' directly. This method acts as a + switch for determining if an autograd function or a regular method + should be used to create the tensor. + ''' + + # return tensor as-is, if not a 'view' + if inp._base is None: + return inp + + # create viewless tensor + if keep_graph: + return MakeViewlessTensor.apply(inp, requires_grad) + else: + return _kernel_make_viewless_tensor(inp, requires_grad) + +def assert_viewless_tensor(tensor, extra_msg = None): + '''Assert that a tensor is not a view (i.e., its '._base' field is + not set).''' + if isinstance(tensor, list): + [ assert_viewless_tensor(t) for t in tensor ] + return tensor + if not isinstance(tensor, torch.Tensor): + return tensor + assert tensor._base is None, ( + "Ensure tensor._base is None before setting tensor.data or storing " + "tensor to memory buffer. Otherwise, a memory leak will occur (and " + "likely accumulate over iterations). %s" + ) % extra_msg + return tensor + +def safely_set_viewless_tensor_data(tensor, new_data_tensor): + '''Safely set tensor's '.data' field. + + Check first that the tensor is viewless (i.e., '._base' not set). If not, + raise an exception. + ''' + assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape)) + tensor.data = new_data_tensor diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index fb9e9daba01e..471052fbd5a9 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -158,3 +158,9 @@ def __repr__(self) -> str: f'parent_seq_id={self.parent_seq_id}, ' f'output_token={self.output_token}), ' f'logprobs={self.logprobs}') + + def __eq__(self, other: 'SequenceOutputs') -> bool: + return (self.seq_id == other.seq_id and + self.parent_seq_id == other.parent_seq_id and + self.output_token == other.output_token and + self.logprobs == other.logprobs) diff --git a/cacheflow/utils.py b/cacheflow/utils.py index fff6b7a86871..db8eb8aaba4c 100644 --- a/cacheflow/utils.py +++ b/cacheflow/utils.py @@ -1,4 +1,11 @@ import enum +import random + +import numpy as np +import torch + +from cacheflow.parallel_utils.parallel_state import model_parallel_is_initialized +from cacheflow.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed class Device(enum.Enum): @@ -18,3 +25,13 @@ def __next__(self) -> int: def reset(self) -> None: self.counter = 0 + +def set_random_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + if model_parallel_is_initialized(): + model_parallel_cuda_manual_seed(seed) diff --git a/cacheflow/worker/cache_engine.py b/cacheflow/worker/cache_engine.py index cf509abb647b..164b2a2a60fd 100644 --- a/cacheflow/worker/cache_engine.py +++ b/cacheflow/worker/cache_engine.py @@ -11,7 +11,6 @@ class CacheEngine: def __init__( self, worker_id: int, - gpu_id: int, num_layers: int, num_heads: int, head_size: int, @@ -25,7 +24,6 @@ def __init__( f'head_size ({head_size}) must be a multiple of 16.') self.worker_id = worker_id - self.gpu_id = gpu_id self.num_layers = num_layers self.num_heads = num_heads self.head_size = head_size @@ -39,8 +37,8 @@ def __init__( self.cpu_cache = self.allocate_cpu_cache() # Initialize the stream for caching operations. - self.cache_stream = torch.cuda.Stream(device=gpu_id) - assert self.cache_stream != torch.cuda.current_stream(device=gpu_id) + self.cache_stream = torch.cuda.Stream() + assert self.cache_stream != torch.cuda.current_stream() # Initialize the events for stream synchronization. self.events = [torch.cuda.Event() for _ in range(num_layers)] @@ -69,12 +67,12 @@ def allocate_gpu_cache(self) -> List[KVCache]: key_blocks = torch.empty( size=(self.num_gpu_blocks, *key_block_shape), dtype=self.dtype, - device=self.gpu_id, + device="cuda", ) value_blocks = torch.empty( size=(self.num_gpu_blocks, *value_block_shape), dtype=self.dtype, - device=self.gpu_id, + device="cuda", ) gpu_cache.append((key_blocks, value_blocks)) return gpu_cache diff --git a/cacheflow/worker/controller.py b/cacheflow/worker/controller.py index 31bd03e0c20b..bb357b132665 100644 --- a/cacheflow/worker/controller.py +++ b/cacheflow/worker/controller.py @@ -1,45 +1,62 @@ -from typing import Dict, List, Union +from typing import Dict, List, Union, Tuple + +import ray from cacheflow.master.scheduler import Scheduler from cacheflow.sequence import SequenceGroupInputs from cacheflow.worker.worker import Worker +DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id + + class Controller: def __init__( self, - node_id: int, - num_workers: int, + stage_id: int, + stage_devices: List[DeviceID], + world_size: int, + tensor_parallel_size: int, + pipeline_parallel_size: int, + distributed_init_method: str, model_name: str, block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, dtype: str, seed: int, + model_path: str, ) -> None: - self.node_id = node_id - self.num_workers = num_workers + self.stage_id = stage_id + self.stage_devices = stage_devices self.model_name = model_name self.block_size = block_size self.num_gpu_blocks = num_gpu_blocks self.num_cpu_blocks = num_cpu_blocks # Which pipeline stage is this node assigned to? - self.is_first_stage = node_id == 0 + self.is_first_stage = stage_id == 0 self.is_last_stage = False self.workers: List[Worker] = [] - for i in range(num_workers): - worker = Worker( - worker_id=node_id + i, - gpu_id=i, + for rank, node_resource, device_id in stage_devices: + worker_cls = ray.remote(num_cpus=0, + num_gpus=1, + resources={node_resource: 1e-5})(Worker) + worker = worker_cls.remote( model_name=model_name, block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, dtype=dtype, seed=seed, + distributed_init_method=distributed_init_method, + rank=rank, + world_size=world_size, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + model_path=model_path, ) self.workers.append(worker) @@ -57,15 +74,21 @@ def execute_stage( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ) -> None: - # FIXME: Support tensor parallelism. - assert len(self.workers) == 1 - worker = self.workers[0] - output = worker.execute_stage( - input_seq_groups, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - ) + futures = [] + for worker in self.workers: + future = worker.execute_stage.remote( + input_seq_groups, + blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, + ) + futures.append(future) + + all_outputs = ray.get(futures) + # Make sure all workers have the same results. + output = all_outputs[0] + for other_output in all_outputs[1:]: + assert output == other_output if self.is_last_stage: self.next_node.post_step(output) diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 3f4b451934c9..f309cf12f672 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -3,49 +3,58 @@ import torch from cacheflow.models import get_model -from cacheflow.models import set_seed from cacheflow.models import InputMetadata from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import SequenceGroupInputs from cacheflow.sequence import SequenceOutputs from cacheflow.worker.cache_engine import CacheEngine +from cacheflow.parallel_utils.parallel_state import ( + initialize_model_parallel, get_tensor_model_parallel_world_size) +from cacheflow.utils import set_random_seed class Worker: def __init__( self, - worker_id: int, - gpu_id: int, model_name: str, block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, dtype: str, seed: int, + distributed_init_method: str, + rank: int, + world_size: int, + model_path: str, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, ) -> None: - self.worker_id = worker_id - self.gpu_id = gpu_id + self.init_distributed_environment(distributed_init_method, + rank, + world_size, + tensor_parallel_size, + pipeline_parallel_size) + self.worker_id = rank self.block_size = block_size - - self.device = torch.device('cuda', index=gpu_id) + set_random_seed(seed) # Initialize the model. - # FIXME(woosuk): This is a hack. - self.model = get_model(model_name, dtype=dtype).to(device=self.device) + self.model, self.dtype = get_model(model_name, dtype=dtype, path=model_path) + self.model = self.model.cuda() + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) self.num_layers = self.model.config.num_hidden_layers - self.num_heads = self.model.config.num_attention_heads - self.head_size = self.model.config.hidden_size // self.num_heads - self.dtype = self.model.dtype + assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0 + self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size + self.head_size = self.model.config.hidden_size // (self.num_heads * tensor_model_parallel_world_size) - # Set the seed. - # We set the seed after initializing the model to ensure that + # We reset the seed after initializing the model to ensure that # the random state is not affected by the model initialization. - set_seed(seed) + set_random_seed(seed) self.cache_engine = CacheEngine( - worker_id=worker_id, - gpu_id=gpu_id, + worker_id=self.worker_id, num_layers=self.num_layers, num_heads=self.num_heads, head_size=self.head_size, @@ -57,6 +66,26 @@ def __init__( self.cache_events = self.cache_engine.events self.gpu_cache = self.cache_engine.gpu_cache + + def init_distributed_environment(self, + distributed_init_method: str, + rank: int, + world_size: int, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1) -> None: + """Initialize the distributed environment.""" + torch.distributed.init_process_group( + backend='nccl', + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + initialize_model_parallel(tensor_parallel_size, + pipeline_parallel_size) + + def prepare_inputs( self, input_seq_groups: List[SequenceGroupInputs], @@ -142,18 +171,18 @@ def prepare_inputs( # Convert to tensors. tokens_tensor = torch.tensor( - input_tokens, dtype=torch.long, device=self.device) + input_tokens, dtype=torch.long, device='cuda') positions_tensor = torch.tensor( - input_positions, dtype=torch.long, device=self.device) + input_positions, dtype=torch.long, device='cuda') slot_mapping_tensor = torch.tensor( - slot_mapping, dtype=torch.int, device=self.device) + slot_mapping, dtype=torch.int, device='cuda') context_lens_tensor = torch.tensor( - context_lens, dtype=torch.int, device=self.device) + context_lens, dtype=torch.int, device='cuda') padded_block_tables = [ _pad_to_max(block_table, max_num_blocks_per_seq) for block_table in generation_block_tables] block_tables_tensor = torch.tensor( - padded_block_tables, dtype=torch.int, device=self.device) + padded_block_tables, dtype=torch.int, device='cuda') input_metadata = InputMetadata( seq_groups=seq_groups, diff --git a/server.py b/server.py index ccd6f8f6e3f8..5838f439f53e 100644 --- a/server.py +++ b/server.py @@ -1,30 +1,99 @@ import argparse -from typing import List +import random +from typing import List, Tuple, Dict + +import ray from cacheflow.master.frontend import Frontend from cacheflow.master.scheduler import Scheduler from cacheflow.models import get_memory_analyzer -from cacheflow.worker.controller import Controller - -parser = argparse.ArgumentParser(description='CacheFlow server') -parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') -parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes') -parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node') -parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size') -# NOTE(woosuk): If FlashAttention is used, the float data type is not supported. -parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') -# TODO(woosuk): Support fine-grained seeds (e.g., seed per request). -parser.add_argument('--seed', type=int, default=0, help='random seed') -parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') -parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens') -args = parser.parse_args() - - -def main(): +from cacheflow.worker.controller import Controller, DeviceID + + +def initialize_ray_cluster( + address: str = 'auto', + pipeline_parallel_size: int = 1, + tensor_parallel_size: int = 1, +) -> Tuple[int, int, str, List[List[DeviceID]]]: + # Connect to a ray cluster. + ray.init(address=address) + + # Assume we have a uniform cluster that each node has the same number of + # GPUs for now. + valid_node_resources = [] + num_devices_per_node = None + for node in ray.nodes(): + if (not node['Alive']) or node['Resources']['GPU'] <= 0: + continue + if num_devices_per_node is None: + num_devices_per_node = node['Resources']['GPU'] + else: + assert num_devices_per_node == node['Resources']['GPU'], ( + "The number of GPUs per node is not uniform.") + for key in node['Resources']: + if key.startswith('node:'): + valid_node_resources.append(key) + + num_nodes = len(valid_node_resources) + + assert (pipeline_parallel_size * tensor_parallel_size + <= num_nodes * num_devices_per_node), ( + "The number of required GPUs exceeds the total number of " + "available GPUs.") + if tensor_parallel_size >= num_devices_per_node: + assert tensor_parallel_size % num_devices_per_node == 0, ( + "The number of tensor parallelism is not divisible by the " + "number of GPUs per node.") + else: + assert num_devices_per_node % tensor_parallel_size == 0, ( + "The number of GPUs per node is not divisible by the number " + "of tensor parallelism.") + + # Assign GPUs to pipeline stages. + rank = 0 + current_node_id = 0 + current_device_id = 0 + distributed_init_method = None + all_stage_devices = [] + + for i in range(pipeline_parallel_size): + stage_devices = [] + for j in range(tensor_parallel_size): + node_resource = valid_node_resources[current_node_id] + stage_devices.append((rank, node_resource, current_device_id)) + if distributed_init_method is None: + ip = node_resource.split("node:")[-1] + port = random.randint(10000, 20000) + distributed_init_method = f"tcp://{ip}:{port}" + rank += 1 + current_device_id += 1 + if current_device_id >= num_devices_per_node: + current_node_id += 1 + current_device_id = 0 + all_stage_devices.append(stage_devices) + + return (num_nodes, num_devices_per_node, distributed_init_method, + all_stage_devices) + + +def main(args: argparse.Namespace): + # TODO(zhuohan): Support pipeline parallelism. + assert args.pipeline_parallel_size == 1, ( + 'Pipeline parallelism is not supported yet.') + + (num_nodes, num_devices_per_node, distributed_init_method, + all_stage_devices) = ( + initialize_ray_cluster( + pipeline_parallel_size=args.pipeline_parallel_size, + tensor_parallel_size=args.tensor_parallel_size)) + + world_size = args.pipeline_parallel_size * args.tensor_parallel_size + memory_analyzer = get_memory_analyzer( model_name=args.model, block_size=args.block_size, dtype=args.dtype, + tensor_parallel_size=args.tensor_parallel_size, ) num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks( max_num_batched_tokens=args.max_batch_size) @@ -32,18 +101,23 @@ def main(): swap_space=args.swap_space) print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}') - # Create a controller for each node. + # Create a controller for each pipeline stage. controllers: List[Controller] = [] - for i in range(args.num_nodes): + for i in range(args.pipeline_parallel_size): controller = Controller( - node_id=i, - num_workers=args.num_workers, + stage_id=i, + stage_devices=all_stage_devices[i], + world_size=world_size, + pipeline_parallel_size=args.pipeline_parallel_size, + tensor_parallel_size=args.tensor_parallel_size, + distributed_init_method=distributed_init_method, model_name=args.model, block_size=args.block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, dtype=args.dtype, seed=args.seed, + model_path=args.model_path, ) controllers.append(controller) @@ -83,4 +157,22 @@ def main(): if __name__ == '__main__': - main() + parser = argparse.ArgumentParser(description='CacheFlow server') + # Model arguments + parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') + parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights', + help='model path to download and load the weights') + # Parallel arguments + parser.add_argument('--pipeline-parallel-size', type=int, default=1, help='number of pipeline stages') + parser.add_argument('--tensor-parallel-size', type=int, default=1, help='number of tensor parallel replicas') + # KV cache arguments + parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size') + # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. + parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') + # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). + parser.add_argument('--seed', type=int, default=0, help='random seed') + parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') + parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens') + args = parser.parse_args() + + main(args)