Skip to content

add support for multinode tp #751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 52 commits into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
65428a1
supporting multinode
jayfeather9 Feb 16, 2025
a26d48d
fix format
jayfeather9 Feb 16, 2025
64ccb4f
add cuda() with device id
jayfeather9 Feb 16, 2025
b3de424
Merge branch 'main' into multinode
Feb 18, 2025
3fd6e48
fix multinode abort
shihaobai Feb 20, 2025
c0b1146
support chunked prefill
shihaobai Feb 20, 2025
855e037
Merge branch 'main' into multinode
Feb 20, 2025
8ad585a
modify dist_utils & remove child_ips
Feb 21, 2025
e0844d3
Merge branch 'multinode' of https://github.com/ModelTC/lightllm into …
Feb 21, 2025
fa3f826
fix chunked_prefill for multinode
shihaobai Feb 23, 2025
7e92ff6
Merge branch 'multinode' of https://github.com/ModelTC/lightllm into …
shihaobai Feb 23, 2025
4a83a6a
merge main
shihaobai Feb 23, 2025
461ec65
fix health
shihaobai Feb 23, 2025
b15a487
update port
shihaobai Feb 24, 2025
4e419df
修改 rank 配置。
hiworldwzj Feb 26, 2025
8237f19
refactor multinode
shihaobai Feb 27, 2025
b55edd7
Merge branch 'main' into multinode
shihaobai Feb 27, 2025
958b83d
fix get_dp_size
shihaobai Feb 27, 2025
fffb99e
remove tp_rank of get_available_gpu_memory
shihaobai Feb 27, 2025
ecd495c
fix chunked prefill
shihaobai Feb 27, 2025
82a756f
fix dist_utils
shihaobai Feb 27, 2025
814f095
multinode utils
shihaobai Feb 27, 2025
b33b3b4
update router mulitnode mananger
shihaobai Feb 27, 2025
296a579
fix chunked prefill
shihaobai Feb 27, 2025
2afd14b
refomat
shihaobai Feb 27, 2025
7061bfb
fix
hiworldwzj Feb 27, 2025
0bde847
refactor order
shihaobai Feb 28, 2025
39b90bf
fix
shihaobai Feb 28, 2025
429f9c3
update httpserver sync
shihaobai Feb 28, 2025
4377c20
update
shihaobai Feb 28, 2025
4abb4a1
remove cudagraph_step_length
shihaobai Feb 28, 2025
7646d6e
modify the default value of current_waiting_num
shihaobai Feb 28, 2025
8446544
update
shihaobai Feb 28, 2025
877b98f
fix
shihaobai Feb 28, 2025
8ff7ed6
fix visualserver
shihaobai Feb 28, 2025
249dea7
fix
shihaobai Feb 28, 2025
64bdc11
update mem_manager
shihaobai Feb 28, 2025
d68e0d7
fix start rank params.
hiworldwzj Feb 28, 2025
2461069
fix
shihaobai Feb 28, 2025
b3dbecd
fix
hiworldwzj Mar 1, 2025
ea4dc98
fix
hiworldwzj Mar 1, 2025
1cffab4
fix
hiworldwzj Mar 1, 2025
4f68164
update docs
shihaobai Mar 1, 2025
2b3f07a
fix
hiworldwzj Mar 1, 2025
068663a
fix
hiworldwzj Mar 1, 2025
d14fcea
fix
hiworldwzj Mar 1, 2025
46dbbb2
fix
shihaobai Mar 1, 2025
72f4eb3
fix
hiworldwzj Mar 1, 2025
a640f72
fix
hiworldwzj Mar 1, 2025
5cd5dbf
fix
hiworldwzj Mar 1, 2025
5d13dc8
fix
hiworldwzj Mar 1, 2025
ce4d3eb
reformat
shihaobai Mar 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/CN/source/getting_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@
.. note::
上面代码中的 ``--model_dir`` 参数需要修改为你本机实际的模型路径。

单机H200部署 DeepSeek-R1 模型, 启动命令如下:

.. code-block:: console

$ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 8 --graph_max_batch_size 100

.. note::
LOADWORKER 指定了模型加载的线程,可以提高模型加载的速度。--graph_max_batch_size 指定了要捕获的cudagraph的数量,将捕获从1到100的batch size的图。

双机H100部署 DeepSeek-R1 模型,启动命令如下:

.. code-block:: console
$ # Node 0
$ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 16 --graph_max_batch_size 100 --nccl_host master_addr --nnodes 2 --node_rank 0
$ # Node 1
$ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 16 --graph_max_batch_size 100 --nccl_host master_addr --nnodes 2 --node_rank 1

3. (可选)测试模型服务
-------------------------
Expand All @@ -75,3 +91,10 @@
$ }'


对于DeepSeek-R1模型,可以用如下脚本进行测试:

.. code-block:: console

$ cd test
$ python benchmark_client.py --num_clients 100 --input_num 2000 --tokenizer_path /nvme/DeepSeek-R1/ --url http://127.0.01:8000/generate_stream

