diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index 299e302ec53..328cce3d014 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -94,15 +94,40 @@ class FusedMoeRunner : public torch::CustomClassHolder } }; + template + std::unique_ptr create_weight_quant_runner() + { + if (isInt8Quant()) + { + return std::make_unique>(); + } + else if (isInt4Quant()) + { +#ifdef ENABLE_FP8 + if (mUseW4GroupScaling) + { + return std::make_unique< + kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, TypeAct, TypeAct>>(); + } +#endif + return std::make_unique>(); + } + else + { + C10_THROW_ERROR_FORMATTED(Error, "Unsupported weight quantization type"); + } + } + FusedMoeRunner(c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype, - bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_mxfp8_act_scaling, - bool use_fused_finalize) + bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_int8_woq_per_channel, + bool use_mxfp8_act_scaling, bool use_fused_finalize) { mActivationDtype = activation_dtype; mWeightDtype = weight_dtype; mOutputDtype = output_dtype; mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale; mUseW4GroupScaling = use_w4_group_scaling; + mUseINT8WoqPerChannel = use_int8_woq_per_channel; mUseMxfp8ActScaling = use_mxfp8_act_scaling; mUseFusedFinalize = use_fused_finalize; mInnerDimMultiplier = 1; @@ -137,7 +162,6 @@ class FusedMoeRunner : public torch::CustomClassHolder mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp4_e2m1>(mOutputDtype); } - if (isNvfp4Quant()) { mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG @@ -152,7 +176,6 @@ class FusedMoeRunner : public torch::CustomClassHolder default: mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, false>(mOutputDtype); } } - if (isWFP4A16Quant()) { mInnerDimMultiplier = 2; @@ -167,45 +190,19 @@ class FusedMoeRunner : public torch::CustomClassHolder } #endif } - #endif - if (isInt4Quant()) + if (isIntWeightOnlyQuant()) { - mInnerDimMultiplier = 2; // 2 INT4 -> 1 INT8 - if (mActivationDtype == c10::ScalarType::Half) + if (isInt4Quant()) { -#ifdef ENABLE_FP8 - if (mUseW4GroupScaling) - { - mKernelRunner - = std::make_unique>(); - } - else - { - mKernelRunner = std::make_shared>(); - } -#else - mKernelRunner = std::make_shared>(); -#endif + mInnerDimMultiplier = 2; // 2 INT4 -> 1 INT8 } -#ifdef ENABLE_BF16 - else if (mActivationDtype == c10::ScalarType::BFloat16) + switch (mActivationDtype) { -#ifdef ENABLE_FP8 - if (mUseW4GroupScaling) - { - mKernelRunner = std::make_unique< - kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>>(); - } - else - { - mKernelRunner = std::make_shared>(); - } -#else - mKernelRunner = std::make_shared>(); -#endif + case c10::ScalarType::Half: mKernelRunner = create_weight_quant_runner(); break; + case c10::ScalarType::BFloat16: mKernelRunner = create_weight_quant_runner<__nv_bfloat16>(); break; + default: C10_THROW_ERROR_FORMATTED(Error, "Unsupported activation type for int-type weight"); } -#endif } if (!mKernelRunner) { @@ -310,13 +307,31 @@ class FusedMoeRunner : public torch::CustomClassHolder } TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0], "fc1_expert_weights and fc2_expert_weights must have the same number of experts."); - TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, - "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); + + if (mUseINT8WoqPerChannel) + { + // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights: + // [num_experts, inter_size, hidden_size] + TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier * 2, + "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); + } + else + { + TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, + "fc1_expert_weights inter size must be fc2_expert_weights inter size."); + } int experts_per_token = token_selected_experts.sizes()[1]; int64_t num_rows = input.sizes()[0]; int64_t hidden_size = fc2_expert_weights.sizes()[1]; int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier; + if (mUseINT8WoqPerChannel) + { + // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights: + // [num_experts, inter_size, hidden_size] + hidden_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier; + inter_size = fc2_expert_weights.sizes()[1]; + } if (isWMxfp4AMxfp8Quant() || isWMxfp4AFp8Quant()) { @@ -593,8 +608,15 @@ class FusedMoeRunner : public torch::CustomClassHolder } int64_t const num_rows = input.sizes()[0]; - int64_t const hidden_size = fc2_expert_weights.sizes()[1]; - int64_t const inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier; + int64_t hidden_size = fc2_expert_weights.sizes()[1]; + int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier; + if (mUseINT8WoqPerChannel) + { + // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights: + // [num_experts, inter_size, hidden_size] + hidden_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier; + inter_size = fc2_expert_weights.sizes()[1]; + } int64_t const group_size_ = isInt4Quant() ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size : -1; int64_t const group_size = isWFP4A16Quant() @@ -677,6 +699,7 @@ class FusedMoeRunner : public torch::CustomClassHolder bool mUseDeepSeekFP8BlockScaling = false; bool mUseW4GroupScaling = false; + bool mUseINT8WoqPerChannel = false; bool mUseMxfp8ActScaling = false; bool mUseFusedFinalize = true; @@ -891,7 +914,6 @@ class FusedMoeRunner : public torch::CustomClassHolder TORCH_CHECK(false, "MXFP8 x MXFP4 quantization is not supported in OSS Cutlass Moe Gemm"); #endif } - else if (isNvfp4Quant()) { TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for nvfp4 quantization"); @@ -966,8 +988,8 @@ class FusedMoeRunner : public torch::CustomClassHolder } else if (isWFP4A16Quant()) { - TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for W4 quantization"); - TORCH_CHECK(quant_scales.value().size() == 2, "Expecting 8 quant scales for W4A16 quantization"); + TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for weight only quantization"); + TORCH_CHECK(quant_scales.value().size() == 2, "Expecting 2 quant scales for W4A16 quantization"); auto& fc1_weight_scales = quant_scales.value()[0]; auto& fc2_weight_scales = quant_scales.value()[1]; @@ -976,28 +998,45 @@ class FusedMoeRunner : public torch::CustomClassHolder static_cast(fc2_weight_scales.data_ptr()), nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); } - else if (isInt4Quant()) + else if (isIntWeightOnlyQuant()) { - TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for W4 quantization"); - TORCH_CHECK(quant_scales.value().size() == 8, "Expecting 8 quant scales for W4A8 quantization"); - - auto& fc1_weight_scales = quant_scales.value()[0]; - auto& fc2_weight_scales = quant_scales.value()[1]; - auto& fc1_act_scales = quant_scales.value()[2]; - auto& fc2_act_scales = quant_scales.value()[3]; - auto& fc1_weight_zeros = quant_scales.value()[4]; - auto& fc2_weight_zeros = quant_scales.value()[5]; - auto& fc1_alpha = quant_scales.value()[6]; - auto& fc2_alpha = quant_scales.value()[7]; - int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size; - return kernels::QuantParams::GroupWise(group_size, static_cast(fc1_weight_scales.data_ptr()), - static_cast(fc2_weight_scales.data_ptr()), - static_cast(fc1_act_scales.numel() > 0 ? fc1_act_scales.data_ptr() : nullptr), - static_cast(fc2_act_scales.numel() > 0 ? fc2_act_scales.data_ptr() : nullptr), - static_cast(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() : nullptr), - static_cast(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() : nullptr), - static_cast(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr), - static_cast(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr)); + TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for weight only quantization"); + if (mUseINT8WoqPerChannel) + { + TORCH_CHECK( + quant_scales.value().size() == 2, "Expecting 2 quant scales for INT8 weight only quantization"); + auto& fc1_weight_scales = quant_scales.value()[0]; + auto& fc2_weight_scales = quant_scales.value()[1]; + return kernels::QuantParams::Int(static_cast(fc1_weight_scales.data_ptr()), + static_cast(fc2_weight_scales.data_ptr())); + } + else if (isInt4Quant() && mUseW4GroupScaling) + { + TORCH_CHECK(quant_scales.value().size() == 8, "Expecting 8 quant scales for W4A8 quantization"); + + auto& fc1_weight_scales = quant_scales.value()[0]; + auto& fc2_weight_scales = quant_scales.value()[1]; + auto& fc1_act_scales = quant_scales.value()[2]; + auto& fc2_act_scales = quant_scales.value()[3]; + auto& fc1_weight_zeros = quant_scales.value()[4]; + auto& fc2_weight_zeros = quant_scales.value()[5]; + auto& fc1_alpha = quant_scales.value()[6]; + auto& fc2_alpha = quant_scales.value()[7]; + int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size; + return kernels::QuantParams::GroupWise(group_size, + static_cast(fc1_weight_scales.data_ptr()), + static_cast(fc2_weight_scales.data_ptr()), + static_cast(fc1_act_scales.numel() > 0 ? fc1_act_scales.data_ptr() : nullptr), + static_cast(fc2_act_scales.numel() > 0 ? fc2_act_scales.data_ptr() : nullptr), + static_cast(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() : nullptr), + static_cast(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() : nullptr), + static_cast(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr), + static_cast(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr)); + } + else + { + TORCH_CHECK(false, "Unsupported weight only quantization"); + } } else { @@ -1022,6 +1061,11 @@ class FusedMoeRunner : public torch::CustomClassHolder return mUseW4GroupScaling && mWeightDtype == c10::ScalarType::Byte; } + bool isInt8Quant() const + { + return mWeightDtype == c10::ScalarType::Char; + } + bool isInt4Quant() const { return mWeightDtype == c10::ScalarType::QUInt4x2; @@ -1032,6 +1076,11 @@ class FusedMoeRunner : public torch::CustomClassHolder return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant(); } + bool isIntWeightOnlyQuant() const + { + return isInt8Quant() || isInt4Quant(); + } + bool isWMxfp4AFp8Quant() const { return mActivationDtype == c10::ScalarType::Float8_e4m3fn && mWeightDtype == c10::ScalarType::Long @@ -1050,7 +1099,7 @@ class FusedMoeRunner : public torch::CustomClassHolder TORCH_LIBRARY(trtllm, m) { m.class_("FusedMoeRunner") - .def(torch::init()) + .def(torch::init()) .def("run_gemm_profile", &torch_ext::FusedMoeRunner::runGemmProfile) .def("get_tactic_num", &torch_ext::FusedMoeRunner::getTacticNum) .def("run_moe", &torch_ext::FusedMoeRunner::runMoe) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index c8c103557f6..a323bb4f553 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -43,6 +43,7 @@ def __init__( cluster_rank: int, use_deepseek_fp8_block_scale: bool, use_w4_group_scaling: bool, + use_int8_woq_per_channel: bool, use_mxfp8_act_scaling: bool, min_latency_mode: bool, use_fused_finalize: bool, @@ -61,20 +62,22 @@ def __init__( self.enable_alltoall = False self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale self.use_w4_group_scaling = use_w4_group_scaling + self.use_int8_woq_per_channel = use_int8_woq_per_channel self.use_mxfp8_act_scaling = use_mxfp8_act_scaling self.min_latency_mode = min_latency_mode self.use_fused_finalize = use_fused_finalize instance_key = (x_dtype, weight_dtype, output_dtype, use_deepseek_fp8_block_scale, use_w4_group_scaling, - use_mxfp8_act_scaling) + use_int8_woq_per_channel, use_mxfp8_act_scaling) if instance_key not in MoERunner.runner_dict: MoERunner.runner_dict[ instance_key] = torch.classes.trtllm.FusedMoeRunner( x_dtype, weight_dtype, output_dtype, use_deepseek_fp8_block_scale, use_w4_group_scaling, - use_mxfp8_act_scaling, use_fused_finalize) + use_int8_woq_per_channel, use_mxfp8_act_scaling, + use_fused_finalize) self.fused_moe_runner = MoERunner.runner_dict[instance_key] def get_valid_tactics( @@ -138,6 +141,7 @@ def fused_moe( enable_alltoall: bool = False, use_deepseek_fp8_block_scale: bool = False, use_w4_group_scaling: bool = False, + use_int8_woq_per_channel: bool = False, use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, use_fused_finalize: bool = True, @@ -174,6 +178,7 @@ def fused_moe( cluster_rank=cluster_rank, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, use_w4_group_scaling=use_w4_group_scaling, + use_int8_woq_per_channel=use_int8_woq_per_channel, use_mxfp8_act_scaling=use_mxfp8_act_scaling, min_latency_mode=min_latency_mode, use_fused_finalize=use_fused_finalize, @@ -257,13 +262,19 @@ def _( enable_alltoall: bool = False, use_deepseek_fp8_block_scale: bool = False, use_w4_group_scaling: bool = False, + use_int8_woq_per_channel: bool = False, use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, use_fused_finalize: bool = True, tune_max_num_tokens: int = 8192, ): seq_len = input.shape[0] - hidden_size = fc2_expert_weights.shape[1] + if use_int8_woq_per_channel: + # Note: The weight shape for INT8 weight only quantization is different, i.e., + # fc2_expert_weights: [num_experts, inter_size, hidden_size] + hidden_size = fc2_expert_weights.shape[2] + else: + hidden_size = fc2_expert_weights.shape[1] if min_latency_mode: num_experts_on_rank = fc2_expert_weights.shape[0] diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 00a1c494d2a..34bb61a7ab0 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -17,8 +17,9 @@ from .quantization import ( DeepSeekFP8BlockScalesFusedMoEMethod, FP8QDQFusedMoEMethod, MoEWeightLoadingMode, NVFP4CutlassFusedMoEMethod, UnquantizedFusedMoEMethod, - W4A8MXFP4FP8CutlassFusedMoEMethod, W4A8MXFP4MXFP8CutlassFusedMoEMethod, - WFP4A16FusedMoEMethod, WInt4AFP8FusedMoEMethod) + INT8WoqPerChannelFusedMoEMethod, W4A8MXFP4FP8CutlassFusedMoEMethod, + W4A8MXFP4MXFP8CutlassFusedMoEMethod, WFP4A16FusedMoEMethod, + WInt4AFP8FusedMoEMethod) # isort: on from .routing import BaseMoeRoutingMethod @@ -167,8 +168,7 @@ def _check_configs(self): if not (self.quant_config.quant_mode.has_nvfp4() | self.quant_config.quant_mode.has_fp8_block_scales() | self.quant_config.quant_mode.has_fp8_qdq() - | self.quant_config.quant_mode. - is_int4_weight_only_per_group() + | self.quant_config.quant_mode.is_weight_only() | self.quant_config.quant_mode.has_w4a8_mxfp4_fp8() | self.quant_config.quant_mode.has_w4a16_mxfp4() | self.quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()): @@ -182,6 +182,11 @@ def has_w4afp8(self): return self.quant_config and self.quant_config.quant_mode.is_int4_weight_only_per_group( ) + @property + def has_int8_woq_per_channel(self): + return self.quant_config.layer_quant_mode.is_int8_weight_only( + ) and not self.quant_config.layer_quant_mode.has_per_group_scaling() + @cached_property def enable_alltoall(self): return (self.mapping.moe_ep_size > self.routing_method.experts_per_token @@ -204,6 +209,8 @@ def _get_quant_method(self): elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group( ): return WInt4AFP8FusedMoEMethod() + elif self.has_int8_woq_per_channel: + return INT8WoqPerChannelFusedMoEMethod() elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8(): return W4A8MXFP4FP8CutlassFusedMoEMethod() elif self.quant_config.layer_quant_mode.has_w4a16_mxfp4(): @@ -261,6 +268,7 @@ def forward_chunk( # quantize inputs use_deepseek_fp8_block_scale = False use_w4_group_scaling = False + use_int8_woq_per_channel = False use_mxfp8_act_scaling = False weight_dtype = self.w3_w1_weight.dtype x_sf = None @@ -279,9 +287,10 @@ def forward_chunk( pad_size = self.hidden_size - x.shape[1] original_hidden_size = x.shape[1] x = torch.nn.functional.pad(x, (0, pad_size)) - use_w4_group_scaling = True weight_dtype = torch.uint8 + elif self.has_int8_woq_per_channel: + use_int8_woq_per_channel = True elif self.has_nvfp4: if run_post_quant_allgather or self.enable_alltoall: if isinstance(x, Fp4QuantizedTensor): @@ -421,6 +430,7 @@ def forward_chunk( enable_alltoall=self.enable_alltoall, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, use_w4_group_scaling=use_w4_group_scaling, + use_int8_woq_per_channel=use_int8_woq_per_channel, use_mxfp8_act_scaling=use_mxfp8_act_scaling, min_latency_mode=False, use_fused_finalize=self.use_fused_finalize, diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index a12aa200dde..734a7240e4e 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -8,6 +8,8 @@ import tensorrt_llm.logger as trtllm_logger from tensorrt_llm._utils import get_sm_version +from tensorrt_llm.quantization.functional import \ + preprocess_weights_for_mixed_gemm from tensorrt_llm.quantization.utils.fp4_utils import ( float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices, get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices) @@ -63,6 +65,11 @@ class FusedMoEQuantScalesW4A8(NamedTuple): alpha_2: torch.Tensor +class FusedMoEQuantScalesINT8WoqPerChannel(NamedTuple): + fc31_weight_scale: torch.Tensor + fc2_weight_scale: torch.Tensor + + class FusedMoEQuantScalesW4A16MXFP4(NamedTuple): scale_1_interleaved: torch.Tensor scale_2_interleaved: torch.Tensor @@ -336,7 +343,7 @@ def apply(self, module: torch.nn.Module, input: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Apply the quantization method to the input tensor. - This isn’t necessary for all quantization methods, but it’s useful for + This isn't necessary for all quantization methods, but it's useful for certain backends that can encapsulate the MoE forward function. """ raise NotImplementedError @@ -775,6 +782,146 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict], self.setup_quant_scales(module) +class INT8WoqPerChannelFusedMoEMethod(FusedMoEMethodBase): + + def create_weights(self, module: torch.nn.Module): + module.sm_version = get_sm_version() + module.sm_version = 80 if module.sm_version >= 90 else module.sm_version + module.preprocessor = preprocess_weights_for_mixed_gemm + + weight_dtype = torch.int8 + if not module.quant_config.layer_quant_mode.is_int8_weight_only(): + raise NotImplementedError( + f"Weight Only Quantization currently only supports INT8. Got: {module.quant_config.layer_quant_mode}." + ) + + # notice the weight shape for int8 weight-only is different from the original shape, + # since the quantized weights have their own layout + w3_w1_weight_shape = (module.expert_size_per_partition, + module.hidden_size, + module.intermediate_size_per_partition * 2) + w2_weight_shape = (module.expert_size_per_partition, + module.intermediate_size_per_partition, + module.hidden_size) + + fc31_weight_scale = nn.Parameter(torch.empty( + module.expert_size_per_partition, + module.intermediate_size_per_partition * 2, + dtype=module.dtype), + requires_grad=False) + module.register_parameter("fc31_weight_scale", fc31_weight_scale) + + fc2_weight_scale = nn.Parameter(torch.empty( + module.expert_size_per_partition, + module.hidden_size, + dtype=module.dtype), + requires_grad=False) + module.register_parameter("fc2_weight_scale", fc2_weight_scale) + + super().create_weights(module, weight_dtype, w3_w1_weight_shape, + w2_weight_shape) + self.setup_quant_scales(module) + + def setup_quant_scales(self, module: torch.nn.Module): + module.quant_scales = FusedMoEQuantScalesINT8WoqPerChannel( + fc31_weight_scale=module.fc31_weight_scale, + fc2_weight_scale=module.fc2_weight_scale, + ) + + def get_quant_scales(self, module: torch.nn.Module, slot_start, + slot_end) -> tuple[torch.Tensor, ...]: + assert module.smart_router + return FusedMoEQuantScalesINT8WoqPerChannel( + fc31_weight_scale=module.fc31_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + fc2_weight_scale=module.fc2_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + ) + + def load_expert_w3_w1_weight(self, module: torch.nn.Module, + w1_weight: torch.Tensor, + w3_weight: torch.Tensor, + dst_w3_w1_weight: torch.Tensor): + """ + Load w1 and w3 weights for each expert. + """ + w1_weight_shard = load_weight_shard(w1_weight, module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN) + w3_weight_shard = load_weight_shard(w3_weight, module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN) + w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0) + + weight_dtype = torch.int8 + + assert module.dtype in [torch.float16, torch.bfloat16], \ + f"activation dtype should be float16 or bfloat16, got {module.dtype}" + if not module.quant_config.layer_quant_mode.is_int8_weight_only(): + raise NotImplementedError( + f"weight dtype should be INT8. Got: {module.quant_config.layer_quant_mode}." + ) + # preprocess the weights for mixed gemm + w31_weight_shard = module.preprocessor(w31_weight_shard.T.contiguous(), + weight_dtype, module.dtype, + module.sm_version).contiguous() + dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype), + non_blocking=True) + + def load_expert_w2_weight(self, module: torch.nn.Module, + w2_weight: torch.Tensor, + dst_w2_weight: torch.Tensor): + """ + Load w2 weight for each expert. + """ + w2_weight_shard = load_weight_shard(w2_weight, module.tp_size, + module.tp_rank, + TensorParallelMode.ROW) + + weight_dtype = torch.int8 + if not module.quant_config.layer_quant_mode.is_int8_weight_only(): + raise NotImplementedError( + f"Weight Only Quantization currently only supports INT8. Got: {module.quant_config.layer_quant_mode}." + ) + + # preprocess the weights for mixed gemm + w2_weight_shard = module.preprocessor(w2_weight_shard.T.contiguous(), + weight_dtype, module.dtype, + module.sm_version).contiguous() + dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), + non_blocking=True) + + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): + # fc31 scales + all_w3_scales = [ + load_weight_shard(weights[f"{expert_id}.w3.weight_scale"], + module.tp_size, module.tp_rank, + TensorParallelMode.COLUMN) + for expert_id in module.initial_local_expert_ids + ] + all_w1_scales = [ + load_weight_shard(weights[f"{expert_id}.w1.weight_scale"], + module.tp_size, module.tp_rank, + TensorParallelMode.COLUMN) + for expert_id in module.initial_local_expert_ids + ] + w3_w1_scales = torch.cat( + [torch.stack(all_w3_scales), + torch.stack(all_w1_scales)], dim=-1) + w3_w1_scales = w3_w1_scales.to(module.dtype) + module.fc31_weight_scale.data.copy_(w3_w1_scales.contiguous()) + + # fc2 scales + all_w2_scales = [ + load_weight_shard(weights[f"{expert_id}.w2.weight_scale"], + module.tp_size, module.tp_rank, + TensorParallelMode.ROW) + for expert_id in module.initial_local_expert_ids + ] + w2_scales = torch.stack(all_w2_scales).to(module.dtype) + module.fc2_weight_scale.data.copy_(w2_scales.contiguous()) + + class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): def create_weights(self, module: torch.nn.Module): @@ -914,9 +1061,7 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 # SM89 if module.sm_version == 89: - import tensorrt_llm.quantization.functional as trtllm_f - - preprocessor = trtllm_f.preprocess_weights_for_mixed_gemm + preprocessor = preprocess_weights_for_mixed_gemm w31_weight_shard = packer( unpacker(w31_weight_shard.cpu()).T.contiguous()).to( @@ -963,9 +1108,7 @@ def load_expert_w2_weight(self, module: torch.nn.Module, packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 if module.sm_version == 89: - import tensorrt_llm.quantization.functional as trtllm_f - - preprocessor = trtllm_f.preprocess_weights_for_mixed_gemm + preprocessor = preprocess_weights_for_mixed_gemm w2_weight_shard = packer( unpacker(w2_weight_shard.cpu()).T.contiguous()).to( diff --git a/tensorrt_llm/layers/moe.py b/tensorrt_llm/layers/moe.py index 59d3a315e5d..3db89d52a6f 100755 --- a/tensorrt_llm/layers/moe.py +++ b/tensorrt_llm/layers/moe.py @@ -777,7 +777,7 @@ def __init__(self, self.use_int8_weight = use_int8_weight self.group_size = group_size - if self.use_int8_weight: + if self.use_int8_weight and self.group_size > 0: raise NotImplementedError("INT8-GPTQ is not implemented for MoE.") self.static_routing = static_routing diff --git a/tests/unittest/_torch/helpers.py b/tests/unittest/_torch/helpers.py index 5e9f2ba1a26..4b63769735d 100644 --- a/tests/unittest/_torch/helpers.py +++ b/tests/unittest/_torch/helpers.py @@ -75,6 +75,19 @@ def calc_diff(x, y): return 1 - sim +def calc_woq_tolerence(x: torch.Tensor, weight_dtype: torch.dtype): + # align with woq_assert_near_eq function in tests/unittest/trt/quantization/_utils.py + if weight_dtype == torch.int8: + bits_in_type = 8 + elif weight_dtype == torch.quint4x2: + bits_in_type = 4 + quant_range_scale = 1.0 / float(1 << (bits_in_type - 1)) + max_val = torch.max(abs(x)).item() + atol = (max_val * quant_range_scale) * 1.5 # allow for rounding + + return atol + + def reference_moe_torch(x: torch.Tensor, selected_experts: torch.Tensor, final_scales: torch.Tensor, num_experts: int, weights: Dict[str, torch.Tensor]) -> torch.Tensor: diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 0a50f530517..147c0ab2868 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -9,7 +9,8 @@ import pytest import torch import torch.nn as nn -from _torch.helpers import (per_block_cast_to_fp8, per_block_cast_to_fp8_e8m0, +from _torch.helpers import (calc_woq_tolerence, per_block_cast_to_fp8, + per_block_cast_to_fp8_e8m0, per_token_cast_to_fp8_e8m0) from mpi4py import MPI from mpi4py.futures import MPIPoolExecutor @@ -1903,6 +1904,117 @@ def mxfp4_to_fp32(tensor, scales): check_accuracy(output, ref_output, rtol=0.6, atol=0.6, percent=0.945) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("weight_dtype", [torch.int8]) +def test_fused_moe_int8_woq_per_channel(dtype, weight_dtype): + + mapping = Mapping() + mapping.rank = mpi_rank() + + with torch.device(f'cuda:{mapping.rank}'): + SEQ_LEN = 4 + HIDDEN_SIZE = 768 + INTERMEDIATE_SIZE = 640 + NUM_EXPERTS = 3 + TOP_K = 2 + routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") + + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), + dtype=dtype, + device="cuda") + + weight_id = 1 # 1 for w8a16, 2 for w4a16 + quant_config = QuantConfig(quant_algo=QuantAlgo.W8A16) + weights = {} + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.randint( + -128, + 127, (INTERMEDIATE_SIZE, HIDDEN_SIZE // weight_id), + dtype=torch.int8).cuda() + w2_weight = torch.randint( + -128, + 127, (HIDDEN_SIZE, INTERMEDIATE_SIZE // weight_id), + dtype=torch.int8).cuda() + w3_weight = torch.randint( + -128, + 127, (INTERMEDIATE_SIZE, HIDDEN_SIZE // weight_id), + dtype=torch.int8).cuda() + + w1_scale = torch.randn( + (INTERMEDIATE_SIZE), dtype=dtype, device="cuda") / HIDDEN_SIZE + w2_scale = torch.randn( + (HIDDEN_SIZE), dtype=dtype, device="cuda") / INTERMEDIATE_SIZE + w3_scale = torch.randn( + (INTERMEDIATE_SIZE), dtype=dtype, device="cuda") / HIDDEN_SIZE + + weights[f"{expert_id}.w1.weight"] = w1_weight + weights[f"{expert_id}.w2.weight"] = w2_weight + weights[f"{expert_id}.w3.weight"] = w3_weight + weights[f"{expert_id}.w1.weight_scale"] = w1_scale + weights[f"{expert_id}.w2.weight_scale"] = w2_scale + weights[f"{expert_id}.w3.weight_scale"] = w3_scale + + fused_moe = CutlassFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=False, + model_config=ModelConfig(quant_config=quant_config)) + fused_moe.load_weights([weights]) + fused_moe.cuda() + + def ref(): + results = torch.zeros_like(x) + selected_experts, final_scales = routing_method.apply(router_logits) + for e_idx in range(NUM_EXPERTS): + mask = selected_experts == e_idx + activated_tokens = mask.sum(1).bool() + act = x[activated_tokens, :] + if act.shape[0] == 0: + continue + final_scale = (final_scales * + mask).sum(1)[activated_tokens].unsqueeze(1) + # weights + w1 = weights[f"{e_idx}.w1.weight"].T.contiguous().cuda() + w2 = weights[f"{e_idx}.w2.weight"].T.contiguous().cuda() + w3 = weights[f"{e_idx}.w3.weight"].T.contiguous().cuda() + w3_w1 = torch.cat([w3, w1], dim=-1) + # scales + s1 = weights[f"{e_idx}.w1.weight_scale"].cuda() + s2 = weights[f"{e_idx}.w2.weight_scale"].cuda() + s3 = weights[f"{e_idx}.w3.weight_scale"].cuda() + s3_s1 = torch.cat([s3, s1], dim=-1) + # calculation + w3_w1 = (w3_w1.float() * s3_s1).to(dtype) + fc1 = torch.matmul(act, w3_w1) + fc1, gate = fc1.chunk(2, dim=-1) + act = fc1 * torch.nn.functional.silu(gate) + w2 = (w2.float() * s2).to(dtype) + fc2 = torch.matmul(act, w2) + results[activated_tokens, :] += (fc2 * final_scale).to( + results.dtype) + return results + + AutoTuner.get().clear_cache() + with torch.inference_mode(), autotune(): + fused_moe.forward(x, router_logits) + + torch.cuda.synchronize() + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + ref_output = ref() + + # compare + torch.cuda.synchronize() + atol = calc_woq_tolerence(ref_output, weight_dtype) + torch.testing.assert_close(output, ref_output, rtol=1e-7, atol=atol) + + class RefGatedMLPFusedMoE(nn.Module): def __init__(self, diff --git a/tests/unittest/_torch/thop/test_weight_only_quant_gemm.py b/tests/unittest/_torch/thop/test_weight_only_quant_gemm.py index fab60be84bc..447c807a8cc 100644 --- a/tests/unittest/_torch/thop/test_weight_only_quant_gemm.py +++ b/tests/unittest/_torch/thop/test_weight_only_quant_gemm.py @@ -15,7 +15,7 @@ import pytest import torch -from _torch.helpers import calc_diff +from _torch.helpers import calc_diff, calc_woq_tolerence def weight_only_quant_gemm_reference(a, b, b_scales): @@ -29,18 +29,6 @@ def weight_only_quant_gemm_reference(a, b, b_scales): return ref.to(dtype=a_dtype) -def woq_tolerence_calculate(output, output_ref, b_dtype): - if b_dtype == torch.int8: - bits_in_type = 8 - elif b_dtype == torch.quint4x2: - bits_in_type = 4 - quant_range_scale = 1.0 / float(1 << (bits_in_type - 1)) - max_val = torch.max(abs(output_ref)).item() - atol = (max_val * quant_range_scale) * 1.5 # allow for rounding - - return atol - - @pytest.mark.parametrize( "k, n", [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (1024, 1024)], @@ -79,5 +67,5 @@ def test_weight_only_quant_gemm(a_dtype, b_dtype, m, k, n): # check accuracy diff = calc_diff(output, output_ref) assert diff < 1e-3, f"Difference {diff} >= 1e-3" - atol = woq_tolerence_calculate(output, output_ref, b_dtype) + atol = calc_woq_tolerence(output_ref, b_dtype) torch.testing.assert_close(output_ref, output, atol=atol, rtol=1e-7) diff --git a/tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py b/tests/unittest/trt/quantization/test_moe_weight_only_quant_matmul.py similarity index 70% rename from tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py rename to tests/unittest/trt/quantization/test_moe_weight_only_quant_matmul.py index e295fa01285..d741978c2a3 100644 --- a/tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py +++ b/tests/unittest/trt/quantization/test_moe_weight_only_quant_matmul.py @@ -36,10 +36,11 @@ from . import _utils -class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase): +class TestMoEWeightOnlyQuantMatmul(unittest.TestCase): def setUp(self): torch.manual_seed(0) + torch.cuda.manual_seed(0) tensorrt_llm.logger.set_level('error') def create_trt_session( @@ -67,12 +68,15 @@ def create_trt_session( network = builder.create_network() dtype = str_dtype_to_trt(str_dtype) norm_mode = MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE - quant_mode = QuantMode.use_weight_only(True, True) k = act.shape[1] - n = fc2_prequant_scale.shape[ - -1] # get the original n from prequant scale because either weight or scale could be interleaved + if has_pre_quant: + n = fc2_prequant_scale.shape[-1] + else: + n = weight_scaling_factor_1.shape[-1] // 2 num_experts = weight_scaling_factor_1.shape[0] + use_int8 = True if self.quant_mode.is_int8_weight_only() else False + with tensorrt_llm.net_guard(network): trt_key = Tensor(name='input_hidden_states', shape=act.shape, @@ -91,18 +95,25 @@ def create_trt_session( hidden_act="swiglu", bias=False, dtype=dtype, - quant_mode=quant_mode, + quant_mode=self.quant_mode, pre_quant_scale=has_pre_quant, zero=has_zero, use_w4a8_awq=has_alpha, + use_int8_weight=use_int8, group_size=group_size) moe.router.weight.value = torch_to_numpy(router.cpu()) moe.fc.weight.value = torch_to_numpy(fc1_weights.cpu()) moe.proj.weight.value = torch_to_numpy(fc2_weights.cpu()) - moe.fc.weights_scaling_factor.value = torch_to_numpy( - weight_scaling_factor_1.cpu()) - moe.proj.weights_scaling_factor.value = torch_to_numpy( - weight_scaling_factor_2.cpu()) + if group_size != -1: + moe.fc.weights_scaling_factor.value = torch_to_numpy( + weight_scaling_factor_1.cpu()) + moe.proj.weights_scaling_factor.value = torch_to_numpy( + weight_scaling_factor_2.cpu()) + else: + moe.fc.per_channel_scale.value = torch_to_numpy( + weight_scaling_factor_1.cpu()) + moe.proj.per_channel_scale.value = torch_to_numpy( + weight_scaling_factor_2.cpu()) if has_pre_quant: moe.fc.prequant_scaling_factor.value = torch_to_numpy( fc1_prequant_scale.cpu()) @@ -121,8 +132,8 @@ def create_trt_session( session = create_session(builder, network, precision=trt_dtype_to_str(dtype), - int8=False, - quant_mode=quant_mode) + int8=use_int8, + quant_mode=self.quant_mode) return session def _woq_moe_groupwise_matmul(self, @@ -206,8 +217,8 @@ def _woq_moe_groupwise_matmul(self, ref_weight_1 += zero_1.repeat_interleave(group_size, dim=1) ref_weight_2 += zero_2.repeat_interleave(group_size, dim=1) activation_type = torch.float8_e4m3fn if has_alpha else activation_dtype - do_weight_interleave = get_sm_version( - ) != 90 or not has_alpha # Hopper w4a8 does not interleave weight + # Hopper w4a8 does not interleave weight + do_weight_interleave = get_sm_version() != 90 or not has_alpha cuda_q_weight_1 = preprocessor( unprocessed_weight_1.cpu(), quantized_weight_dtype, @@ -298,6 +309,97 @@ def interleave_scales(scales: torch.Tensor, interleave_dim: int): ref = results.view(*inputs.shape) _utils.woq_assert_near_eq(ref, out, 2) + def _woq_moe_matmul_per_channel(self, + m, + n, + k, + num_experts, + activation_dtype_str, + quantized_weight_dtype, + top_k=2): + + activation_dtype = tensorrt_llm._utils.str_dtype_to_torch( + activation_dtype_str) + activation = torch.randn(m, k, dtype=activation_dtype, device="cuda") + router = torch.randn((num_experts, k), + dtype=torch.float32, + device="cuda") + + num_weights_in_32_bits = 4 + + assert n % num_weights_in_32_bits == 0, f"n must be a multiple of {num_weights_in_32_bits}" + unprocessed_int_weight_1 = torch.randint( + -2**31, + 2**31, (num_experts, k, n * 2 // num_weights_in_32_bits), + dtype=torch.int32, + device="cuda") + unprocessed_int_weight_2 = torch.randint( + -2**31, + 2**31, (num_experts, n, k // num_weights_in_32_bits), + dtype=torch.int32, + device="cuda") + unprocessed_weight_1 = unprocessed_int_weight_1.view(torch.int8) + unprocessed_weight_2 = unprocessed_int_weight_2.view(torch.int8) + + scale_1 = torch.randn( + num_experts, 1, n * 2, dtype=activation_dtype, device="cuda") / k + scale_2 = torch.randn( + num_experts, 1, k, dtype=activation_dtype, device="cuda") / n + + ref_weight_1 = unprocessed_weight_1 * scale_1 + ref_weight_2 = unprocessed_weight_2 * scale_2 + scale_1 = scale_1.squeeze(1) + scale_2 = scale_2.squeeze(1) + + preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm + + cuda_q_weight_1 = preprocessor(unprocessed_weight_1.cpu(), + quantized_weight_dtype, + activation_dtype).cpu() + cuda_q_weight_2 = preprocessor(unprocessed_weight_2.cpu(), + quantized_weight_dtype, + activation_dtype).cpu() + + session = self.create_trt_session(activation_dtype_str, activation, + router, None, None, cuda_q_weight_1, + cuda_q_weight_2, scale_1, scale_2, + None, None, None, None, top_k, False, + False, False, -1) + + inputs = {"input_hidden_states": activation} + outputs = run_session(session, inputs) + out = outputs['output'].float() + + # ref + inputs = activation.cuda().float() + inputs_merged = inputs.view(-1, inputs.shape[-1]) + routing = torch.matmul(inputs_merged, router.T.float()) + router_probs = torch.softmax(routing, 1, dtype=inputs.dtype) + topk = torch.topk(router_probs, top_k) + results = torch.zeros_like(inputs_merged) + for i, (scales, experts) in enumerate(zip(topk.values, topk.indices)): + scales /= sum(scales) + input = inputs_merged[i, :] + for scale, expert in zip(scales, experts): + input = inputs_merged[i, :] + fc1_qd = ref_weight_1[expert].cuda().float() + fc1 = torch.matmul(input, fc1_qd) + fc1, gate = fc1.chunk(2, dim=-1) + fc1 = fc1 * torch.nn.functional.silu(gate) + + fc2_qd = ref_weight_2[expert].cuda().float() + final = torch.matmul(fc1, fc2_qd) + results[i] += scale * final + ref = results.view(*inputs.shape) + _utils.woq_assert_near_eq(ref, out, 1) + + @parameterized.expand([(1, 14336, 4096, 8, "float16"), + (1, 14336, 4096, 8, "bfloat16")], + name_func=unittest_name_func) + def test_moe_w8a16(self, m, n, k, experts, dtype): + self.quant_mode = QuantMode.use_weight_only(False, False) + self._woq_moe_matmul_per_channel(m, n, k, experts, dtype, torch.int8) + @parameterized.expand([(1, 14336, 4096, 8, "float16", True, True), (1, 14336, 4096, 8, "float16", True, False), (1, 14336, 4096, 8, "float16", False, True), @@ -305,7 +407,9 @@ def interleave_scales(scales: torch.Tensor, interleave_dim: int): (1, 14336, 4096, 8, "bfloat16", True, False), (1, 14336, 4096, 8, "bfloat16", False, True)], name_func=unittest_name_func) - def test_moe_w4a16(self, m, n, k, experts, dtype, has_pre_quant, has_zero): + def test_moe_w4a16_groupwise(self, m, n, k, experts, dtype, has_pre_quant, + has_zero): + self.quant_mode = QuantMode.use_weight_only(True, True) self._woq_moe_groupwise_matmul(m, n, k, experts, dtype, torch.quint4x2, has_pre_quant, has_zero, False) @@ -315,8 +419,9 @@ def test_moe_w4a16(self, m, n, k, experts, dtype, has_pre_quant, has_zero): (1, 14336, 4096, 8, "bfloat16", True, True)], name_func=unittest_name_func) @skip_neither_ada_nor_hopper_unittest - def test_moe_w4a8(self, m, n, k, experts, dtype, has_pre_quant, has_zero): - + def test_moe_w4a8_groupwise(self, m, n, k, experts, dtype, has_pre_quant, + has_zero): + self.quant_mode = QuantMode.use_weight_only(True, True) self._woq_moe_groupwise_matmul(m, n, k, experts, dtype, torch.quint4x2, has_pre_quant, has_zero, True)