Skip to content

Commit d80cabe

Browse files
committed
refactor moe loading logics.
Signed-off-by: Mindy Li <[email protected]>
1 parent 9cd1f27 commit d80cabe

File tree

4 files changed

+21
-18
lines changed

4 files changed

+21
-18
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,9 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
13441344
params_map = {'gate_up_proj': ['gate_proj', 'up_proj']}
13451345
all_named_modules = dict(self.named_modules())
13461346

1347+
# moe_backend: cute_dsl_group_gemm
1348+
# use_cute_dsl_gemm, use_cute_dsl_bmm; use_cute_dsl
1349+
# attention/mla, gated_mlp, linear
13471350
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
13481351
) and get_sm_version() == 100:
13491352
for name in list(weights.keys()):

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from ...model_config import ModelConfig
1010
from ...utils import Fp4QuantizedTensor
1111
from .fused_moe_cutlass import CutlassFusedMoE
12-
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl,
13-
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
12+
from .quantization import MoEWeightLoadingMode
1413
from .routing import BaseMoeRoutingMethod
1514

1615

@@ -140,18 +139,6 @@ def __init__(
140139
layer_idx=layer_idx,
141140
)
142141

143-
def _get_quant_method(self):
144-
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
145-
exclude_kv_cache=True):
146-
if self.quant_config.layer_quant_mode.has_fp8_block_scales():
147-
return DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl()
148-
else:
149-
raise ValueError(
150-
f"Unsupported quantization mode: {self.quant_config.quant_mode}"
151-
)
152-
else:
153-
return UnquantizedFusedMoEMethod()
154-
155142
def forward_chunk(
156143
self,
157144
x: Union[torch.Tensor, Fp4QuantizedTensor],

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from ...model_config import ModelConfig
1414
from ...utils import Fp4QuantizedTensor
1515
from .fused_moe_cutlass import CutlassFusedMoE
16-
from .quantization import MoEWeightLoadingMode
16+
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
17+
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
1718
from .routing import BaseMoeRoutingMethod
1819

1920

@@ -340,6 +341,18 @@ def __init__(
340341
layer_idx=layer_idx,
341342
)
342343

344+
def _get_quant_method(self):
345+
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
346+
exclude_kv_cache=True):
347+
if self.quant_config.layer_quant_mode.has_fp8_block_scales():
348+
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
349+
else:
350+
raise ValueError(
351+
f"Unsupported quantization mode: {self.quant_config.quant_mode}"
352+
)
353+
else:
354+
return UnquantizedFusedMoEMethod()
355+
343356
@nvtx_range("[DG] forward")
344357
def forward_chunk(
345358
self,

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
430430
module.fc31_input_dequant.data.copy_(max_fc31_input_scale)
431431

432432

433-
class DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl(FusedMoEMethodBase):
433+
class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
434434

435435
def create_weights(self, module: torch.nn.Module):
436436
weight_dtype = torch.float8_e4m3fn
@@ -553,8 +553,8 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
553553
})
554554

555555

556-
class DeepSeekFP8BlockScalesFusedMoEMethod(
557-
DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl):
556+
class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
557+
DeepSeekFP8BlockScalesFusedMoEMethod):
558558

559559
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
560560
weight_loading_mode: MoEWeightLoadingMode):

0 commit comments

Comments
 (0)