10 changes: 9 additions & 1 deletion docs/EN/source/getting_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ After downloading the Llama-2-7b-chat model, use the following command in the te
.. note::
The ``--model_dir`` parameter in the above command should be changed to the actual path of your model on your machine.

For the DeepSeek-R1 model on H200, it can be launched with the following command:
For the DeepSeek-R1 model on single H200, it can be launched with the following command:

.. code-block:: console

Expand All @@ -62,6 +62,14 @@ For the DeepSeek-R1 model on H200, it can be launched with the following command
.. note::
LOADWORKER specifies the thread for model loading, which can enhance the speed of model loading. The --graph_max_batch_size parameter specifies the number of cudagraphs to be captured, which will capture graphs for batch sizes ranging from 1 to 100.

For the DeepSeek-R1 model on two H100, it can be launched with the following command:

.. code-block:: console
$ # Node 0
$ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 16 --graph_max_batch_size 100 --nccl_host master_addr --nnodes 2 --node_rank 0
$ # Node 1
$ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 16 --graph_max_batch_size 100 --nccl_host master_addr --nnodes 2 --node_rank 1


3. (Optional) Test the Model Service
--------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion lightllm/common/basemodel/layer_weights/base_layer_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import threading
from lightllm.common.basemodel.layer_weights.meta_weights import BaseWeight
from lightllm.utils.dist_utils import get_current_device_id


class BaseLayerWeight:
Expand Down Expand Up @@ -37,4 +38,4 @@ def _cuda(self, cpu_tensor):
if self.tp_rank_ is None:
return cpu_tensor.contiguous().to(self.data_type_).cuda()
else:
return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_)
return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id())
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/layer_weights/hf_load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import gc
from safetensors import safe_open
import lightllm.utils.petrel_helper as utils
from lightllm.utils.dist_utils import get_current_device_id


def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None):
# fix bug for 多线程加载的时候,每个线程内部的cuda device 会切回 0, 修改后来保证不会出现bug
import torch.distributed as dist

tp_rank = dist.get_rank()
torch.cuda.set_device(tp_rank)
torch.cuda.set_device(get_current_device_id())

if use_safetensors:
weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from abc import ABC, abstractmethod
from lightllm.utils.dist_utils import get_world_size, get_rank
from lightllm.utils.device_utils import get_current_device_id
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id


class BaseWeight(ABC):
Expand All @@ -19,8 +18,8 @@ def verify_load(self):

class BaseWeightTpl(BaseWeight):
def __init__(self):
self.world_size_ = get_world_size()
self.tp_rank_ = get_rank()
self.world_size_ = get_global_world_size()
self.tp_rank_ = get_global_rank()
self.device_id_ = get_current_device_id()

def load_hf_weights(self, weights):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from .base_weight import BaseWeight
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
from lightllm.common.quantization.quantize_method import QuantizationMethod
from lightllm.utils.dist_utils import get_world_size, get_rank
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id
from lightllm.common.vllm_kernel import _custom_ops as ops
from lightllm.utils.device_utils import get_current_device_id


class FusedMoeWeight(BaseWeight):
Expand Down Expand Up @@ -39,7 +38,7 @@ def __init__(
self.n_routed_experts = n_routed_experts
self.split_inter_size = split_inter_size
self.data_type_ = data_type
self.tp_rank_ = get_rank()
self.tp_rank_ = get_global_rank()
self.experts_up_projs = [None] * self.n_routed_experts
self.experts_gate_projs = [None] * self.n_routed_experts
self.experts_up_proj_scales = [None] * self.n_routed_experts
Expand Down Expand Up @@ -159,7 +158,7 @@ def _fuse_weight_scale(self):
delattr(self, "experts_gate_proj_scales")

def _load_hf_weights_etp(self, weights):
world_size_ = get_world_size()
world_size_ = get_global_world_size()
assert self.n_routed_experts % world_size_ == 0
n_expert_ep = self.n_routed_experts // world_size_

Expand Down
36 changes: 17 additions & 19 deletions lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional, Tuple, List, Dict, Any
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.common.quantization.quantize_method import QuantizationMethod
from lightllm.utils.dist_utils import get_current_device_id


def generate_scale_name(name, weight_scale_suffix, act_scale_suffix):
Expand Down Expand Up @@ -73,20 +74,17 @@ def _post_load_weights(self) -> None:
and (not self.static_activation or self.input_scale is not None)
):
if self.weight_scale.ndim > 1:
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
self.weight_scale = self.weight_scale.cuda(self.device_id_).transpose(0, 1)
self.weight_scale = self.weight_scale.transpose(0, 1).cuda(get_current_device_id())
self.weight = [
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
self.weight.cuda(self.device_id_).transpose(0, 1),
self.weight.cuda(get_current_device_id()).transpose(0, 1),
self.weight_scale,
self.input_scale,
]
else:
self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.device_id_))
self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(get_current_device_id()))
return

# 让 k dim 更连续,大多数split k 算法的算子可能能更快
self.weight = self.weight.to(self.data_type_).cuda(self.device_id_).transpose(0, 1)
self.weight = self.weight.to(self.data_type_).cuda(get_current_device_id()).transpose(0, 1)


