diff --git a/docs/CN/source/getting_started/quickstart.rst b/docs/CN/source/getting_started/quickstart.rst index eaf664748..e7d303499 100755 --- a/docs/CN/source/getting_started/quickstart.rst +++ b/docs/CN/source/getting_started/quickstart.rst @@ -56,6 +56,22 @@ .. note:: 上面代码中的 ``--model_dir`` 参数需要修改为你本机实际的模型路径。 +单机H200部署 DeepSeek-R1 模型, 启动命令如下: + +.. code-block:: console + + $ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 8 --graph_max_batch_size 100 + +.. note:: + LOADWORKER 指定了模型加载的线程,可以提高模型加载的速度。--graph_max_batch_size 指定了要捕获的cudagraph的数量,将捕获从1到100的batch size的图。 + +双机H100部署 DeepSeek-R1 模型,启动命令如下: + +.. code-block:: console + $ # Node 0 + $ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 16 --graph_max_batch_size 100 --nccl_host master_addr --nnodes 2 --node_rank 0 + $ # Node 1 + $ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 16 --graph_max_batch_size 100 --nccl_host master_addr --nnodes 2 --node_rank 1 3. (可选)测试模型服务 ------------------------- @@ -75,3 +91,10 @@ $ }' +对于DeepSeek-R1模型,可以用如下脚本进行测试: + +.. code-block:: console + + $ cd test + $ python benchmark_client.py --num_clients 100 --input_num 2000 --tokenizer_path /nvme/DeepSeek-R1/ --url http://127.0.01:8000/generate_stream + diff --git a/docs/EN/source/getting_started/quickstart.rst b/docs/EN/source/getting_started/quickstart.rst index 7ff2b95c1..f9563a8ab 100755 --- a/docs/EN/source/getting_started/quickstart.rst +++ b/docs/EN/source/getting_started/quickstart.rst @@ -53,7 +53,7 @@ After downloading the Llama-2-7b-chat model, use the following command in the te .. note:: The ``--model_dir`` parameter in the above command should be changed to the actual path of your model on your machine. -For the DeepSeek-R1 model on H200, it can be launched with the following command: +For the DeepSeek-R1 model on single H200, it can be launched with the following command: .. code-block:: console @@ -62,6 +62,14 @@ For the DeepSeek-R1 model on H200, it can be launched with the following command .. note:: LOADWORKER specifies the thread for model loading, which can enhance the speed of model loading. The --graph_max_batch_size parameter specifies the number of cudagraphs to be captured, which will capture graphs for batch sizes ranging from 1 to 100. +For the DeepSeek-R1 model on two H100, it can be launched with the following command: + +.. code-block:: console + $ # Node 0 + $ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 16 --graph_max_batch_size 100 --nccl_host master_addr --nnodes 2 --node_rank 0 + $ # Node 1 + $ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 16 --graph_max_batch_size 100 --nccl_host master_addr --nnodes 2 --node_rank 1 + 3. (Optional) Test the Model Service -------------------------------------- diff --git a/lightllm/common/basemodel/layer_weights/base_layer_weight.py b/lightllm/common/basemodel/layer_weights/base_layer_weight.py index 94bb0a8a7..749eac619 100644 --- a/lightllm/common/basemodel/layer_weights/base_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/base_layer_weight.py @@ -2,6 +2,7 @@ import numpy as np import threading from lightllm.common.basemodel.layer_weights.meta_weights import BaseWeight +from lightllm.utils.dist_utils import get_current_device_id class BaseLayerWeight: @@ -37,4 +38,4 @@ def _cuda(self, cpu_tensor): if self.tp_rank_ is None: return cpu_tensor.contiguous().to(self.data_type_).cuda() else: - return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_) + return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id()) diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index b96df55e0..ae44f746c 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -3,14 +3,14 @@ import gc from safetensors import safe_open import lightllm.utils.petrel_helper as utils +from lightllm.utils.dist_utils import get_current_device_id def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None): # fix bug for 多线程加载的时候,每个线程内部的cuda device 会切回 0, 修改后来保证不会出现bug import torch.distributed as dist - tp_rank = dist.get_rank() - torch.cuda.set_device(tp_rank) + torch.cuda.set_device(get_current_device_id()) if use_safetensors: weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index 762617b71..f284a7774 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -1,7 +1,6 @@ import torch from abc import ABC, abstractmethod -from lightllm.utils.dist_utils import get_world_size, get_rank -from lightllm.utils.device_utils import get_current_device_id +from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id class BaseWeight(ABC): @@ -19,8 +18,8 @@ def verify_load(self): class BaseWeightTpl(BaseWeight): def __init__(self): - self.world_size_ = get_world_size() - self.tp_rank_ = get_rank() + self.world_size_ = get_global_world_size() + self.tp_rank_ = get_global_rank() self.device_id_ = get_current_device_id() def load_hf_weights(self, weights): diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py index 2a58e5a80..94c16a6d0 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py @@ -5,9 +5,8 @@ from .base_weight import BaseWeight from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod from lightllm.common.quantization.quantize_method import QuantizationMethod -from lightllm.utils.dist_utils import get_world_size, get_rank +from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id from lightllm.common.vllm_kernel import _custom_ops as ops -from lightllm.utils.device_utils import get_current_device_id class FusedMoeWeight(BaseWeight): @@ -39,7 +38,7 @@ def __init__( self.n_routed_experts = n_routed_experts self.split_inter_size = split_inter_size self.data_type_ = data_type - self.tp_rank_ = get_rank() + self.tp_rank_ = get_global_rank() self.experts_up_projs = [None] * self.n_routed_experts self.experts_gate_projs = [None] * self.n_routed_experts self.experts_up_proj_scales = [None] * self.n_routed_experts @@ -159,7 +158,7 @@ def _fuse_weight_scale(self): delattr(self, "experts_gate_proj_scales") def _load_hf_weights_etp(self, weights): - world_size_ = get_world_size() + world_size_ = get_global_world_size() assert self.n_routed_experts % world_size_ == 0 n_expert_ep = self.n_routed_experts // world_size_ diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py index ee7c0cf33..2bd5ba9ab 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py @@ -4,6 +4,7 @@ from typing import Optional, Tuple, List, Dict, Any from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.quantization.quantize_method import QuantizationMethod +from lightllm.utils.dist_utils import get_current_device_id def generate_scale_name(name, weight_scale_suffix, act_scale_suffix): @@ -73,20 +74,17 @@ def _post_load_weights(self) -> None: and (not self.static_activation or self.input_scale is not None) ): if self.weight_scale.ndim > 1: - # 让 k dim 更连续,大多数split k 算法的算子可能能更快 - self.weight_scale = self.weight_scale.cuda(self.device_id_).transpose(0, 1) + self.weight_scale = self.weight_scale.transpose(0, 1).cuda(get_current_device_id()) self.weight = [ - # 让 k dim 更连续,大多数split k 算法的算子可能能更快 - self.weight.cuda(self.device_id_).transpose(0, 1), + self.weight.cuda(get_current_device_id()).transpose(0, 1), self.weight_scale, self.input_scale, ] else: - self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.device_id_)) + self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(get_current_device_id())) return - # 让 k dim 更连续,大多数split k 算法的算子可能能更快 - self.weight = self.weight.to(self.data_type_).cuda(self.device_id_).transpose(0, 1) + self.weight = self.weight.to(self.data_type_).cuda(get_current_device_id()).transpose(0, 1) class MMWeight(MMWeightTpl): @@ -133,7 +131,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: self.weight = weight[self.start : self.end] if self.bias_name in weights: bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end] - self.bias = bias.cuda(self.device_id_) + self.bias = bias.cuda(get_current_device_id()) if self.weight_scale_name is not None and self.weight_scale_name in weights: block_size = 1 @@ -154,7 +152,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: if self.act_scale_name is not None and self.act_scale_name in weights: input_scale = weights[self.act_scale_name].to(torch.float) - self.input_scale = input_scale.cuda() + self.input_scale = input_scale.cuda(get_current_device_id()) if weight is None and weight_scale is None and input_scale is None: return @@ -198,7 +196,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: self.weight = weight[:, self.start : self.end] if self.bias_name in weights: bias = weights[self.bias_name] - self.bias = (bias / self.world_size_).to(self.data_type_).cuda(self.device_id_) + self.bias = (bias / self.world_size_).to(self.data_type_).cuda(get_current_device_id()) if self.quantized_weight and self.weight_scale_name in weights: block_size = 1 @@ -216,7 +214,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: if self.static_activation and self.act_scale_name in weights: input_scale = weights[self.act_scale_name].to(torch.float) - self.input_scale = input_scale.cuda() + self.input_scale = input_scale.cuda(get_current_device_id()) if weight is None and weight_scale is None and input_scale is None: return @@ -294,19 +292,19 @@ def _fuse(self) -> None: delattr(self, "weights") if self.weight_scale is None and (None not in self.weight_scales): - self.weight_scale = torch.cat(self.weight_scales, dim=0).cuda() + self.weight_scale = torch.cat(self.weight_scales, dim=0).cuda(get_current_device_id()) self._post_load_weights() delattr(self, "weight_scales") if self.static_activation and self.input_scale is None and (None not in self.input_scales): input_scales = torch.stack(self.input_scales, dim=0) - self.input_scale = torch.max(input_scales).cuda() + self.input_scale = torch.max(input_scales).cuda(get_current_device_id()) self._post_load_weights() delattr(self, "input_scales") if self.has_bias: if self.bias is None and (None not in self.biases): - self.bias = torch.cat(self.biases, dim=0).cuda(self.device_id_) + self.bias = torch.cat(self.biases, dim=0).cuda(get_current_device_id()) delattr(self, "biases") return self @@ -449,10 +447,10 @@ def _post_load_weights(self) -> None: and (not self.static_activation or self.input_scale is not None) ): if self.weight_scale.ndim > 1: - self.weight_scale = self.weight_scale.cuda(self.device_id_) - self.weight = [self.weight.cuda(self.device_id_), self.weight_scale, self.input_scale] + self.weight_scale = self.weight_scale.cuda(get_current_device_id()) + self.weight = [self.weight.cuda(get_current_device_id()), self.weight_scale, self.input_scale] return - self.weight = self.weight.cuda(self.device_id_) + self.weight = self.weight.cuda(get_current_device_id()) class BMMWeight(BMMWeightTpl): @@ -518,7 +516,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: self.weight = weight[self.start : self.end] if self.bias_name in weights: bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end] - self.bias = bias.cuda(self.device_id_) + self.bias = bias.cuda(get_current_device_id()) if self.weight_scale_name is not None and self.weight_scale_name in weights: weight_scale = weights[self.weight_scale_name] @@ -532,7 +530,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: if self.act_scale_name is not None and self.act_scale_name in weights: input_scale = weights[self.act_scale_name].to(torch.float) - self.input_scale = input_scale.cuda() + self.input_scale = input_scale.cuda(get_current_device_id()) if weight is None and weight_scale is None and input_scale is None: return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 8d593b603..7ec672ab8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -1,5 +1,6 @@ import torch from .base_weight import BaseWeightTpl +from lightllm.utils.dist_utils import get_current_device_id class NormWeight(BaseWeightTpl): @@ -13,9 +14,9 @@ def __init__(self, weight_name, data_type, bias_name=None): def load_hf_weights(self, weights): if self.weight_name in weights: - self.weight = weights[self.weight_name].to(self.data_type_).cuda(self.device_id_) + self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id()) if self.bias_name in weights: - self.bias = weights[self.bias_name].to(self.data_type_).cuda(self.device_id_) + self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id()) def verify_load(self): load_ok = True @@ -33,7 +34,7 @@ def __init__(self, weight_name, data_type, bias_name=None): def load_hf_weights(self, weights): if self.weight_name in weights: - self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(self.device_id_) + self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(get_current_device_id()) class TpNormWeight(NormWeight): @@ -46,6 +47,6 @@ def load_hf_weights(self, weights): end = self.split_n_embed * (self.tp_rank_ + 1) if self.weight_name in weights: - self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(self.device_id_) + self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id()) if self.bias_name in weights: - self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(self.device_id_) + self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id()) diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index cbbb6ae6a..f12227c31 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -8,6 +8,9 @@ from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory from lightllm.common.kv_trans_kernel.kv_trans import kv_trans +from lightllm.utils.dist_utils import get_global_rank +from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args + logger = init_logger(__name__) @@ -32,14 +35,10 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.can_use_mem_size = self.size # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from torch.distributed.distributed_c10d import _default_pg_init_method - - nccl_port = re.search(r":(\d+)$", _default_pg_init_method).group(1) - assert nccl_port is not None - logger.info(f"mem manger get nccl port: {str(nccl_port)}") + from lightllm.utils.envs_utils import get_unique_server_name - rank_id = dist.get_rank() - self.shared_can_use_token_num = SharedInt(f"{str(nccl_port)}_mem_manger_can_use_token_num_{rank_id}") + rank_id = get_global_rank() + self.shared_can_use_token_num = SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_id}") self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( @@ -56,12 +55,10 @@ def get_cell_size(self): def profile_size(self, mem_fraction): if self.size is not None: return - import torch.distributed as dist - tp_rank = dist.get_rank() world_size = dist.get_world_size() total_memory = get_total_gpu_memory() - available_memory = get_available_gpu_memory(tp_rank, world_size) - total_memory * (1 - mem_fraction) + available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) cell_size = self.get_cell_size() self.size = int(available_memory * 1024 ** 3 / cell_size) logger.info( @@ -95,7 +92,6 @@ def send_to_decode_node(self, move_tasks: List[KVMoveTask], mem_managers: List[" assert dp_size == 1 # 先将数据发送到指定的一张卡上的buffer,再发送。 - import torch.distributed as dist move_token_indexes = [] for task in move_tasks: @@ -137,7 +133,6 @@ def receive_from_prefill_node( assert dp_size == 1 # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。 - import torch.distributed as dist move_token_indexes = [] for task in move_tasks: @@ -172,7 +167,6 @@ def send_to_decode_node_p2p(self, move_tasks: List[KVMoveTask], mem_managers: Li assert dp_size == 1 # 先将数据发送到指定的一张卡上的buffer,再发送。 - import torch.distributed as dist move_token_indexes = [] for task in move_tasks: @@ -201,7 +195,6 @@ def receive_from_prefill_node_p2p( assert dp_size == 1 # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。 - import torch.distributed as dist move_token_indexes = [] for task in move_tasks: @@ -307,10 +300,23 @@ class ReadOnlyStaticsMemoryManager: 读取一些统计信息 """ - def __init__(self, nccl_port, tp_size) -> None: - self.shared_tp_infos = [ - SharedInt(f"{str(nccl_port)}_mem_manger_can_use_token_num_{tp_index}") for tp_index in range(tp_size) - ] - - def get_unrefed_token_num(self, tp_index: int): - return self.shared_tp_infos[tp_index].get_value() + def __init__(self) -> None: + args = get_env_start_args() + self.global_world_size = args.tp + node_world_size = args.tp // args.nnodes + rank_start = args.node_rank * node_world_size + rank_end = (args.node_rank + 1) * node_world_size + self.shared_tp_infos = { + rank: SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank}") + for rank in range(rank_start, rank_end) + } + + def get_unrefed_token_num(self, dp_rank: int): + args = get_env_start_args() + if args.dp == 1 and args.nnodes > 1: + # 兼容多机 dp size=1 的情况 + rank_id = args.tp // args.nnodes * args.node_rank + return self.shared_tp_infos[rank_id].get_value() + dp_size = args.dp + dp_world_size = self.global_world_size // dp_size + return self.shared_tp_infos[dp_rank * dp_world_size].get_value() diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 01d339674..b8bcf8a63 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -1,6 +1,7 @@ import torch from abc import ABC, abstractmethod -from lightllm.utils.device_utils import get_current_device_id +from lightllm.utils.dist_utils import get_current_device_id + class QuantizationMethod(ABC): def __init__(self): diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index 56947055f..0f1575f8e 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -3,7 +3,7 @@ import numpy as np import torch.nn.functional as F from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.utils.device_utils import get_current_device_id +from lightllm.utils.dist_utils import get_current_device_id class ViTPreAndPostLayerWeight(PreAndPostLayerWeight): diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 7f7a9a7bb..f773d90fd 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -11,7 +11,7 @@ MultiROWMMWeight, TpNormWeight, ) -from lightllm.utils.device_utils import get_current_device_id +from lightllm.utils.dist_utils import get_current_device_id class ViTTransformerLayerWeight(TransformerLayerWeight): diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 511e6c69b..411cc3e80 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -24,7 +24,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--pd_master_ip", type=str, - default="127.0.0.1", + default="0.0.0.0", help="when run_mode set to prefill or decode, you need set this pd_mater_ip", ) parser.add_argument( @@ -93,6 +93,20 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time" ) + parser.add_argument("--nnodes", type=int, default=1, help="the number of nodes") + parser.add_argument("--node_rank", type=int, default=0, help="the rank of the current node") + parser.add_argument( + "--multinode_httpmanager_port", + type=int, + default=12345, + help="the port for multinode http manager, default is 20000", + ) + parser.add_argument( + "--multinode_router_gloo_port", + type=int, + default=20001, + help="the gloo port for multinode router, default is 20001", + ) parser.add_argument("--tp", type=int, default=1, help="model tp parral size, the default is 1") parser.add_argument( "--dp", @@ -105,6 +119,13 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--max_req_total_len", type=int, default=16384, help="the max value for req_input_len + req_output_len" ) + parser.add_argument( + "--nccl_host", + type=str, + default="127.0.0.1", + help="""The nccl_host to build a distributed environment for PyTorch. + When deploying in multi-node manner, the value should be set to the IP of the master node""", + ) parser.add_argument( "--nccl_port", type=int, default=28765, help="the nccl_port to build a distributed environment for PyTorch" ) @@ -223,6 +244,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" ) parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage") + parser.add_argument( "--graph_max_batch_size", type=int, diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 667ac8cb7..35ac87579 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -54,6 +54,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient +from lightllm.utils.envs_utils import get_unique_server_name from dataclasses import dataclass logger = init_logger(__name__) @@ -100,7 +101,7 @@ def set_args(self, args): enable_multimodal=args.enable_multimodal, metric_port=args.metric_port, ) - self.shared_token_load = TokenLoad(f"{str(args.nccl_port)}_shared_token_load", args.dp) + self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", args.dp) g_objs = G_Objs() diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 4a352bf6e..f7bb38c2d 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -10,10 +10,11 @@ from .embed_cache.manager import start_cache_manager from .visualserver.manager import start_visual_process from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name +from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name from .detokenization.manager import start_detokenization_process from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active +from lightllm.utils.multinode_utils import send_and_receive_node_ip logger = init_logger(__name__) @@ -58,26 +59,16 @@ def signal_handler(sig, frame): return -def set_env(args): - import os - - if args.static_quant: - os.environ["STATIC_QUANT"] = "1" - set_unique_server_name(args) - set_env_start_args(args) - return - - def normal_or_p_d_start(args): + set_unique_server_name(args) if args.run_mode not in ["normal", "prefill", "decode"]: return assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] - # 确保单机上多实列不冲突 if args.zmq_mode == "ipc:///tmp/": - zmq_mode = f"{args.zmq_mode}_{str(args.nccl_port)}_" + zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_" args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功 args.zmq_mode = zmq_mode logger.info(f"zmq mode head: {args.zmq_mode}") @@ -208,7 +199,8 @@ def normal_or_p_d_start(args): if args.run_mode == "decode": args.router_max_wait_tokens = 0 - set_env(args) + send_and_receive_node_ip(args) # 多机用于收发node ip + set_env_start_args(args) logger.info(f"all start args:{args}") ports_locker.release_port() @@ -280,6 +272,7 @@ def normal_or_p_d_start(args): def pd_master_start(args): + set_unique_server_name(args) if args.run_mode != "pd_master": return @@ -291,7 +284,7 @@ def pd_master_start(args): args.metric_port = metric_port - set_env(args) + set_env_start_args(args) process_manager.start_submodule_processes( start_funcs=[ diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index ccf648a4c..ebb4aa454 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -96,6 +96,14 @@ class Req(ctypes.Structure): ("cumlogprob", ctypes.c_float), ] + def get_str(self): + return ( + f"request_id:{self.request_id}, input_len:{self.input_len}," + f"shm_cur_kv_len:{self.shm_cur_kv_len}," + f"shm_cur_output_len:{self.shm_cur_output_len}," + f"finish_status:{self.finish_status.is_finished()}" + ) + def init( self, request_id: int, diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index b38327e25..ff2224332 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -5,14 +5,16 @@ import uvloop import rpyc import time +import copy import hashlib import datetime import websockets import pickle import ujson as json +import multiprocessing asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -from typing import Union, List, Tuple, Dict +from typing import Union, List, Tuple, Dict, Optional from ..tokenizer import get_tokenizer from ..pd_io_struct import NodeRole, ObjType from ..embed_cache.utils import get_shm_name_data, create_shm @@ -50,6 +52,29 @@ def __init__( self.send_to_router = context.socket(zmq.PUSH) self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}") + self.multinode_req_manager = None + self.nnodes = args.nnodes + self.node_rank = args.node_rank + self.transfer_lock = asyncio.Lock() # the lock for transfer to next module in multi node mode. + self.disable_abort = args.nnodes > 1 and args.dp == 1 # mulitnode dp=1 mode, disable abort + if args.nnodes > 1: + if args.node_rank == 0: + self.multinode_req_manager = [] + for child_ip in args.child_ips: + context = zmq.asyncio.Context(2) + self.multinode_req_manager.append(context.socket(zmq.PUSH)) + self.multinode_req_manager[-1].connect(f"tcp://{child_ip}:{args.multinode_httpmanager_port}") + logger.info( + f"HttpServerManager connected to child node at {child_ip}:{args.multinode_httpmanager_port}" + ) + else: + context = zmq.asyncio.Context(2) + self.multinode_req_manager = context.socket(zmq.PULL) + self.multinode_req_manager.bind(f"tcp://*:{args.multinode_httpmanager_port}") + logger.info( + f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}" + ) + self.enable_multimodal = enable_multimodal if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", cache_port) @@ -126,31 +151,69 @@ def tokens(self, prompt, kwargs=None): prompt_ids = self.tokenizer.encode(prompt, None, **kwargs) return len(prompt_ids) - async def generate( - self, - prompt: Union[str, List[int]], - sampling_params: SamplingParams, - multimodal_params: MultimodalParams, - request: Request, - ) -> Tuple[int, str, dict, FinishStatus]: - start_time = time.time() + async def loop_for_request(self): + assert self.args.node_rank > 0 + tasks = [] + self.request_order_queue = [] + while True: + ( + prompt, + sampling_params, + multimodal_params, + ) = await self.multinode_req_manager.recv_pyobj() + self.request_order_queue.append(sampling_params.group_request_id) + results_generator = self.generate(prompt, sampling_params, multimodal_params, None) + + async def generate_wrapper(results_generator): + async for _, _, _, _ in results_generator: + pass + + tasks.append(asyncio.create_task(generate_wrapper(results_generator))) + # cleanup + while len(tasks) > 0 and tasks[0].done(): + tasks.pop(0) + + def alloc_req_id(self, sampling_params): # 请求的 id 可以由外部传入,也可以由内部生成,但是由外部传入的时候,要自己保证全局唯一性 # 否则会造成异常问题。目前限制 NORMAL 模式都使用内部id替换, P 和 D 模式按需设置 if self.pd_mode == NodeRole.NORMAL: - group_request_id = self.id_gen.generate_id() + if not (self.nnodes > 1 and self.args.dp == 1): + group_request_id = self.id_gen.generate_id() + else: + if self.node_rank == 0: + group_request_id = self.id_gen.generate_id() + else: + assert sampling_params.group_request_id != -1 + group_request_id = sampling_params.group_request_id sampling_params.group_request_id = group_request_id elif self.pd_mode == NodeRole.P or self.pd_mode == NodeRole.D: assert sampling_params.group_request_id is not None, "p d mode, group_request_id must be setting" group_request_id = sampling_params.group_request_id else: assert False, "dead code path" + return group_request_id + + async def generate( + self, + prompt: Union[str, List[int]], + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + ) -> Tuple[int, str, dict, FinishStatus]: + start_time = time.time() + request_headers = request.headers if request is not None else {} + group_request_id = self.alloc_req_id(sampling_params) try: + old_multimodal_params = None + if self.nnodes > 1 and self.node_rank == 0 and self.args.dp == 1: + old_multimodal_params = copy.deepcopy(multimodal_params) + if self.pd_mode.is_P_or_NORMAL(): multimodal_params.verify_and_preload() # 记录请求到达的相关信息 - await self._log_req_header(request, group_request_id) + await self._log_req_header(request_headers, group_request_id) # 监控 self.metric_client.counter_inc("lightllm_request_count") @@ -191,11 +254,17 @@ async def generate( req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time) self.req_id_to_out_inf[group_request_id] = req_status - # 将请求转发给其他节点 - await self.transfer_to_next_module(req_status.group_req_objs) + await self.transfer_to_next_module_or_node( + prompt, sampling_params, old_multimodal_params, req_status.group_req_objs + ) results_generator = self._wait_to_token_package( - start_time, prompt_ids, group_request_id, sampling_params, req_status, request + start_time, + prompt_ids, + group_request_id, + sampling_params, + req_status, + request, ) async for sub_req_id, request_output, metadata, finish_status in results_generator: # p d 模式下,将 token 数据放入到转发队列中 @@ -210,10 +279,10 @@ async def generate( raise e return - async def _log_req_header(self, request: Request, group_request_id: int): + async def _log_req_header(self, request_headers, group_request_id: int): - x_request_id = request.headers.get("X-Request-Id", "") if request is not None else "" - x_session_id = request.headers.get("X-Session-Id", "") if request is not None else "" + x_request_id = request_headers.get("X-Request-Id", "") + x_session_id = request_headers.get("X-Session-Id", "") format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") logger.info( @@ -276,10 +345,44 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: return prompt_ids + async def transfer_to_next_module_or_node( + self, + prompt: str, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + group_req_objs: Optional[GroupReqObjs] = None, + ): + # 多节点纯tp 运行模式下,保证请求能保持相同的顺序转发到其他节点和当前节点next module. + if self.nnodes > 1 and self.node_rank == 0 and self.args.dp == 1: + async with self.transfer_lock: + for sender in self.multinode_req_manager: + sender.send_pyobj( + (prompt, sampling_params, multimodal_params), + protocol=pickle.HIGHEST_PROTOCOL, + ) + await self.transfer_to_next_module(group_req_objs) + return + + if self.nnodes > 1 and self.node_rank > 0 and self.args.dp == 1: + while True: + if self.request_order_queue and self.request_order_queue[0] != group_req_objs.group_req_id: + await asyncio.sleep(0.002) + continue + else: + async with self.transfer_lock: + await self.transfer_to_next_module(group_req_objs) + self.request_order_queue.pop(0) + break + return + + await self.transfer_to_next_module(group_req_objs) + return + async def transfer_to_next_module( self, - group_req_objs: GroupReqObjs, + group_req_objs: Optional[GroupReqObjs] = None, ): + if self.pd_mode == NodeRole.P: if self.enable_multimodal: self.send_to_visual.send_pyobj( @@ -340,7 +443,7 @@ async def _wait_to_token_package( except asyncio.TimeoutError: pass - if request is not None and await request.is_disconnected(): + if not self.disable_abort and request is not None and await request.is_disconnected(): await self.abort(group_request_id) raise Exception(f"req_id {group_request_id} disconnected") @@ -376,7 +479,6 @@ async def _wait_to_token_package( self.per_token_costs.add(mean_per_token_cost_time_ms) x_request_id = request.headers.get("X-Request-Id", "") if request is not None else "" x_session_id = request.headers.get("X-Session-Id", "") if request is not None else "" - prompt_cache_ratio = prompt_cache_len / prompt_tokens self.metric_client.histogram_observe("lightllm_cache_length", prompt_cache_len) self.metric_client.histogram_observe("lightllm_cache_ratio", prompt_cache_ratio) @@ -461,6 +563,9 @@ async def handle_loop(self): self.forwarding_queue = AsyncQueue() asyncio.create_task(self.pd_handle_loop()) + if self.args.node_rank > 0: + asyncio.create_task(self.loop_for_request()) + while True: try: await asyncio.wait_for(self.recv_from_detokenization.recv_pyobj(), timeout=0.05) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index b620aaabd..5def2cfd0 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -14,6 +14,7 @@ import zmq import zmq.asyncio import torch.multiprocessing as mp +import torch.distributed as dist import multiprocessing from typing import Dict, List, Optional from .batch import Batch @@ -32,6 +33,7 @@ from lightllm.common.mem_manager import ReadOnlyStaticsMemoryManager from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_unique_server_name logger = init_logger(__name__) @@ -41,18 +43,20 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr self.args = args self.model_weightdir = args.model_dir self.world_size = args.tp + self.nnodes = args.nnodes + self.node_rank = args.node_rank self.dp_size = args.dp self.load_way = args.load_way self.mode = args.mode self.max_total_token_num = args.max_total_token_num self.shm_req_manager = ShmReqManager() # 用共享内存进行共享,router 模块读取进行精确的调度估计 - self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager(args.nccl_port, args.tp) + self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager() # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None # 共享变量,用于存储router端调度分析得到的机器负载信息 - self.shared_token_load = TokenLoad(f"{str(args.nccl_port)}_shared_token_load", self.dp_size) + self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size) for dp_index in range(self.dp_size): self.shared_token_load.set_estimated_peak_token_count(0, dp_index) self.shared_token_load.set_frozened_token_count(0, dp_index) @@ -65,7 +69,6 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr self.eos_id = args.eos_id self.has_wait_tokens = 0 self.max_wait_tokens = args.router_max_wait_tokens - context = zmq.asyncio.Context(2) self.recv_from_httpserver = context.socket(zmq.PULL) self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{router_port}") @@ -74,6 +77,14 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr self.send_to_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{detokenization_port}") self.model_rpc_ports = model_rpc_ports + if args.nnodes > 1 and args.dp == 1: + self.mulitnode_group = dist.init_process_group( + backend="gloo", + init_method=f"tcp://{args.nccl_host}:{args.multinode_router_gloo_port}", + world_size=args.nnodes, + rank=args.node_rank, + ) + self.is_token_healing = self.args.token_healing_mode self.chunked_prefill_size = args.chunked_prefill_size self.enable_chunked_prefill = args.enable_chunked_prefill @@ -103,13 +114,16 @@ async def wait_to_model_ready(self): self.rpc_event = multiprocessing.Event() self.rpc_finished_event = multiprocessing.Event() - for rank_id in range(self.world_size): + assert (self.world_size % self.nnodes) == 0 + node_world_size = self.world_size // self.nnodes + for rank_id in range(self.node_rank * node_world_size, (self.node_rank + 1) * node_world_size): rpc_model = await start_model_process( args=self.args, - tp_rank=rank_id, + rank=rank_id, + rank_in_node=rank_id % node_world_size, + node_world_size=node_world_size, rpc_event=self.rpc_event, rpc_finished_event=self.rpc_finished_event, - world_size=self.world_size, info_queue=self.info_queue, mem_queue=self.mem_queues[rank_id], router_lock=self.router_lock, @@ -134,6 +148,7 @@ async def wait_to_model_ready(self): "mode": self.mode, "max_req_num": self.args.running_max_req_size + 8, "max_seq_length": self.args.max_req_total_len + 8, # 留一点余量 + "nccl_host": self.args.nccl_host, "nccl_port": self.args.nccl_port, "is_first_token_constraint_mode": self.args.first_token_constraint_mode, "enable_chunked_prefill": self.enable_chunked_prefill, @@ -162,7 +177,7 @@ async def wait_to_model_ready(self): self.args.max_total_token_num = self.max_total_token_num if self.args.use_dynamic_prompt_cache: self.radix_cache_client = RadixCacheReadOnlyClient( - str(self.args.nccl_port), self.max_total_token_num, tp_size=self.world_size + get_unique_server_name(), self.max_total_token_num, tp_size=self.world_size ) self.req_queue = build_req_queue(self.args, self, self.dp_size) logger.info(f"use req queue {self.req_queue.__class__.__name__}") @@ -248,10 +263,20 @@ async def get_schedule_result(self, running_batch: Batch): if self.schedule_task is None: def get_new_batch(): + limit_router_queue_length = None + if self.nnodes > 1 and self.args.dp == 1: + # 使用 all_reduce 获取最小值 + limit_router_queue_length = len(self.req_queue.waiting_req_list) + limit_router_queue_length_tensor = torch.tensor( + limit_router_queue_length, dtype=torch.int32, device="cpu" + ) + dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + limit_router_queue_length = limit_router_queue_length_tensor.item() + self.overlap_event.wait(timeout=0.020) self.overlap_event.clear() - time.sleep(0.003) # 这里是为了保证能正确进入推理的流程,保证折叠成功。 - new_batch = self.req_queue.generate_new_batch(running_batch) + time.sleep(0.003) + new_batch = self.req_queue.generate_new_batch(running_batch, limit_router_queue_length) return new_batch self.schedule_task = asyncio.get_running_loop().run_in_executor(self.overlap_thread_pool, get_new_batch) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index f9503d474..cc0a97b5c 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -3,6 +3,7 @@ import numpy as np import rpyc import torch +import socket from datetime import timedelta from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig @@ -37,9 +38,13 @@ from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock -from lightllm.utils.device_utils import set_current_device_id +from lightllm.utils.dist_utils import _init_distributed_env +from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.utils.dist_utils import get_global_rank, get_global_world_size, get_dp_size +from lightllm.utils.dist_utils import get_dp_world_size, get_current_dp_rank, get_current_rank_in_dp +from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size import torch.distributed as dist @@ -49,12 +54,13 @@ def __init__(self) -> None: pass def init_model(self, kvargs): - self.args = kvargs.get("args", None) # p d 分离模式下会有特殊的一些初始化, 所以需要传递 # 模式参数到模型的初始化过程中进行控制 self.run_mode = "normal" if self.args is None else self.args.run_mode self.is_multimodal = False + self.nnodes = self.args.nnodes + self.node_rank = self.args.node_rank self.tp_rank = kvargs["rank_id"] self.world_size = kvargs["world_size"] self.dp_size = kvargs.get("dp_size", 1) @@ -71,8 +77,6 @@ def init_model(self, kvargs): self.logger = init_logger(__name__) self.weight_dir = kvargs["weight_dir"] - nccl_port_str = str(kvargs["nccl_port"]) - self.shared_token_load = TokenLoad(f"{nccl_port_str}_shared_token_load", self.dp_size) # p d 分离模式,decode节点才会使用的参数 self.pd_rpyc_ports = kvargs.get("pd_rpyc_ports", None) max_total_token_num = kvargs["max_total_token_num"] @@ -81,19 +85,10 @@ def init_model(self, kvargs): assert self.dp_size == self.world_size, "Currently only self-sustaining dp_size == tp_size" os.environ["ENABLE_DP"] = "1" - torch.cuda.set_device(self.tp_rank) - set_current_device_id(self.tp_rank) + _init_distributed_env(kvargs) + self.init_rank_infos() - dist.init_process_group( - "nccl", - init_method=f'tcp://127.0.0.1:{kvargs["nccl_port"]}', - rank=self.tp_rank, - world_size=self.world_size, - ) - # warmup nccl communicator - _a = torch.zeros([1]).to(f"cuda:{self.tp_rank}") - dist.all_reduce(_a) - del _a + self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size) from lightllm.distributed import custom_comm_ops @@ -102,7 +97,7 @@ def init_model(self, kvargs): # 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在 # init_process_group 之后调用 - g_infer_state_lock.obj = InferStateLock(name=nccl_port_str) + g_infer_state_lock.obj = InferStateLock(name=get_unique_server_name()) g_infer_state_lock.dp_size = self.dp_size self.infer_state_lock = g_infer_state_lock # 防止InferStateLock 中的全局共享信息被重复异常初始化,导致同步异常的问题。 @@ -207,7 +202,7 @@ def init_model(self, kvargs): set_random_seed(2147483647) self.radix_cache = ( RadixCache( - str(kvargs["nccl_port"]), self.model.mem_manager.size, self.tp_rank, mem_manager=self.model.mem_manager + get_unique_server_name(), self.model.mem_manager.size, self.tp_rank, mem_manager=self.model.mem_manager ) if self.use_dynamic_prompt_cache else None @@ -279,3 +274,26 @@ def preload_prompt_cache_kv_buffer(self, model_cfg): self.radix_cache.match_prefix( torch.tensor(model_cfg["prompt_cache_token_ids"], dtype=torch.int64, device="cpu"), update_refs=True ) + + def init_rank_infos(self): + self.node_world_size = get_node_world_size() + self.rank_in_node = get_current_rank_in_node() + self.current_device_id = get_current_device_id() + self.rank_in_dp = get_current_rank_in_dp() + self.dp_rank = get_current_dp_rank() + self.dp_world_size = get_dp_world_size() + self.global_rank = get_global_rank() + self.global_world_size = get_global_world_size() + self.dp_size = get_dp_size() + + if self.nnodes > 1 and self.dp_size == 1: + if self.rank_in_node == 0: + self.is_master_in_dp = True + else: + self.is_master_in_dp = False + else: + if self.rank_in_dp == 0: + self.is_master_in_dp = True + else: + self.is_master_in_dp = False + return diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 3e91c5bf7..34d60cb22 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -54,7 +54,7 @@ def post_handel(self, run_reqs: List[InferReq], next_token_ids, next_token_logpr req_obj.cur_kv_len = len(req_obj.get_chuncked_input_token_ids()) if req_obj.cur_kv_len < req_obj.get_cur_total_len(): - if self.tp_rank < self.dp_size: + if self.is_master_in_dp: req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len continue @@ -67,7 +67,7 @@ def post_handel(self, run_reqs: List[InferReq], next_token_ids, next_token_logpr if req_obj.finish_status.is_finished() or req_obj.shm_req.router_aborted: finished_req_ids.append(req_obj.shm_req.request_id) - if self.tp_rank < self.dp_size: + if self.is_master_in_dp: # shm_cur_kv_len shm_cur_output_len 是 router 调度进程需要读的信息 # finish_token_index finish_status candetoken_out_len 是 # detokenization 进程需要的信息,注意这些变量的写入顺序避免异步协同问题。 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py index 3c09465b8..3884d5fa6 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py @@ -55,7 +55,7 @@ def post_handel(self, run_reqs: List[InferReq], next_token_ids, next_token_logpr if req_obj.finish_status.is_finished() or req_obj.shm_req.router_aborted: finished_req_ids.append(req_obj.shm_req.request_id) - if self.tp_rank < self.dp_size: + if self.is_master_in_dp: # shm_cur_kv_len shm_cur_output_len 是 router 调度进程需要读的信息 # finish_token_index finish_status candetoken_out_len 是 # detokenization 进程需要的信息,注意这些变量的写入顺序避免异步协同问题。 diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 2a63e6b21..1a48f182b 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -32,7 +32,9 @@ class ModelRpcServer: def __init__( self, args, - tp_rank: int, + rank: int, + rank_in_node: int, + node_world_size: int, rpc_event: multiprocessing.Event, rpc_finished_event: multiprocessing.Event, info_queue: mp.Queue, @@ -40,7 +42,7 @@ def __init__( ): super().__init__() self.args = args - self.world_size = self.args.tp + self.node_world_size = node_world_size self.info_queue = info_queue self.mem_queue = mem_queue self.rpc_event = rpc_event @@ -50,10 +52,12 @@ def __init__( self.rpc_shm_params.create_or_link_shm() self.rpc_shm_results = RpcShmResults() self.rpc_shm_results.create_or_link_shm() - self.rpc_shm_sync_status = ShmSyncStatusArray(self.world_size) + self.rpc_shm_sync_status = ShmSyncStatusArray(self.node_world_size) self.rpc_shm_sync_status.create_or_link_shm() - self.tp_rank = tp_rank + self.rank = rank + self.rank_in_node = rank_in_node + logger.info(f"Initialized RPC server for rank {self.rank}.") # 多卡才是跨进程的 if self.args.tp != 1: @@ -69,21 +73,21 @@ def rpc_loop(self): func_name, args = self.rpc_shm_params.read_func_params() ans = getattr(self, func_name)(*args) - if ans is not None and self.tp_rank == 0: + if ans is not None and self.rank_in_node == 0: self.rpc_shm_results.write_func_result(func_name=func_name, ret=ans) # 下面得执行顺序不可随意交换, 否则容易出现同步或者死锁问题。 - self.rpc_shm_sync_status.add_mark(self.tp_rank) + self.rpc_shm_sync_status.add_mark(self.rank_in_node) while not self.rpc_shm_sync_status.run_finished(): pass self.rpc_event.clear() - self.rpc_shm_sync_status.add_mark1(self.tp_rank) + self.rpc_shm_sync_status.add_mark1(self.rank_in_node) while not self.rpc_shm_sync_status.run_finished1(): pass - if self.tp_rank == 0: + if self.rank_in_node == 0: self.rpc_finished_event.set() except BaseException as e: @@ -98,7 +102,7 @@ def rpc_loop(self): def init_model(self, kvargs): # 填充真正的 rank_id 参数 - kvargs["rank_id"] = self.tp_rank + kvargs["rank_id"] = self.rank self.world_size = kvargs["world_size"] enable_chunked_prefill = kvargs.get("enable_chunked_prefill", False) return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False) @@ -258,7 +262,9 @@ async def get_max_total_token_num(self): def _init_env( args, - tp_rank, + rank, + rank_in_node, + node_world_size, info_queue, mem_queue, router_lock, @@ -277,7 +283,9 @@ def _init_env( g_router_lock.obj = router_lock - model_rpc_server = ModelRpcServer(args, tp_rank, rpc_event, rpc_finished_event, info_queue, mem_queue) + model_rpc_server = ModelRpcServer( + args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, mem_queue + ) success_event.set() model_rpc_server.loop_thread.join() @@ -286,10 +294,11 @@ def _init_env( async def start_model_process( args, - tp_rank, + rank, + rank_in_node, + node_world_size, rpc_event, rpc_finished_event, - world_size, info_queue: mp.Queue, mem_queue: mp.Queue, router_lock: mp.Queue, @@ -297,13 +306,33 @@ async def start_model_process( import lightllm.utils.rpyc_fix_utils as _ # 单卡时不使用 rpc - if world_size == 1: - return ModelRpcServer(args, tp_rank, rpc_event, rpc_finished_event, info_queue, mem_queue) + if node_world_size == 1 and args.nnodes == 1: + return ModelRpcServer( + args, + rank, + rank_in_node, + node_world_size, + rpc_event, + rpc_finished_event, + info_queue, + mem_queue, + ) success_event = mp.Event() proc = mp.Process( target=_init_env, - args=(args, tp_rank, info_queue, mem_queue, router_lock, rpc_event, rpc_finished_event, success_event), + args=( + args, + rank, + rank_in_node, + node_world_size, + info_queue, + mem_queue, + router_lock, + rpc_event, + rpc_finished_event, + success_event, + ), ) proc.start() success_event.wait(timeout=40) diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index ff919c2f7..f2d3d35a7 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -69,7 +69,15 @@ def get_batch_dp_req_size(self, current_batch: Batch): return len([req for req in current_batch.reqs if req.sample_params.suggested_dp_index == self.dp_index]) - def generate_new_batch(self, current_batch: Batch): + def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None): + """ + args: + current_batch: current batch + limit_router_queue_length: the least length of waiting list across all nodes. + It only works when nnodes > 1 and dp_size == 1. + return: + new batch + """ raise NotImplementedError() def calcu_batch_token_load(self, current_batch: Batch): diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 313d6f68b..9a3d352d9 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -56,7 +56,7 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens return False, new_batch_first_router_need_tokens # @calculate_time(show=True, min_cost_ms=10) - def generate_new_batch(self, current_batch: Batch): + def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None): # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 exist_req_num = self.get_batch_dp_req_size(current_batch) + len(self.pause_req_dict) @@ -74,7 +74,13 @@ def generate_new_batch(self, current_batch: Batch): can_run_list = [] abort_req_list = [] aborted_count = 0 - for req in self.waiting_req_list: + + if limit_router_queue_length is None: + waiting_queue = self.waiting_req_list + else: + waiting_queue = self.waiting_req_list[:limit_router_queue_length] + + for req in waiting_queue: if req.is_aborted and not req.is_paused: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 diff --git a/lightllm/server/router/req_queue/continues_batch/beam_impl.py b/lightllm/server/router/req_queue/continues_batch/beam_impl.py index d695da9d8..bb0056695 100644 --- a/lightllm/server/router/req_queue/continues_batch/beam_impl.py +++ b/lightllm/server/router/req_queue/continues_batch/beam_impl.py @@ -76,7 +76,7 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new return False, new_batch_first_router_need_tokens # @calculate_time(show=True, min_cost_ms=10) - def generate_new_batch(self, current_batch: Batch): + def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None): # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 exist_req_num = self.get_batch_dp_req_size(current_batch) + len(self.pause_req_dict) req_is_full = exist_req_num >= self.running_max_req_size diff --git a/lightllm/server/router/req_queue/continues_batch/impl.py b/lightllm/server/router/req_queue/continues_batch/impl.py index 96eb86efe..cab9759cf 100644 --- a/lightllm/server/router/req_queue/continues_batch/impl.py +++ b/lightllm/server/router/req_queue/continues_batch/impl.py @@ -61,7 +61,7 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens return False, new_batch_first_router_need_tokens # @calculate_time(show=True, min_cost_ms=10) - def generate_new_batch(self, current_batch: Batch): + def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None): # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 exist_req_num = self.get_batch_dp_req_size(current_batch) + len(self.pause_req_dict) req_is_full = exist_req_num >= self.running_max_req_size @@ -76,7 +76,13 @@ def generate_new_batch(self, current_batch: Batch): abort_req_list = [] new_batch_first_router_need_tokens = 0 # 主要是对 prefill 大块计算时候的token数量限制 aborted_count = 0 - for req in self.waiting_req_list: + + if limit_router_queue_length is None: + waiting_queue = self.waiting_req_list + else: + waiting_queue = self.waiting_req_list[:limit_router_queue_length] + + for req in waiting_queue: if req.is_aborted and not req.is_paused: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token和管理req对象的泄漏 diff --git a/lightllm/server/router/req_queue/continues_batch/pd_decode_impl.py b/lightllm/server/router/req_queue/continues_batch/pd_decode_impl.py index a4ff1059a..86bc2e3b3 100644 --- a/lightllm/server/router/req_queue/continues_batch/pd_decode_impl.py +++ b/lightllm/server/router/req_queue/continues_batch/pd_decode_impl.py @@ -24,7 +24,7 @@ def _init_cache_list(self, current_batch: Batch, is_busy): return # @calculate_time(show=True, min_cost_ms=10) - def generate_new_batch(self, current_batch: Batch): + def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None): # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 exist_req_num = self.get_batch_dp_req_size(current_batch) + len(self.pause_req_dict) req_is_full = exist_req_num >= self.running_max_req_size diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index 84b599634..c987998ee 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -27,7 +27,7 @@ def get_wait_req_num(self): return sum(queue.get_wait_req_num() for queue in self.inner_queues) # @calculate_time(show=True, min_cost_ms=10) - def generate_new_batch(self, current_batch: Batch): + def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None): batches = [self.inner_queues[dp_index].generate_new_batch(current_batch) for dp_index in range(self.dp_size)] return self._merge_batch(batches) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 6679c2a73..5605eac30 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -100,15 +100,13 @@ async def loop_for_fwd(self): if len(self.waiting_reqs) == 0: await asyncio.sleep(0.01) # 10ms else: - cur_batch_size = 0 - reqs_need_infer = [] + processing_group_reqs = [] uuids_need_infer = [] - while cur_batch_size < self.infer_batch_size and len(self.waiting_reqs) > 0: + while len(self.waiting_reqs) > 0: group_req_indexes = self.waiting_reqs.pop(0) shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) is_aborted = shm_req.is_aborted self.shm_req_manager.put_back_req_obj(shm_req) - if is_aborted: # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 @@ -118,33 +116,28 @@ async def loop_for_fwd(self): multimodal_params = group_req_indexes.multimodal_params - cur_uuids_need_infer = [] for img in multimodal_params.images: if not self.cache_client.root.get_item_embed(img.uuid): - cur_batch_size += 1 - cur_uuids_need_infer.append(img.uuid) + uuids_need_infer.append(img.uuid) + + if len(uuids_need_infer) == self.infer_batch_size: + await self.infer_imgs(uuids_need_infer) + uuids_need_infer = [] + for _group_req_indexes in processing_group_reqs: + self.send_to_router.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + processing_group_reqs = [] - if len(cur_uuids_need_infer) == 0: + if len(uuids_need_infer) == 0: self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) else: - uuids_need_infer.extend(cur_uuids_need_infer) - reqs_need_infer.append((group_req_indexes, len(uuids_need_infer) - 1)) - - for start_index in range(0, len(uuids_need_infer), self.infer_batch_size): - await self.infer_imgs(uuids_need_infer[start_index : (start_index + self.infer_batch_size)]) - finished_req_indexes = [ - group_req_indexes - for group_req_indexes, mark_index in reqs_need_infer - if mark_index < start_index + self.infer_batch_size - ] - reqs_need_infer = [ - (group_req_indexes, mark_index) - for group_req_indexes, mark_index in reqs_need_infer - if mark_index >= start_index + self.infer_batch_size - ] - - for group_req_indexes in finished_req_indexes: - self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + processing_group_reqs.append(group_req_indexes) + + if len(uuids_need_infer) > 0: + await self.infer_imgs(uuids_need_infer) + for _group_req_indexes in processing_group_reqs: + self.send_to_router.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + processing_group_reqs = [] + uuids_need_infer = [] async def loop_for_netio_req(self): while True: diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 64ca02f8e..db5337f3d 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -16,7 +16,7 @@ from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end -from lightllm.utils.device_utils import set_current_device_id +from lightllm.utils.dist_utils import _init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry @@ -31,20 +31,11 @@ def exposed_init_model(self, kvargs): self.tp_rank_id = kvargs["tp_rank_id"] self.cache_port = kvargs["cache_port"] weight_dir = kvargs["weight_dir"] - visual_gpu_ids = kvargs["visual_gpu_ids"] - visual_nccl_port = kvargs["visual_nccl_port"] self.vit_rank_id = kvargs["vit_rank_id"] self.cache_client = rpyc.connect("localhost", self.cache_port) self.data_type = kvargs["data_type"] - torch.cuda.set_device(visual_gpu_ids[self.vit_rank_id]) - set_current_device_id(visual_gpu_ids[self.vit_rank_id]) - dist.init_process_group( - backend="nccl", - init_method=f"tcp://127.0.0.1:{visual_nccl_port}", - rank=self.tp_rank_id, - world_size=self.vit_tp, - ) + _init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) try: diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 382edc4aa..816e1e9a9 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -3,23 +3,6 @@ import subprocess -def set_current_device_id(device_id: int): - os.environ["CURRENT_DEVICE_ID"] = str(device_id) - - -@lru_cache(maxsize=None) -def get_current_device_id(): - import torch - - if torch.cuda.is_available(): - device_id = os.getenv("CURRENT_DEVICE_ID", None) - if device_id is None: - raise RuntimeError("set_current_device_id must called first to set current device") - return int(device_id) - else: - raise RuntimeError("Torch CUDA is not avaliable.") - - @lru_cache(maxsize=None) def get_device_sm_count(): import triton diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 5264095b8..199216814 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -1,15 +1,148 @@ import torch.distributed as dist +import os +import torch +# 规范 rank 的含义,在 llm 推理的相关代码中下述的 rank 的含义如下: +# global_rank 全局 rank 序列id, 如两节点 8卡,会存在 0 - 15 16个global_rank +# global_world_size 全局的 world size 大小, 如两节点 8 卡,该值为 16 +# dp_size 如果部署形态是一个推理实列包含几个数据并行的推理实列,则 dp size 为整个系统中的 dp 并行数量 +# dp_world_size 每一个dp 数据并行占用的卡数 +# dp_rank 指每个dp 数据并行在整个推理实列中dp的rank号, 如果 16卡部署,4 dp size, 则存在 0 - 3 4个dp_rank +# 值,其中 0-3号卡为 dp_rank 0, 4-8 为 dp_rank 1, 9-12 为dp_rank 2, 13-15为dp_rank 3 +# rank_in_dp 指在一个dp内的rank序号。 +# node_world_size 指一个推理节点的使用的卡数,如两机 tp 推理,如果两机器8卡,则 node_world_size 为 8. +# rank_in_node 指在一个node内的rank序号,如两机8卡推理,每机上的rank序号都是0-8 -def get_world_size(): - if dist.is_initialized(): - return dist.get_world_size() - else: - raise RuntimeError("Distributed package is not initialized.") +def set_environ(environ_name, value): + os.environ[environ_name] = str(value) -def get_rank(): - if dist.is_initialized(): - return dist.get_rank() - else: - raise RuntimeError("Distributed package is not initialized.") + +def get_environ(environ_name): + value = os.getenv(environ_name, None) + if value is None: + raise RuntimeError(f"{environ_name} is not set") + return value + + +def _init_vision_distributed_env(kvargs): + world_size = kvargs["vit_tp"] + set_global_rank(kvargs["tp_rank_id"]) + set_global_world_size(world_size) + visual_gpu_ids = kvargs["visual_gpu_ids"] + device_id = visual_gpu_ids[kvargs["vit_rank_id"]] + set_current_device_id(device_id) + torch.cuda.set_device(device_id) + dist.init_process_group( + "nccl", + init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', + rank=kvargs["tp_rank_id"], + world_size=world_size, + ) + # warmup nccl communicator + _a = torch.zeros([1]).to(f"cuda:{device_id}") + dist.all_reduce(_a) + del _a + + +def _init_distributed_env(kvargs): + assert kvargs["world_size"] % kvargs["args"].nnodes == 0, "world_size should be divided by nnodes" + node_world_size = kvargs["world_size"] // kvargs["args"].nnodes + + set_global_rank(kvargs["rank_id"]) + set_global_world_size(kvargs["world_size"]) + set_dp_size(kvargs.get("dp_size", 1)) + set_dp_world_size(get_global_world_size() // get_dp_size()) + set_current_dp_rank(get_global_rank() // get_dp_world_size()) + set_current_rank_in_dp(get_global_rank() % get_dp_world_size()) + set_current_rank_in_node(get_global_rank() % node_world_size) + set_node_world_size(node_world_size) + + device_id = kvargs["rank_id"] % get_node_world_size() + set_current_device_id(device_id) + torch.cuda.set_device(device_id) + dist.init_process_group( + "nccl", + init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}', + rank=kvargs["rank_id"], + world_size=kvargs["world_size"], + ) + # warmup nccl communicator + _a = torch.zeros([1]).to(f"cuda:{device_id}") + dist.all_reduce(_a) + del _a + + +def set_global_rank(global_rank: int): + set_environ("LIGHTLLM_GLOBAL_RANK", global_rank) + + +def get_global_rank(): + return int(get_environ("LIGHTLLM_GLOBAL_RANK")) + + +def set_global_world_size(world_size: int): + set_environ("LIGHTLLM_GLOBAL_WORLD_SIZE", world_size) + + +def get_global_world_size(): + return int(get_environ("LIGHTLLM_GLOBAL_WORLD_SIZE")) + + +def set_dp_size(dp_size: int): + """ + total dp num + """ + set_environ("LIGHTLLM_DP_SIZE", dp_size) + + +def get_dp_size(): + return int(get_environ("LIGHTLLM_DP_SIZE")) + + +def set_dp_world_size(world_size: int): + set_environ("LIGHTLLM_DP_WORLD_SIZE", world_size) + + +def get_dp_world_size(): + return int(get_environ("LIGHTLLM_DP_WORLD_SIZE")) + + +def set_current_dp_rank(rank: int): + set_environ("LIGHTLLM_CURRENT_DP_RANK", rank) + + +def get_current_dp_rank(): + return int(get_environ("LIGHTLLM_CURRENT_DP_RANK")) + + +def set_current_rank_in_dp(rank: int): + set_environ("LIGHTLLM_CURRENT_RANK_IN_DP", rank) + + +def get_current_rank_in_dp(): + return int(get_environ("LIGHTLLM_CURRENT_RANK_IN_DP")) + + +def set_current_device_id(device_id: int): + set_environ("LIGHTLLM_CURRENT_DEVICE_ID", device_id) + + +def get_current_device_id(): + return int(get_environ("LIGHTLLM_CURRENT_DEVICE_ID")) + + +def set_current_rank_in_node(rank: int): + set_environ("LIGHTLLM_CURRENT_RANK_IN_NODE", rank) + + +def get_current_rank_in_node(): + return int(get_environ("LIGHTLLM_CURRENT_RANK_IN_NODE")) + + +def set_node_world_size(node_world_size: int): + set_environ("LIGHTLLM_NODE_WORLD_SIZE", node_world_size) + + +def get_node_world_size(): + return int(get_environ("LIGHTLLM_NODE_WORLD_SIZE")) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index c606fd238..d16b298b3 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -8,7 +8,7 @@ def set_unique_server_name(args): - os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + "_" + str(args.node_rank) return diff --git a/lightllm/utils/multinode_utils.py b/lightllm/utils/multinode_utils.py new file mode 100644 index 000000000..b356eff0f --- /dev/null +++ b/lightllm/utils/multinode_utils.py @@ -0,0 +1,30 @@ +import zmq +import socket +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def send_and_receive_node_ip(args): + # 传输子node的ip + if args.nnodes > 1: + + if args.node_rank == 0: + args.child_ips = None + args.child_ips = [] + for i in range(1, args.nnodes): + context = zmq.Context(2) + comm_socket = context.socket(zmq.PULL) + comm_socket.bind(f"tcp://*:{args.multinode_httpmanager_port + i + 100}") + logger.info(f"binding port {args.multinode_httpmanager_port + i + 100}") + args.child_ips.append(comm_socket.recv_pyobj()) + comm_socket.close() + logger.info(f"Received child IPs: {args.child_ips}") + else: + local_ip = socket.gethostbyname(socket.gethostname()) + context = zmq.Context(2) + comm_socket = context.socket(zmq.PUSH) + comm_socket.connect(f"tcp://{args.nccl_host}:{args.multinode_httpmanager_port + args.node_rank + 100}") + logger.info(f"connecting to {args.nccl_host}:{args.multinode_httpmanager_port + args.node_rank + 100}") + comm_socket.send_pyobj(local_ip) + comm_socket.close() diff --git a/lightllm/utils/profile_max_tokens.py b/lightllm/utils/profile_max_tokens.py index 7f54d7ecf..e3a62b62e 100644 --- a/lightllm/utils/profile_max_tokens.py +++ b/lightllm/utils/profile_max_tokens.py @@ -5,18 +5,19 @@ from transformers import AutoModelForCausalLM import argparse from lightllm.common.build_utils import repair_config +from lightllm.utils.dist_utils import get_current_device_id data_type_dict = {"float32": 4, "float16": 2, "bfloat16": 2, "fp32": 4, "fp16": 2, "bf16": 2, "int8": 1, "int4": 0.5} -def get_available_gpu_memory(tp_rank, world_size): +def get_available_gpu_memory(world_size): """ Get available memory. """ torch.cuda.empty_cache() - free_gpu_memory, _ = torch.cuda.mem_get_info(tp_rank) + free_gpu_memory, _ = torch.cuda.mem_get_info(get_current_device_id()) if world_size > 1: - tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(f"cuda:{tp_rank}") + tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(f"cuda:{get_current_device_id()}") torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN) free_gpu_memory = tensor.item() return free_gpu_memory / (1024 ** 3) diff --git a/test/model/model_infer.py b/test/model/model_infer.py index 7f8084a7b..265ad38e3 100644 --- a/test/model/model_infer.py +++ b/test/model/model_infer.py @@ -39,7 +39,7 @@ def test_model_inference(world_size, model_class, batch_size, input_len, output_ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_len, ans_queue): import torch from lightllm.distributed import custom_comm_ops - from lightllm.utils.device_utils import set_current_device_id + from lightllm.utils.dist_utils import set_current_device_id import torch.distributed as dist