Skip to content

Commit c26bf90

Browse files
committed
fix correctness issue
Signed-off-by: Yuening Li <[email protected]>
1 parent bcf6514 commit c26bf90

File tree

6 files changed

+213
-56
lines changed

6 files changed

+213
-56
lines changed

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,15 @@ class FusedMoeRunner : public torch::CustomClassHolder
122122
}
123123

124124
FusedMoeRunner(c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype,
125-
bool use_deepseek_fp8_block_scale, bool use_w4a8_group_scaling, bool use_woq_group_scaling,
126-
bool use_mxfp8_act_scaling)
125+
bool use_deepseek_fp8_block_scale, bool use_w4a8_group_scaling, bool use_woq_per_channel,
126+
bool use_woq_group_scaling, bool use_mxfp8_act_scaling)
127127
{
128128
mActivationDtype = activation_dtype;
129129
mWeightDtype = weight_dtype;
130130
mOutputDtype = output_dtype;
131131
mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale;
132132
mUseW4A8GroupScaling = use_w4a8_group_scaling;
133+
mUseWoqPerChannel = use_woq_per_channel;
133134
mUseWoqGroupScaling = use_woq_group_scaling;
134135
mUseMxfp8ActScaling = use_mxfp8_act_scaling;
135136
mInnerDimMultiplier = 1;
@@ -276,13 +277,27 @@ class FusedMoeRunner : public torch::CustomClassHolder
276277
}
277278
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0],
278279
"fc1_expert_weights and fc2_expert_weights must have the same number of experts.");
279-
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
280-
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
280+
281+
if (mUseWoqPerChannel)
282+
{
283+
TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier * 2,
284+
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
285+
}
286+
else
287+
{
288+
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
289+
"fc1_expert_weights inter size must be fc2_expert_weights inter size.");
290+
}
281291

282292
int experts_per_token = token_selected_experts.sizes()[1];
283293
int64_t num_rows = input.sizes()[0];
284294
int64_t hidden_size = fc2_expert_weights.sizes()[1];
285295
int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
296+
if (mUseWoqPerChannel)
297+
{
298+
hidden_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
299+
inter_size = fc2_expert_weights.sizes()[1];
300+
}
286301

287302
if (isWMxfp4AMxfp8Quant() || isWMxfp4AFp8Quant())
288303
{
@@ -506,9 +521,14 @@ class FusedMoeRunner : public torch::CustomClassHolder
506521
}
507522

508523
int64_t const num_rows = input.sizes()[0];
509-
int64_t const hidden_size = fc2_expert_weights.sizes()[1];
510-
int64_t const inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
511-
int64_t const group_size = mUseWoqGroupScaling ? 128 : -1;
524+
int64_t hidden_size = fc2_expert_weights.sizes()[1];
525+
int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
526+
if (mUseWoqPerChannel)
527+
{
528+
hidden_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
529+
inter_size = fc2_expert_weights.sizes()[1];
530+
}
531+
int64_t const group_size = mUseWoqGroupScaling or mUseW4A8GroupScaling ? 128 : -1;
512532
int const num_experts = static_cast<int>(fc2_expert_weights.sizes()[0] * ep_size);
513533

514534
// Get specific profile configs according to the profile_id.
@@ -585,6 +605,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
585605

586606
bool mUseDeepSeekFP8BlockScaling = false;
587607
bool mUseW4A8GroupScaling = false;
608+
bool mUseWoqPerChannel = false;
588609
bool mUseWoqGroupScaling = false;
589610
bool mUseMxfp8ActScaling = false;
590611

@@ -876,16 +897,15 @@ class FusedMoeRunner : public torch::CustomClassHolder
876897
else if (isWeightOnlyQuant())
877898
{
878899
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for weight only quantization");
879-
if (!mUseWoqGroupScaling)
900+
if (mUseWoqPerChannel)
880901
{
881902
TORCH_CHECK(quant_scales.value().size() == 2, "Expecting 2 quant scales for weight only quantization");
882903
auto& fc1_weight_scales = quant_scales.value()[0];
883904
auto& fc2_weight_scales = quant_scales.value()[1];
884905
return kernels::QuantParams::Int(static_cast<float const*>(fc1_weight_scales.data_ptr()),
885906
static_cast<float const*>(fc2_weight_scales.data_ptr()));
886907
}
887-
// TODO: support groupwise quantization for int8 weight only
888-
else if (isInt4Quant())
908+
else if (isInt4Quant() && mUseWoqGroupScaling)
889909
{
890910
TORCH_CHECK(quant_scales.value().size() == 8, "Expecting 8 quant scales for INT4 quantization");
891911
auto& fc1_weight_scales = quant_scales.value()[0];
@@ -968,7 +988,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
968988
TORCH_LIBRARY(trtllm, m)
969989
{
970990
m.class_<torch_ext::FusedMoeRunner>("FusedMoeRunner")
971-
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool, bool, bool>())
991+
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool, bool, bool, bool>())
972992
.def("run_gemm_profile", &torch_ext::FusedMoeRunner::runGemmProfile)
973993
.def("get_tactic_num", &torch_ext::FusedMoeRunner::getTacticNum)
974994
.def("run_moe", &torch_ext::FusedMoeRunner::runMoe)

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
cluster_rank: int,
4141
use_deepseek_fp8_block_scale: bool,
4242
use_w4a8_group_scaling: bool,
43+
use_woq_per_channel: bool,
4344
use_woq_group_scaling: bool,
4445
use_mxfp8_act_scaling: bool,
4546
min_latency_mode: bool,
@@ -58,19 +59,22 @@ def __init__(
5859
self.enable_alltoall = False
5960
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
6061
self.use_w4a8_group_scaling = use_w4a8_group_scaling
62+
self.use_woq_per_channel = use_woq_per_channel
6163
self.use_woq_group_scaling = use_woq_group_scaling
6264
self.use_mxfp8_act_scaling = use_mxfp8_act_scaling
6365
self.min_latency_mode = min_latency_mode
6466
instance_key = (x_dtype, weight_dtype, output_dtype,
6567
use_deepseek_fp8_block_scale, use_w4a8_group_scaling,
66-
use_woq_group_scaling, use_mxfp8_act_scaling)
68+
use_woq_per_channel, use_woq_group_scaling,
69+
use_mxfp8_act_scaling)
6770

6871
if instance_key not in MoERunner.runner_dict:
6972
MoERunner.runner_dict[
7073
instance_key] = torch.classes.trtllm.FusedMoeRunner(
7174
x_dtype, weight_dtype, output_dtype,
7275
use_deepseek_fp8_block_scale, use_w4a8_group_scaling,
73-
use_woq_group_scaling, use_mxfp8_act_scaling)
76+
use_woq_per_channel, use_woq_group_scaling,
77+
use_mxfp8_act_scaling)
7478
self.fused_moe_runner = MoERunner.runner_dict[instance_key]
7579

7680
def get_valid_tactics(
@@ -139,6 +143,7 @@ def fused_moe(
139143
enable_alltoall: bool = False,
140144
use_deepseek_fp8_block_scale: bool = False,
141145
use_w4a8_group_scaling: bool = False,
146+
use_woq_per_channel: bool = False,
142147
use_woq_group_scaling: bool = False,
143148
use_mxfp8_act_scaling: bool = False,
144149
min_latency_mode: bool = False,
@@ -176,6 +181,7 @@ def fused_moe(
176181
cluster_rank=cluster_rank,
177182
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
178183
use_w4a8_group_scaling=use_w4a8_group_scaling,
184+
use_woq_per_channel=use_woq_per_channel,
179185
use_woq_group_scaling=use_woq_group_scaling,
180186
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
181187
min_latency_mode=min_latency_mode,
@@ -249,13 +255,17 @@ def _(
249255
enable_alltoall: bool = False,
250256
use_deepseek_fp8_block_scale: bool = False,
251257
use_w4a8_group_scaling: bool = False,
258+
use_woq_per_channel: bool = False,
252259
use_woq_group_scaling: bool = False,
253260
use_mxfp8_act_scaling: bool = False,
254261
min_latency_mode: bool = False,
255262
tune_max_num_tokens: int = 8192,
256263
):
257264
seq_len = input.shape[0]
258-
hidden_size = fc2_expert_weights.shape[1]
265+
if use_woq_per_channel:
266+
hidden_size = fc2_expert_weights.shape[2]
267+
else:
268+
hidden_size = fc2_expert_weights.shape[1]
259269

260270
if min_latency_mode:
261271
num_experts_on_rank = fc2_expert_weights.shape[0]

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ def has_w4afp8(self):
144144
return self.quant_config and self.quant_config.quant_mode.is_int4_weight_only_per_group(
145145
)
146146

147+
@property
148+
def has_woq_per_channel(self):
149+
return self.quant_config.layer_quant_mode.is_weight_only(
150+
) and not self.quant_config.layer_quant_mode.has_per_group_scaling()
151+
147152
@property
148153
def has_woq_per_group_scaling(self):
149154
return self.quant_config.layer_quant_mode.is_weight_only(
@@ -161,9 +166,7 @@ def _get_quant_method(self):
161166
elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
162167
):
163168
return WInt4AFP8FusedMoEMethod()
164-
elif self.quant_config.layer_quant_mode.is_weight_only(
165-
) and not self.quant_config.layer_quant_mode.has_per_group_scaling(
166-
):
169+
elif self.has_woq_per_channel:
167170
return WeightOnlyFusedMoEMethod()
168171
else:
169172
raise ValueError(
@@ -234,6 +237,7 @@ def forward_chunk(
234237
# quantize inputs
235238
use_deepseek_fp8_block_scale = False
236239
use_w4a8_group_scaling = False
240+
use_woq_per_channel = False
237241
use_woq_group_scaling = False
238242
weight_dtype = self.w3_w1_weight.dtype
239243
x_sf = None
@@ -247,7 +251,8 @@ def forward_chunk(
247251
use_w4a8_group_scaling = True
248252
use_woq_group_scaling = True
249253
weight_dtype = torch.quint4x2
250-
# TODO: add support for weight only quantization with per group scaling
254+
elif self.has_woq_per_channel:
255+
use_woq_per_channel = True
251256
elif self.has_woq_per_group_scaling:
252257
use_woq_group_scaling = True
253258
elif self.has_nvfp4:
@@ -269,10 +274,10 @@ def forward_chunk(
269274
x, x_sf = torch.ops.trtllm.fp4_quantize(
270275
x, self.fc31_input_scale, self.scaling_vector_size,
271276
False, True)
272-
# else:
273-
# raise ValueError(
274-
# f"unsupported quantization mode: {self.quant_config.quant_mode}"
275-
# )
277+
else:
278+
raise ValueError(
279+
f"unsupported quantization mode: {self.quant_config.quant_mode}"
280+
)
276281

277282
# gather inputs for attention dp
278283
if run_post_quant_allgather:
@@ -312,6 +317,7 @@ def forward_chunk(
312317
enable_alltoall=self.enable_alltoall,
313318
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
314319
use_w4a8_group_scaling=use_w4a8_group_scaling,
320+
use_woq_per_channel=use_woq_per_channel,
315321
use_woq_group_scaling=use_woq_group_scaling,
316322
min_latency_mode=False,
317323
tune_max_num_tokens=self.tune_max_num_tokens,

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -671,9 +671,6 @@ def create_weights(self, module: torch.nn.Module):
671671
requires_grad=False)
672672
module.register_parameter("fc2_weight_scale", fc2_weight_scale)
673673

674-
print(f"fc31_weight_scale.shape: {fc31_weight_scale.shape}")
675-
print(f"fc2_weight_scale.shape: {fc2_weight_scale.shape}")
676-
677674
fc31_alpha = nn.Parameter(torch.empty(module.expert_size_per_partition,
678675
1,
679676
dtype=torch.float32),
@@ -866,27 +863,31 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
866863

867864

868865
class WeightOnlyFusedMoEMethod(FusedMoEMethodBase):
869-
"""
870-
Base class for Weight Only Quantization fused MoE methods.
871-
"""
872866

873867
def create_weights(self, module: torch.nn.Module):
868+
module.sm_version = get_sm_version()
869+
module.sm_version = 80 if module.sm_version >= 90 else module.sm_version
870+
module.preprocessor = preprocess_weights_for_mixed_gemm
871+
874872
weight_dtype = torch.int8
875873
# int4 weight are packed into int8
876874
if module.quant_config.layer_quant_mode.is_int8_weight_only():
877-
weight_id = 1
875+
pass
878876
elif module.quant_config.layer_quant_mode.is_int4_weight_only():
879-
weight_id = 2
877+
pass
880878
else:
881879
raise NotImplementedError(
882880
f"Weight Only Quantization is unsupported on {module.quant_config.layer_quant_mode}."
883881
)
884882

883+
# notice the weight shape for weight-only is different from the original shape,
884+
# since the quantized weights have their own layout
885885
w3_w1_weight_shape = (module.expert_size_per_partition,
886-
module.intermediate_size_per_partition * 2,
887-
module.hidden_size // weight_id)
888-
w2_weight_shape = (module.expert_size_per_partition, module.hidden_size,
889-
module.intermediate_size_per_partition // weight_id)
886+
module.hidden_size,
887+
module.intermediate_size_per_partition * 2)
888+
w2_weight_shape = (module.expert_size_per_partition,
889+
module.intermediate_size_per_partition,
890+
module.hidden_size)
890891

891892
fc31_weight_scale = nn.Parameter(torch.empty(
892893
module.expert_size_per_partition,
@@ -938,22 +939,22 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
938939
w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0)
939940

940941
# preprocess the weights for mixed gemm
941-
preprocessor = preprocess_weights_for_mixed_gemm
942942
if module.quant_config.layer_quant_mode.is_int8_weight_only():
943943
weight_dtype = torch.int8
944-
elif module.quant_config.layer_quant_mode.is_int4_weight_only():
945-
weight_dtype = torch.quint4x2
946-
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
947-
unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
948-
w31_weight_shard = packer(
949-
unpacker(w31_weight_shard.cpu()).T.contiguous()).to(
950-
w31_weight_shard.device)
944+
# elif module.quant_config.layer_quant_mode.is_int4_weight_only():
945+
# weight_dtype = torch.quint4x2
946+
# packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
947+
# unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
948+
# w31_weight_shard = packer(
949+
# unpacker(w31_weight_shard.cpu()).T.contiguous()).to(
950+
# w31_weight_shard.device)
951951

952952
assert module.dtype in [torch.float16, torch.bfloat16], \
953953
f"activation dtype should be float16 or bfloat16, got {module.dtype}"
954-
w31_weight_shard = preprocessor(w31_weight_shard, weight_dtype,
955-
module.dtype).view(
956-
dst_w3_w1_weight.shape)
954+
955+
w31_weight_shard = module.preprocessor(w31_weight_shard.T.contiguous(),
956+
weight_dtype, module.dtype,
957+
module.sm_version).contiguous()
957958
dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype),
958959
non_blocking=True)
959960

@@ -968,22 +969,22 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
968969
TensorParallelMode.ROW)
969970

970971
# preprocess the weights for mixed gemm
971-
preprocessor = preprocess_weights_for_mixed_gemm
972972
if module.quant_config.layer_quant_mode.is_int8_weight_only():
973973
weight_dtype = torch.int8
974-
elif module.quant_config.layer_quant_mode.is_int4_weight_only():
975-
weight_dtype = torch.quint4x2
976-
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
977-
unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
978-
w2_weight_shard = packer(
979-
unpacker(w2_weight_shard.cpu()).T.contiguous()).to(
980-
w2_weight_shard.device)
974+
# elif module.quant_config.layer_quant_mode.is_int4_weight_only():
975+
# weight_dtype = torch.quint4x2
976+
# packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
977+
# unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
978+
# w31_weight_shard = packer(
979+
# unpacker(w31_weight_shard.cpu()).T.contiguous()).to(
980+
# w31_weight_shard.device)
981981

982982
assert module.dtype in [torch.float16, torch.bfloat16], \
983983
f"activation dtype should be float16 or bfloat16, got {module.dtype}"
984-
w2_weight_shard = preprocessor(w2_weight_shard, weight_dtype,
985-
module.dtype).view(dst_w2_weight.shape)
986984

985+
w2_weight_shard = module.preprocessor(w2_weight_shard.T.contiguous(),
986+
weight_dtype, module.dtype,
987+
module.sm_version).contiguous()
987988
dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype),
988989
non_blocking=True)
989990

tests/unittest/_torch/helpers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ def calc_diff(x, y):
7575
return 1 - sim
7676

7777

78+
def calc_woq_tolerence(x: torch.Tensor, weight_dtype: torch.dtype):
79+
if weight_dtype == torch.int8:
80+
bits_in_type = 8
81+
elif weight_dtype == torch.quint4x2:
82+
bits_in_type = 4
83+
quant_range_scale = 1.0 / float(1 << (bits_in_type - 1))
84+
max_val = torch.max(abs(x)).item()
85+
atol = (max_val * quant_range_scale) * 1.5 # allow for rounding
86+
87+
return atol
88+
89+
7890
def reference_moe_torch(x: torch.Tensor, selected_experts: torch.Tensor,
7991
final_scales: torch.Tensor, num_experts: int,
8092
weights: Dict[str, torch.Tensor]) -> torch.Tensor:

0 commit comments

Comments
 (0)