Skip to content
Merged
15 changes: 14 additions & 1 deletion tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .quantization import MoEWeightLoadingMode
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
from .routing import BaseMoeRoutingMethod


Expand Down Expand Up @@ -340,6 +341,18 @@ def __init__(
layer_idx=layer_idx,
)

def _get_quant_method(self):
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
if self.quant_config.layer_quant_mode.has_fp8_block_scales():
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
else:
raise ValueError(
f"Unsupported quantization mode: {self.quant_config.quant_mode}"
)
else:
return UnquantizedFusedMoEMethod()

@nvtx_range("[DG] forward")
def forward_chunk(
self,
Expand Down
81 changes: 44 additions & 37 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,45 +629,8 @@ def create_weights(self, module: torch.nn.Module):

def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):

if get_sm_version() == 100:
expert_ids = set(module.initial_local_expert_ids)
if self.need_load_shared_weights(module):
expert_ids.update(
module.layer_load_balancer.get_load_expert_ids())
for name in list(weights.keys()):
if name.endswith("weight_scale_inv"):
if int(name.split(".")[0]) not in expert_ids:
continue
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
super().load_weights(module, weights, weight_loading_mode)

if get_sm_version() == 100:
transfromed_w3_w1_scale = transform_sf_into_required_layout(
module.quant_scales[0],
mn=module.w3_w1_weight.shape[1],
k=module.w3_w1_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w3_w1_weight_scaling_factor = nn.Parameter(
transfromed_w3_w1_scale, requires_grad=False)
transfromed_w2_scale = transform_sf_into_required_layout(
module.quant_scales[1],
mn=module.w2_weight.shape[1],
k=module.w2_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
requires_grad=False)
self.setup_quant_scales(module)

def setup_quant_scales(self, module: torch.nn.Module):
module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales(
fc_weight_scales=module.w3_w1_weight_scaling_factor,
Expand Down Expand Up @@ -765,6 +728,50 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
})


class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
DeepSeekFP8BlockScalesFusedMoEMethod):

def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):
if get_sm_version() == 100:
expert_ids = set(module.initial_local_expert_ids)
if self.need_load_shared_weights(module):
expert_ids.update(
module.layer_load_balancer.get_load_expert_ids())
for name in list(weights.keys()):
if name.endswith("weight_scale_inv"):
if int(name.split(".")[0]) not in expert_ids:
continue
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
super().load_weights(module, weights, weight_loading_mode)

if get_sm_version() == 100:
transfromed_w3_w1_scale = transform_sf_into_required_layout(
module.quant_scales[0],
mn=module.w3_w1_weight.shape[1],
k=module.w3_w1_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w3_w1_weight_scaling_factor = nn.Parameter(
transfromed_w3_w1_scale, requires_grad=False)
transfromed_w2_scale = transform_sf_into_required_layout(
module.quant_scales[1],
mn=module.w2_weight.shape[1],
k=module.w2_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
requires_grad=False)
self.setup_quant_scales(module)


class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):

def create_weights(self, module: torch.nn.Module):
Expand Down
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/modules/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(self,
config: Optional[ModelConfig] = None,
overridden_tp_size: Optional[int] = None,
reduce_output: bool = True,
layer_idx: Optional[int] = None):
layer_idx: Optional[int] = None,
use_cute_dsl_blockscaling_mm: bool = False):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
Expand Down Expand Up @@ -64,7 +65,8 @@ def __init__(self,
reduce_output=False,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm)

self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
[self.hidden_size])
Expand All @@ -81,7 +83,8 @@ def __init__(self,
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.down_lora,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm)

# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora
Expand Down
30 changes: 20 additions & 10 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,21 +583,29 @@ def apply(self, module: Linear, input: torch.Tensor,
assert input.dtype == torch.bfloat16

if get_sm_version() == 100:
from tensorrt_llm import deep_gemm
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
output = torch.empty((input.shape[0], module.weight.shape[0]),
device=input.device,
dtype=torch.bfloat16)
deep_gemm.fp8_gemm_nt((a, a_sf),
(module.weight, module.weight_scale),
output,
disable_ue8m0_cast=True)
if module.use_cute_dsl_blockscaling_mm:
# TODO (@lmin): replace with cute_dsl gemm
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
input)
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, module.weight, act_input_sf,
module.weight_scale)
else:
from tensorrt_llm import deep_gemm
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
output = torch.empty((input.shape[0], module.weight.shape[0]),
device=input.device,
dtype=torch.bfloat16)
deep_gemm.fp8_gemm_nt((a, a_sf),
(module.weight, module.weight_scale),
output,
disable_ue8m0_cast=True)
else:
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
input)

output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, module.weight, act_input_sf, module.weight_scale)

if bias is not None:
output = output + bias
return output
Expand Down Expand Up @@ -1488,6 +1496,7 @@ def __init__(
lora: Optional[LoraLayer] = None,
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
force_dynamic_quantization: bool = False,
use_cute_dsl_blockscaling_mm: bool = False,
):
from ..distributed import AllReduce

Expand All @@ -1504,6 +1513,7 @@ def __init__(
self.tp_mode = tensor_parallel_mode
self.gather_output = gather_output
self.force_dynamic_quantization = force_dynamic_quantization
self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm

local_in_features = in_features
local_out_features = out_features
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/evaluate/lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

try:
from lm_eval.api.model import TemplateLM
from lm_eval.tasks import TaskManager
except ImportError:
TemplateLM = object

Expand Down Expand Up @@ -147,7 +148,7 @@ def __init__(self,
self.dataset_path = dataset_path
self.num_samples = num_samples

task_manager = lm_eval.tasks.TaskManager(
task_manager = TaskManager(
include_path=f"{os.path.dirname(__file__)}/lm_eval_tasks")
with self._patch_lm_eval():
self.task_dict = lm_eval.tasks.get_task_dict(
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph,
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@skip_no_hopper
@skip_pre_blackwell
@parametrize_with_ids("torch_compile", [False])
@parametrize_with_ids(
"fp8kv,attention_dp,cuda_graph,overlap_scheduler",
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@skip_no_hopper
@skip_pre_blackwell
@parametrize_with_ids("torch_compile", [False])
@parametrize_with_ids(
"fp8kv,attention_dp,cuda_graph,overlap_scheduler",
Expand Down
Loading