diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index bcdf8d4415e..e659801e001 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -13,7 +13,8 @@ from ...model_config import ModelConfig from ...utils import Fp4QuantizedTensor from .fused_moe_cutlass import CutlassFusedMoE -from .quantization import MoEWeightLoadingMode +from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm, + MoEWeightLoadingMode, UnquantizedFusedMoEMethod) from .routing import BaseMoeRoutingMethod @@ -340,6 +341,18 @@ def __init__( layer_idx=layer_idx, ) + def _get_quant_method(self): + if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True): + if self.quant_config.layer_quant_mode.has_fp8_block_scales(): + return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm() + else: + raise ValueError( + f"Unsupported quantization mode: {self.quant_config.quant_mode}" + ) + else: + return UnquantizedFusedMoEMethod() + @nvtx_range("[DG] forward") def forward_chunk( self, diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 1510fac470d..ca373c2ed18 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -629,45 +629,8 @@ def create_weights(self, module: torch.nn.Module): def load_weights(self, module: torch.nn.Module, weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode): - - if get_sm_version() == 100: - expert_ids = set(module.initial_local_expert_ids) - if self.need_load_shared_weights(module): - expert_ids.update( - module.layer_load_balancer.get_load_expert_ids()) - for name in list(weights.keys()): - if name.endswith("weight_scale_inv"): - if int(name.split(".")[0]) not in expert_ids: - continue - weight_name = name.replace("weight_scale_inv", "weight") - logger.debug(f"Resmoothing {weight_name}") - weight = weights[weight_name][:] - scale = weights[name][:] - weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( - weight, scale) super().load_weights(module, weights, weight_loading_mode) - if get_sm_version() == 100: - transfromed_w3_w1_scale = transform_sf_into_required_layout( - module.quant_scales[0], - mn=module.w3_w1_weight.shape[1], - k=module.w3_w1_weight.shape[2], - recipe=(1, 128, 128), - num_groups=module.w3_w1_weight.shape[0], - is_sfa=False) - module.w3_w1_weight_scaling_factor = nn.Parameter( - transfromed_w3_w1_scale, requires_grad=False) - transfromed_w2_scale = transform_sf_into_required_layout( - module.quant_scales[1], - mn=module.w2_weight.shape[1], - k=module.w2_weight.shape[2], - recipe=(1, 128, 128), - num_groups=module.w3_w1_weight.shape[0], - is_sfa=False) - module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale, - requires_grad=False) - self.setup_quant_scales(module) - def setup_quant_scales(self, module: torch.nn.Module): module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales( fc_weight_scales=module.w3_w1_weight_scaling_factor, @@ -765,6 +728,50 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): }) +class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm( + DeepSeekFP8BlockScalesFusedMoEMethod): + + def load_weights(self, module: torch.nn.Module, weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode): + if get_sm_version() == 100: + expert_ids = set(module.initial_local_expert_ids) + if self.need_load_shared_weights(module): + expert_ids.update( + module.layer_load_balancer.get_load_expert_ids()) + for name in list(weights.keys()): + if name.endswith("weight_scale_inv"): + if int(name.split(".")[0]) not in expert_ids: + continue + weight_name = name.replace("weight_scale_inv", "weight") + logger.debug(f"Resmoothing {weight_name}") + weight = weights[weight_name][:] + scale = weights[name][:] + weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( + weight, scale) + super().load_weights(module, weights, weight_loading_mode) + + if get_sm_version() == 100: + transfromed_w3_w1_scale = transform_sf_into_required_layout( + module.quant_scales[0], + mn=module.w3_w1_weight.shape[1], + k=module.w3_w1_weight.shape[2], + recipe=(1, 128, 128), + num_groups=module.w3_w1_weight.shape[0], + is_sfa=False) + module.w3_w1_weight_scaling_factor = nn.Parameter( + transfromed_w3_w1_scale, requires_grad=False) + transfromed_w2_scale = transform_sf_into_required_layout( + module.quant_scales[1], + mn=module.w2_weight.shape[1], + k=module.w2_weight.shape[2], + recipe=(1, 128, 128), + num_groups=module.w3_w1_weight.shape[0], + is_sfa=False) + module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale, + requires_grad=False) + self.setup_quant_scales(module) + + class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): def create_weights(self, module: torch.nn.Module): diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index 8b3e314a9ec..d7c20fe8f04 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -27,7 +27,8 @@ def __init__(self, config: Optional[ModelConfig] = None, overridden_tp_size: Optional[int] = None, reduce_output: bool = True, - layer_idx: Optional[int] = None): + layer_idx: Optional[int] = None, + use_cute_dsl_blockscaling_mm: bool = False): super().__init__() self.layer_idx = layer_idx self.hidden_size = hidden_size @@ -64,7 +65,8 @@ def __init__(self, reduce_output=False, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm) self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], [self.hidden_size]) @@ -81,7 +83,8 @@ def __init__(self, skip_create_weights_in_init=config.skip_create_weights_in_init, lora=self.down_lora, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm) # These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used, # but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index a6cb25867ff..3bae20297a8 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -583,21 +583,29 @@ def apply(self, module: Linear, input: torch.Tensor, assert input.dtype == torch.bfloat16 if get_sm_version() == 100: - from tensorrt_llm import deep_gemm - a, a_sf = fp8_utils.per_token_quant_and_transform(input) - output = torch.empty((input.shape[0], module.weight.shape[0]), - device=input.device, - dtype=torch.bfloat16) - deep_gemm.fp8_gemm_nt((a, a_sf), - (module.weight, module.weight_scale), - output, - disable_ue8m0_cast=True) + if module.use_cute_dsl_blockscaling_mm: + # TODO (@lmin): replace with cute_dsl gemm + act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( + input) + output = torch.ops.trtllm.fp8_block_scaling_gemm( + act_input_fp8, module.weight, act_input_sf, + module.weight_scale) + else: + from tensorrt_llm import deep_gemm + a, a_sf = fp8_utils.per_token_quant_and_transform(input) + output = torch.empty((input.shape[0], module.weight.shape[0]), + device=input.device, + dtype=torch.bfloat16) + deep_gemm.fp8_gemm_nt((a, a_sf), + (module.weight, module.weight_scale), + output, + disable_ue8m0_cast=True) else: act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( input) - output = torch.ops.trtllm.fp8_block_scaling_gemm( act_input_fp8, module.weight, act_input_sf, module.weight_scale) + if bias is not None: output = output + bias return output @@ -1488,6 +1496,7 @@ def __init__( lora: Optional[LoraLayer] = None, allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO, force_dynamic_quantization: bool = False, + use_cute_dsl_blockscaling_mm: bool = False, ): from ..distributed import AllReduce @@ -1504,6 +1513,7 @@ def __init__( self.tp_mode = tensor_parallel_mode self.gather_output = gather_output self.force_dynamic_quantization = force_dynamic_quantization + self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm local_in_features = in_features local_out_features = out_features diff --git a/tensorrt_llm/evaluate/lm_eval.py b/tensorrt_llm/evaluate/lm_eval.py index bdddbcbb736..6a24e07f79a 100644 --- a/tensorrt_llm/evaluate/lm_eval.py +++ b/tensorrt_llm/evaluate/lm_eval.py @@ -25,6 +25,7 @@ try: from lm_eval.api.model import TemplateLM + from lm_eval.tasks import TaskManager except ImportError: TemplateLM = object @@ -147,7 +148,7 @@ def __init__(self, self.dataset_path = dataset_path self.num_samples = num_samples - task_manager = lm_eval.tasks.TaskManager( + task_manager = TaskManager( include_path=f"{os.path.dirname(__file__)}/lm_eval_tasks") with self._patch_lm_eval(): self.task_dict = lm_eval.tasks.get_task_dict( diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index ada8352b9ad..16a76dd5ff0 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1021,7 +1021,7 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph, task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @skip_no_hopper + @skip_pre_blackwell @parametrize_with_ids("torch_compile", [False]) @parametrize_with_ids( "fp8kv,attention_dp,cuda_graph,overlap_scheduler", @@ -1171,7 +1171,7 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn, task.evaluate(llm) @pytest.mark.skip_less_device(4) - @skip_no_hopper + @skip_pre_blackwell @parametrize_with_ids("torch_compile", [False]) @parametrize_with_ids( "fp8kv,attention_dp,cuda_graph,overlap_scheduler", diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 5b4c74dd97c..b56caa264a8 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -628,7 +628,7 @@ def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) -@skip_non_hopper_unittest +@skip_pre_blackwell @pytest.mark.parametrize( "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode", product( @@ -640,13 +640,13 @@ def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ], ), ) -def test_fused_moe_fp8_blockwise(dtype, - num_experts, - seq_len, - hidden_size, - RoutingMethodCls, - WeightLoadingMode, - mapping=None): +def test_fused_moe_fp8_blockwise_cute_dsl(dtype, + num_experts, + seq_len, + hidden_size, + RoutingMethodCls, + WeightLoadingMode, + mapping=None): SEQ_LEN = seq_len HIDDEN_SIZE = hidden_size INTERMEDIATE_SIZE = 1536 @@ -739,7 +739,128 @@ def test_fused_moe_fp8_blockwise(dtype, fused_moe.cuda() fused_moe.load_weights([weights]) - fused_moe_origin = CutlassFusedMoE( + ref_fused_moe = RefGatedMLPFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + model_config=ModelConfig(quant_config=quant_config), + # Note: use deepgemm mm will cause accuracy error, so we use trtllmgen mm here + use_cute_dsl_blockscaling_mm=True, + ) + ref_fused_moe.load_weights([weights]) + ref_fused_moe.cuda() + + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + ref_output = ref_fused_moe.forward(x, router_logits) + + # compare + torch.cuda.synchronize() + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + return True + + +@skip_non_hopper_unittest +@pytest.mark.parametrize( + "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode", + product( + [torch.bfloat16], + [72], + [128, 256, 384, 512, 1024, 2048, 4096, 8192], + [2560], + [DefaultMoeRoutingMethod], + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ], + ), +) +def test_fused_moe_fp8_blockwise_cutlass(dtype, + num_experts, + seq_len, + hidden_size, + RoutingMethodCls, + WeightLoadingMode, + mapping=None): + SEQ_LEN = seq_len + HIDDEN_SIZE = hidden_size + INTERMEDIATE_SIZE = 1536 + NUM_EXPERTS = num_experts + TOP_K = 6 + + routing_method = RoutingMethodCls(top_k=TOP_K) + + mapping = mapping or Mapping() + mapping.rank = mpi_rank() + torch.cuda.set_device(mapping.rank) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") + # Note: we use some special values init x and weight, otherwise the test will false positive failed. + set_tensor_value_2(x, SEQ_LEN, HIDDEN_SIZE) + + x = x.cuda() + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), + dtype=dtype, + device="cuda") + + weights = {} + + if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + weights['gate_up_proj'] = {} + weights['down_proj'] = {} + weights['gate_up_proj_weight_scale'] = {} + weights['down_proj_weight_scale'] = {} + + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype, + device="cuda") + w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + set_tensor_value_3(w1_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + set_tensor_value_4(w2_weight, HIDDEN_SIZE, INTERMEDIATE_SIZE) + set_tensor_value_3(w3_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + + w1_weight_fp8, w1_weight_scale = per_block_cast_to_fp8(w1_weight) + w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w2_weight_fp8, w2_weight_scale = per_block_cast_to_fp8(w2_weight) + w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w3_weight_fp8, w3_weight_scale = per_block_cast_to_fp8(w3_weight) + w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda() + + weights[f"{expert_id}.w1.weight"] = w1_weight_fp8 + weights[f"{expert_id}.w2.weight"] = w2_weight_fp8 + weights[f"{expert_id}.w3.weight"] = w3_weight_fp8 + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale + + if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + weights['gate_up_proj'][expert_id] = torch.cat( + [w3_weight_fp8, w1_weight_fp8], + dim=-2).transpose(0, 1).contiguous() + weights['down_proj'][expert_id] = w2_weight_fp8.transpose( + 0, 1).contiguous() + weights['gate_up_proj_weight_scale'][expert_id] = torch.cat( + [w3_weight_scale, w1_weight_scale], + dim=-2).transpose(0, 1).contiguous() + weights['down_proj_weight_scale'][ + expert_id] = w2_weight_scale.transpose(0, 1).contiguous() + elif WeightLoadingMode == MoEWeightLoadingMode.VANILLA: + weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale + + quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES) + + fused_moe = CutlassFusedMoE( num_experts=NUM_EXPERTS, routing_method=routing_method, hidden_size=HIDDEN_SIZE, @@ -749,8 +870,8 @@ def test_fused_moe_fp8_blockwise(dtype, model_config=ModelConfig(quant_config=quant_config, mapping=mapping), weight_loading_mode=WeightLoadingMode, ) - fused_moe_origin.cuda() - fused_moe_origin.load_weights([weights]) + fused_moe.cuda() + fused_moe.load_weights([weights]) ref_fused_moe = RefGatedMLPFusedMoE( num_experts=NUM_EXPERTS, @@ -765,13 +886,10 @@ def test_fused_moe_fp8_blockwise(dtype, with torch.inference_mode(): output = fused_moe.forward(x, router_logits) - output_origin = fused_moe_origin.forward(x, router_logits) ref_output = ref_fused_moe.forward(x, router_logits) # compare torch.cuda.synchronize() - torch.testing.assert_close(output_origin, output, rtol=1e-2, atol=0.1) - torch.testing.assert_close(output_origin, ref_output, rtol=1e-2, atol=0.1) torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) return True @@ -781,17 +899,55 @@ def test_fused_moe_fp8_blockwise(dtype, reason="needs 4 GPUs to run this test") @pytest.mark.parametrize("ep_size", [1, 2, 4]) @pytest.mark.parametrize("routing_method", [DefaultMoeRoutingMethod]) -def test_fused_moe_fp8_blockwise_multi_gpu(ep_size, routing_method): +@pytest.mark.parametrize( + "weight_loading_mode", + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ]) +def test_fused_moe_fp8_blockwise_cutlass_multi_gpu(ep_size, routing_method, + weight_loading_mode): + world_size = 4 + with MPIPoolExecutor(max_workers=world_size) as executor: + results = executor.map( + test_fused_moe_fp8_blockwise_cutlass, + *zip(*[( + torch.bfloat16, + 72, + 384, + 384, + routing_method, + weight_loading_mode, + Mapping( + world_size=world_size, + tp_size=world_size, + moe_ep_size=ep_size, + moe_tp_size=world_size // ep_size, + ), + )] * world_size), + ) + for r in results: + assert r is True + + +@skip_pre_blackwell +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="needs 4 GPUs to run this test") +@pytest.mark.parametrize("ep_size", [1, 2, 4]) +@pytest.mark.parametrize("routing_method", [DefaultMoeRoutingMethod]) +@pytest.mark.parametrize( + "weight_loading_mode", + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ]) +def test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu(ep_size, routing_method, + weight_loading_mode): world_size = 4 with MPIPoolExecutor(max_workers=world_size) as executor: results = executor.map( - test_fused_moe_fp8_blockwise, + test_fused_moe_fp8_blockwise_cute_dsl, *zip(*[( torch.bfloat16, 72, 384, 384, routing_method, + weight_loading_mode, Mapping( world_size=world_size, tp_size=world_size, @@ -1453,6 +1609,7 @@ def __init__(self, intermediate_size: int, dtype: Optional[torch.dtype] = None, model_config: ModelConfig = ModelConfig(), + use_cute_dsl_blockscaling_mm: bool = False, bias=False): super().__init__() self.num_experts = num_experts @@ -1471,6 +1628,7 @@ def __init__(self, bias=bias, dtype=self.dtype, config=model_config, + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm, ) for _ in range(self.num_experts) ])