class MMWeight(MMWeightTpl):
Expand Down Expand Up @@ -133,7 +131,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
self.weight = weight[self.start : self.end]
if self.bias_name in weights:
bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end]
self.bias = bias.cuda(self.device_id_)
self.bias = bias.cuda(get_current_device_id())

if self.weight_scale_name is not None and self.weight_scale_name in weights:
block_size = 1
Expand All @@ -154,7 +152,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:

if self.act_scale_name is not None and self.act_scale_name in weights:
input_scale = weights[self.act_scale_name].to(torch.float)
self.input_scale = input_scale.cuda()
self.input_scale = input_scale.cuda(get_current_device_id())

if weight is None and weight_scale is None and input_scale is None:
return
Expand Down Expand Up @@ -198,7 +196,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
self.weight = weight[:, self.start : self.end]
if self.bias_name in weights:
bias = weights[self.bias_name]
self.bias = (bias / self.world_size_).to(self.data_type_).cuda(self.device_id_)
self.bias = (bias / self.world_size_).to(self.data_type_).cuda(get_current_device_id())

if self.quantized_weight and self.weight_scale_name in weights:
block_size = 1
Expand All @@ -216,7 +214,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:

if self.static_activation and self.act_scale_name in weights:
input_scale = weights[self.act_scale_name].to(torch.float)
self.input_scale = input_scale.cuda()
self.input_scale = input_scale.cuda(get_current_device_id())

if weight is None and weight_scale is None and input_scale is None:
return
Expand Down Expand Up @@ -294,19 +292,19 @@ def _fuse(self) -> None:
delattr(self, "weights")

if self.weight_scale is None and (None not in self.weight_scales):
self.weight_scale = torch.cat(self.weight_scales, dim=0).cuda()
self.weight_scale = torch.cat(self.weight_scales, dim=0).cuda(get_current_device_id())
self._post_load_weights()
delattr(self, "weight_scales")

if self.static_activation and self.input_scale is None and (None not in self.input_scales):
input_scales = torch.stack(self.input_scales, dim=0)
self.input_scale = torch.max(input_scales).cuda()
self.input_scale = torch.max(input_scales).cuda(get_current_device_id())
self._post_load_weights()
delattr(self, "input_scales")

if self.has_bias:
if self.bias is None and (None not in self.biases):
self.bias = torch.cat(self.biases, dim=0).cuda(self.device_id_)
self.bias = torch.cat(self.biases, dim=0).cuda(get_current_device_id())
delattr(self, "biases")
return self

Expand Down Expand Up @@ -449,10 +447,10 @@ def _post_load_weights(self) -> None:
and (not self.static_activation or self.input_scale is not None)
):
if self.weight_scale.ndim > 1:
self.weight_scale = self.weight_scale.cuda(self.device_id_)
self.weight = [self.weight.cuda(self.device_id_), self.weight_scale, self.input_scale]
self.weight_scale = self.weight_scale.cuda(get_current_device_id())
self.weight = [self.weight.cuda(get_current_device_id()), self.weight_scale, self.input_scale]
return
self.weight = self.weight.cuda(self.device_id_)
self.weight = self.weight.cuda(get_current_device_id())


class BMMWeight(BMMWeightTpl):
Expand Down Expand Up @@ -518,7 +516,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
self.weight = weight[self.start : self.end]
if self.bias_name in weights:
bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end]
self.bias = bias.cuda(self.device_id_)
self.bias = bias.cuda(get_current_device_id())

if self.weight_scale_name is not None and self.weight_scale_name in weights:
weight_scale = weights[self.weight_scale_name]
Expand All @@ -532,7 +530,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:

if self.act_scale_name is not None and self.act_scale_name in weights:
input_scale = weights[self.act_scale_name].to(torch.float)
self.input_scale = input_scale.cuda()
self.input_scale = input_scale.cuda(get_current_device_id())

if weight is None and weight_scale is None and input_scale is None:
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from .base_weight import BaseWeightTpl
from lightllm.utils.dist_utils import get_current_device_id


class NormWeight(BaseWeightTpl):
Expand All @@ -13,9 +14,9 @@ def __init__(self, weight_name, data_type, bias_name=None):

def load_hf_weights(self, weights):
if self.weight_name in weights:
self.weight = weights[self.weight_name].to(self.data_type_).cuda(self.device_id_)
self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id())
if self.bias_name in weights:
self.bias = weights[self.bias_name].to(self.data_type_).cuda(self.device_id_)
self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id())

def verify_load(self):
load_ok = True
Expand All @@ -33,7 +34,7 @@ def __init__(self, weight_name, data_type, bias_name=None):

def load_hf_weights(self, weights):
if self.weight_name in weights:
self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(self.device_id_)
self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(get_current_device_id())


class TpNormWeight(NormWeight):
Expand All @@ -46,6 +47,6 @@ def load_hf_weights(self, weights):
end = self.split_n_embed * (self.tp_rank_ + 1)

if self.weight_name in weights:
self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(self.device_id_)
self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id())
if self.bias_name in weights:
self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(self.device_id_)
self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id())
Loading