diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py index 5e2978b62..ee7c0cf33 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py @@ -73,16 +73,20 @@ def _post_load_weights(self) -> None: and (not self.static_activation or self.input_scale is not None) ): if self.weight_scale.ndim > 1: - self.weight_scale = self.weight_scale.transpose(0, 1).cuda(self.device_id_) + # 让 k dim 更连续,大多数split k 算法的算子可能能更快 + self.weight_scale = self.weight_scale.cuda(self.device_id_).transpose(0, 1) self.weight = [ - self.weight.transpose(0, 1).cuda(self.device_id_), + # 让 k dim 更连续,大多数split k 算法的算子可能能更快 + self.weight.cuda(self.device_id_).transpose(0, 1), self.weight_scale, self.input_scale, ] else: self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.device_id_)) return - self.weight = self.weight.to(self.data_type_).transpose(0, 1).cuda(self.device_id_) + + # 让 k dim 更连续,大多数split k 算法的算子可能能更快 + self.weight = self.weight.to(self.data_type_).cuda(self.device_id_).transpose(0, 1) class MMWeight(MMWeightTpl): diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index 22a0e99d2..8e258eec7 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -331,7 +331,7 @@ def grouped_matmul_kernel( for step_k in range(0, tl.cdiv(k, BLOCK_SIZE_K)): # hint to Triton compiler to do proper loop pipelining # tl.multiple_of(a_ptrs, [16, 16]) - tl.multiple_of(b_ptrs, [16, 16]) + # tl.multiple_of(b_ptrs, [16, 16]) if use_fp8_w8a8: a = tl.load(a_ptrs, mask=(offs_am[None, :] < cur_m) & (offs_k[:, None] < k)) @@ -464,10 +464,10 @@ def grouped_matmul( token_input_scale, expert_to_weights_scale, expert_to_weights_scale.stride(0) - if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2 + if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 1 else 0, expert_to_weights_scale.stride(1) - if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2 + if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 2 else 0, expert_to_weights_scale.stride(2) if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3 @@ -532,10 +532,10 @@ def grouped_matmul( token_input_scale, expert_to_weights_scale, expert_to_weights_scale.stride(0) - if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2 + if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 1 else 0, expert_to_weights_scale.stride(1) - if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2 + if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 2 else 0, expert_to_weights_scale.stride(2) if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3 diff --git a/lightllm/common/quantization/vllm_quant.py b/lightllm/common/quantization/vllm_quant.py index dcfed8b69..e2b7309fa 100644 --- a/lightllm/common/quantization/vllm_quant.py +++ b/lightllm/common/quantization/vllm_quant.py @@ -198,7 +198,6 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ dtype=input_tensor.dtype, ) else: - qweight = qweight.t().contiguous().t() input_scale = input_scale.t().contiguous().t() torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias) return out diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index f2b431cc6..dc558edeb 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -18,6 +18,7 @@ ROWBMMWeightNoTp, ) from functools import partial +from ..triton_kernel.weight_dequant import weight_dequant class Deepseek2TransformerLayerWeight(TransformerLayerWeight): @@ -116,8 +117,15 @@ def _load_vb_scale(self, kv_b_proj_scale_, block_size): def load_hf_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights: kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"] + # for deepseek_v3, the bmm operator is not quantized + if self.quant_cfg.quantized_weight: + kv_b_proj_ = weight_dequant( + kv_b_proj_.cuda(), + weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix].cuda(), + ).cpu() weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_) weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_) + if ( self.quant_cfg.quantized_weight and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix in weights @@ -184,15 +192,11 @@ def _init_qkvo(self): f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", self.data_type_, split_n_embed=self.tp_q_head_num_, - weight_scale_suffix=self.weight_scale_suffix, - act_scale_suffix=self.act_scale_suffix, ) self.v_b_proj_ = ROWBMMWeight( f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", self.data_type_, split_n_embed=self.tp_q_head_num_, - weight_scale_suffix=self.weight_scale_suffix, - act_scale_suffix=self.act_scale_suffix, ) if self.enable_cc_method: self.cc_kv_b_proj_ = ROWMMWeight( diff --git a/lightllm/models/deepseek2/triton_kernel/weight_dequant.py b/lightllm/models/deepseek2/triton_kernel/weight_dequant.py new file mode 100644 index 000000000..880aa96d2 --- /dev/null +++ b/lightllm/models/deepseek2/triton_kernel/weight_dequant.py @@ -0,0 +1,59 @@ +# adapt from +# https://github.com/deepseek-ai/DeepSeek-V3/blob/f09f5fa321f5a421704136c0463b1eaca6557712/inference/kernel.py +import torch +import triton +import triton.language as tl +from triton import Config + + +def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ + assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" + assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"])) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y.to(torch.bfloat16) + + +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask)