Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 115 additions & 66 deletions cpp/tensorrt_llm/thop/moeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,40 @@ class FusedMoeRunner : public torch::CustomClassHolder
}
};

template <typename TypeAct>
std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> create_weight_quant_runner()
{
if (isInt8Quant())
{
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, uint8_t>>();
}
else if (isInt4Quant())
{
#ifdef ENABLE_FP8
if (mUseW4GroupScaling)
{
return std::make_unique<
kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, TypeAct, TypeAct>>();
}
#endif
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, cutlass::uint4b_t>>();
}
else
{
C10_THROW_ERROR_FORMATTED(Error, "Unsupported weight quantization type");
}
}

FusedMoeRunner(c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype,
bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_mxfp8_act_scaling,
bool use_fused_finalize)
bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_int8_woq_per_channel,
bool use_mxfp8_act_scaling, bool use_fused_finalize)
{
mActivationDtype = activation_dtype;
mWeightDtype = weight_dtype;
mOutputDtype = output_dtype;
mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale;
mUseW4GroupScaling = use_w4_group_scaling;
mUseINT8WoqPerChannel = use_int8_woq_per_channel;
mUseMxfp8ActScaling = use_mxfp8_act_scaling;
mUseFusedFinalize = use_fused_finalize;
mInnerDimMultiplier = 1;
Expand Down Expand Up @@ -137,7 +162,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp4_e2m1>(mOutputDtype);
}

if (isNvfp4Quant())
{
mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG
Expand All @@ -152,7 +176,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
default: mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, false>(mOutputDtype);
}
}

if (isWFP4A16Quant())
{
mInnerDimMultiplier = 2;
Expand All @@ -167,45 +190,19 @@ class FusedMoeRunner : public torch::CustomClassHolder
}
#endif
}

#endif
if (isInt4Quant())
if (isIntWeightOnlyQuant())
{
mInnerDimMultiplier = 2; // 2 INT4 -> 1 INT8
if (mActivationDtype == c10::ScalarType::Half)
if (isInt4Quant())
{
#ifdef ENABLE_FP8
if (mUseW4GroupScaling)
{
mKernelRunner
= std::make_unique<kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>>();
}
else
{
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t>>();
}
#else
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t>>();
#endif
mInnerDimMultiplier = 2; // 2 INT4 -> 1 INT8
}
#ifdef ENABLE_BF16
else if (mActivationDtype == c10::ScalarType::BFloat16)
switch (mActivationDtype)
{
#ifdef ENABLE_FP8
if (mUseW4GroupScaling)
{
mKernelRunner = std::make_unique<
kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>>();
}
else
{
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>>();
}
#else
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>>();
#endif
case c10::ScalarType::Half: mKernelRunner = create_weight_quant_runner<half>(); break;
case c10::ScalarType::BFloat16: mKernelRunner = create_weight_quant_runner<__nv_bfloat16>(); break;
default: C10_THROW_ERROR_FORMATTED(Error, "Unsupported activation type for int-type weight");
}
#endif
}
if (!mKernelRunner)
{
Expand Down Expand Up @@ -310,13 +307,31 @@ class FusedMoeRunner : public torch::CustomClassHolder
}
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0],
"fc1_expert_weights and fc2_expert_weights must have the same number of experts.");
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");

if (mUseINT8WoqPerChannel)
{
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
// [num_experts, inter_size, hidden_size]
TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier * 2,
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
}
else
{
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
"fc1_expert_weights inter size must be fc2_expert_weights inter size.");
}

int experts_per_token = token_selected_experts.sizes()[1];
int64_t num_rows = input.sizes()[0];
int64_t hidden_size = fc2_expert_weights.sizes()[1];
int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
if (mUseINT8WoqPerChannel)
{
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
// [num_experts, inter_size, hidden_size]
hidden_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
inter_size = fc2_expert_weights.sizes()[1];
}

if (isWMxfp4AMxfp8Quant() || isWMxfp4AFp8Quant())
{
Expand Down Expand Up @@ -593,8 +608,15 @@ class FusedMoeRunner : public torch::CustomClassHolder
}

int64_t const num_rows = input.sizes()[0];
int64_t const hidden_size = fc2_expert_weights.sizes()[1];
int64_t const inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
int64_t hidden_size = fc2_expert_weights.sizes()[1];
int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
if (mUseINT8WoqPerChannel)
{
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
// [num_experts, inter_size, hidden_size]
hidden_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
inter_size = fc2_expert_weights.sizes()[1];
}
int64_t const group_size_
= isInt4Quant() ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size : -1;
int64_t const group_size = isWFP4A16Quant()
Expand Down Expand Up @@ -677,6 +699,7 @@ class FusedMoeRunner : public torch::CustomClassHolder

bool mUseDeepSeekFP8BlockScaling = false;
bool mUseW4GroupScaling = false;
bool mUseINT8WoqPerChannel = false;
bool mUseMxfp8ActScaling = false;
bool mUseFusedFinalize = true;

Expand Down Expand Up @@ -891,7 +914,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
TORCH_CHECK(false, "MXFP8 x MXFP4 quantization is not supported in OSS Cutlass Moe Gemm");
#endif
}

else if (isNvfp4Quant())
{
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for nvfp4 quantization");
Expand Down Expand Up @@ -966,8 +988,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
}
else if (isWFP4A16Quant())
{
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for W4 quantization");
TORCH_CHECK(quant_scales.value().size() == 2, "Expecting 8 quant scales for W4A16 quantization");
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for weight only quantization");
TORCH_CHECK(quant_scales.value().size() == 2, "Expecting 2 quant scales for W4A16 quantization");

auto& fc1_weight_scales = quant_scales.value()[0];
auto& fc2_weight_scales = quant_scales.value()[1];
Expand All @@ -976,28 +998,45 @@ class FusedMoeRunner : public torch::CustomClassHolder
static_cast<void const*>(fc2_weight_scales.data_ptr()), nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr);
}
else if (isInt4Quant())
else if (isIntWeightOnlyQuant())
{
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for W4 quantization");
TORCH_CHECK(quant_scales.value().size() == 8, "Expecting 8 quant scales for W4A8 quantization");

auto& fc1_weight_scales = quant_scales.value()[0];
auto& fc2_weight_scales = quant_scales.value()[1];
auto& fc1_act_scales = quant_scales.value()[2];
auto& fc2_act_scales = quant_scales.value()[3];
auto& fc1_weight_zeros = quant_scales.value()[4];
auto& fc2_weight_zeros = quant_scales.value()[5];
auto& fc1_alpha = quant_scales.value()[6];
auto& fc2_alpha = quant_scales.value()[7];
int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
return kernels::QuantParams::GroupWise(group_size, static_cast<void const*>(fc1_weight_scales.data_ptr()),
static_cast<void const*>(fc2_weight_scales.data_ptr()),
static_cast<void const*>(fc1_act_scales.numel() > 0 ? fc1_act_scales.data_ptr() : nullptr),
static_cast<void const*>(fc2_act_scales.numel() > 0 ? fc2_act_scales.data_ptr() : nullptr),
static_cast<void const*>(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() : nullptr),
static_cast<void const*>(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() : nullptr),
static_cast<float const*>(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr),
static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr));
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for weight only quantization");
if (mUseINT8WoqPerChannel)
{
TORCH_CHECK(
quant_scales.value().size() == 2, "Expecting 2 quant scales for INT8 weight only quantization");
auto& fc1_weight_scales = quant_scales.value()[0];
auto& fc2_weight_scales = quant_scales.value()[1];
return kernels::QuantParams::Int(static_cast<float const*>(fc1_weight_scales.data_ptr()),
static_cast<float const*>(fc2_weight_scales.data_ptr()));
}
else if (isInt4Quant() && mUseW4GroupScaling)
{
TORCH_CHECK(quant_scales.value().size() == 8, "Expecting 8 quant scales for W4A8 quantization");

auto& fc1_weight_scales = quant_scales.value()[0];
auto& fc2_weight_scales = quant_scales.value()[1];
auto& fc1_act_scales = quant_scales.value()[2];
auto& fc2_act_scales = quant_scales.value()[3];
auto& fc1_weight_zeros = quant_scales.value()[4];
auto& fc2_weight_zeros = quant_scales.value()[5];
auto& fc1_alpha = quant_scales.value()[6];
auto& fc2_alpha = quant_scales.value()[7];
int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
return kernels::QuantParams::GroupWise(group_size,
static_cast<void const*>(fc1_weight_scales.data_ptr()),
static_cast<void const*>(fc2_weight_scales.data_ptr()),
static_cast<void const*>(fc1_act_scales.numel() > 0 ? fc1_act_scales.data_ptr() : nullptr),
static_cast<void const*>(fc2_act_scales.numel() > 0 ? fc2_act_scales.data_ptr() : nullptr),
static_cast<void const*>(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() : nullptr),
static_cast<void const*>(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() : nullptr),
static_cast<float const*>(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr),
static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr));
}
else
{
TORCH_CHECK(false, "Unsupported weight only quantization");
}
}
else
{
Expand All @@ -1022,6 +1061,11 @@ class FusedMoeRunner : public torch::CustomClassHolder
return mUseW4GroupScaling && mWeightDtype == c10::ScalarType::Byte;
}

