Skip to content

deepseekv3 bmm noquant and fix moe gemm bug. #745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,20 @@ def _post_load_weights(self) -> None:
and (not self.static_activation or self.input_scale is not None)
):
if self.weight_scale.ndim > 1:
self.weight_scale = self.weight_scale.transpose(0, 1).cuda(self.device_id_)
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
self.weight_scale = self.weight_scale.cuda(self.device_id_).transpose(0, 1)
self.weight = [
self.weight.transpose(0, 1).cuda(self.device_id_),
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
self.weight.cuda(self.device_id_).transpose(0, 1),
self.weight_scale,
self.input_scale,
]
else:
self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.device_id_))
return
self.weight = self.weight.to(self.data_type_).transpose(0, 1).cuda(self.device_id_)

# 让 k dim 更连续,大多数split k 算法的算子可能能更快
self.weight = self.weight.to(self.data_type_).cuda(self.device_id_).transpose(0, 1)


class MMWeight(MMWeightTpl):
Expand Down
10 changes: 5 additions & 5 deletions lightllm/common/fused_moe/grouped_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def grouped_matmul_kernel(
for step_k in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
# hint to Triton compiler to do proper loop pipelining
# tl.multiple_of(a_ptrs, [16, 16])
tl.multiple_of(b_ptrs, [16, 16])
# tl.multiple_of(b_ptrs, [16, 16])

if use_fp8_w8a8:
a = tl.load(a_ptrs, mask=(offs_am[None, :] < cur_m) & (offs_k[:, None] < k))
Expand Down Expand Up @@ -464,10 +464,10 @@ def grouped_matmul(
token_input_scale,
expert_to_weights_scale,
expert_to_weights_scale.stride(0)
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 1
else 0,
expert_to_weights_scale.stride(1)
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 2
else 0,
expert_to_weights_scale.stride(2)
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3
Expand Down Expand Up @@ -532,10 +532,10 @@ def grouped_matmul(
token_input_scale,
expert_to_weights_scale,
expert_to_weights_scale.stride(0)
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 1
else 0,
expert_to_weights_scale.stride(1)
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 2
else 0,
expert_to_weights_scale.stride(2)
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3
Expand Down
1 change: 0 additions & 1 deletion lightllm/common/quantization/vllm_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
dtype=input_tensor.dtype,
)
else:
qweight = qweight.t().contiguous().t()
input_scale = input_scale.t().contiguous().t()
torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
return out
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ROWBMMWeightNoTp,
)
from functools import partial
from ..triton_kernel.weight_dequant import weight_dequant


class Deepseek2TransformerLayerWeight(TransformerLayerWeight):
Expand Down Expand Up @@ -116,8 +117,15 @@ def _load_vb_scale(self, kv_b_proj_scale_, block_size):
def load_hf_weights(self, weights):
if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights:
kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"]
# for deepseek_v3, the bmm operator is not quantized
if self.quant_cfg.quantized_weight:
kv_b_proj_ = weight_dequant(
kv_b_proj_.cuda(),
weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix].cuda(),
).cpu()
weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_)
weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_)

if (
self.quant_cfg.quantized_weight
and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix in weights
Expand Down Expand Up @@ -184,15 +192,11 @@ def _init_qkvo(self):
f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight",
self.data_type_,
split_n_embed=self.tp_q_head_num_,
weight_scale_suffix=self.weight_scale_suffix,
act_scale_suffix=self.act_scale_suffix,
)
self.v_b_proj_ = ROWBMMWeight(
f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight",
self.data_type_,
split_n_embed=self.tp_q_head_num_,
weight_scale_suffix=self.weight_scale_suffix,
act_scale_suffix=self.act_scale_suffix,
)
if self.enable_cc_method:
self.cc_kv_b_proj_ = ROWMMWeight(
Expand Down
59 changes: 59 additions & 0 deletions lightllm/models/deepseek2/triton_kernel/weight_dequant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# adapt from
# https://github.com/deepseek-ai/DeepSeek-V3/blob/f09f5fa321f5a421704136c0463b1eaca6557712/inference/kernel.py
import torch
import triton
import triton.language as tl
from triton import Config


def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.

Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M, N).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.

Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.

Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y.to(torch.bfloat16)


@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Dequantizes weights using the provided scaling factors and stores the result.

Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
s_ptr (tl.pointer): Pointer to the scaling factors.
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.

Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)