bool isInt8Quant() const
{
return mWeightDtype == c10::ScalarType::Char;
}

bool isInt4Quant() const
{
return mWeightDtype == c10::ScalarType::QUInt4x2;
Expand All @@ -1032,6 +1076,11 @@ class FusedMoeRunner : public torch::CustomClassHolder
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant();
}

bool isIntWeightOnlyQuant() const
{
return isInt8Quant() || isInt4Quant();
}

bool isWMxfp4AFp8Quant() const
{
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && mWeightDtype == c10::ScalarType::Long
Expand All @@ -1050,7 +1099,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
TORCH_LIBRARY(trtllm, m)
{
m.class_<torch_ext::FusedMoeRunner>("FusedMoeRunner")
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool, bool, bool>())
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool, bool, bool, bool>())
.def("run_gemm_profile", &torch_ext::FusedMoeRunner::runGemmProfile)
.def("get_tactic_num", &torch_ext::FusedMoeRunner::getTacticNum)
.def("run_moe", &torch_ext::FusedMoeRunner::runMoe)
Expand Down
17 changes: 14 additions & 3 deletions tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
cluster_rank: int,
use_deepseek_fp8_block_scale: bool,
use_w4_group_scaling: bool,
use_int8_woq_per_channel: bool,
use_mxfp8_act_scaling: bool,
min_latency_mode: bool,
use_fused_finalize: bool,
Expand All @@ -61,20 +62,22 @@ def __init__(
self.enable_alltoall = False
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
self.use_w4_group_scaling = use_w4_group_scaling
self.use_int8_woq_per_channel = use_int8_woq_per_channel
self.use_mxfp8_act_scaling = use_mxfp8_act_scaling
self.min_latency_mode = min_latency_mode
self.use_fused_finalize = use_fused_finalize

instance_key = (x_dtype, weight_dtype, output_dtype,
use_deepseek_fp8_block_scale, use_w4_group_scaling,
use_mxfp8_act_scaling)
use_int8_woq_per_channel, use_mxfp8_act_scaling)

if instance_key not in MoERunner.runner_dict:
MoERunner.runner_dict[
instance_key] = torch.classes.trtllm.FusedMoeRunner(
x_dtype, weight_dtype, output_dtype,
use_deepseek_fp8_block_scale, use_w4_group_scaling,
use_mxfp8_act_scaling, use_fused_finalize)
use_int8_woq_per_channel, use_mxfp8_act_scaling,
use_fused_finalize)
self.fused_moe_runner = MoERunner.runner_dict[instance_key]

def get_valid_tactics(
Expand Down Expand Up @@ -138,6 +141,7 @@ def fused_moe(
enable_alltoall: bool = False,
use_deepseek_fp8_block_scale: bool = False,
use_w4_group_scaling: bool = False,
use_int8_woq_per_channel: bool = False,
use_mxfp8_act_scaling: bool = False,
min_latency_mode: bool = False,
use_fused_finalize: bool = True,
Expand Down Expand Up @@ -174,6 +178,7 @@ def fused_moe(
cluster_rank=cluster_rank,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
use_w4_group_scaling=use_w4_group_scaling,
use_int8_woq_per_channel=use_int8_woq_per_channel,
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
min_latency_mode=min_latency_mode,
use_fused_finalize=use_fused_finalize,
Expand Down Expand Up @@ -257,13 +262,19 @@ def _(
enable_alltoall: bool = False,
use_deepseek_fp8_block_scale: bool = False,
use_w4_group_scaling: bool = False,
use_int8_woq_per_channel: bool = False,
use_mxfp8_act_scaling: bool = False,
min_latency_mode: bool = False,
use_fused_finalize: bool = True,
tune_max_num_tokens: int = 8192,
):
seq_len = input.shape[0]
hidden_size = fc2_expert_weights.shape[1]
if use_int8_woq_per_channel:
# Note: The weight shape for INT8 weight only quantization is different, i.e.,
# fc2_expert_weights: [num_experts, inter_size, hidden_size]
hidden_size = fc2_expert_weights.shape[2]
else:
hidden_size = fc2_expert_weights.shape[1]

if min_latency_mode:
num_experts_on_rank = fc2_expert_weights.shape[0]
Expand Down
Loading