From 2cf31b59246f00cad5d5b06041fc3104a6b4df20 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Mon, 11 Aug 2025 15:00:51 +0800 Subject: [PATCH 1/2] relax tensor device type check to fix wideEP loading and fix argument Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- .../_torch/modules/fused_moe/fused_moe_wide_ep.py | 2 +- tensorrt_llm/_torch/modules/fused_moe/quantization.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 355ca3cccfc..241c276865a 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -533,7 +533,7 @@ def forward_chunk( self.fc31_input_scale, self.scaling_vector_size, sfUseUE8M0=False, - swizzedLayout=False) + isSfSwizzledLayout=False) x_sf = x_sf.view((x_row, -1)) elif self.has_deepseek_fp8_block_scales: diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index ca373c2ed18..a7c7aa89af8 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -348,7 +348,7 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, Override this method if you need to preprocess the weights differently. """ device = dst_w3_w1_weight.device - assert device.type == "cuda" + # device don't have to be 'cuda', e.g. 'cpu' for online EPLB w1_weight_shard = load_weight_shard(w1_weight, module.tp_size, module.tp_rank, @@ -373,7 +373,7 @@ def load_expert_w2_weight(self, module: torch.nn.Module, Override this method if you need to preprocess the weights differently. """ device = dst_w2_weight.device - assert device.type == "cuda" + # device don't have to be 'cuda', e.g. 'cpu' for online EPLB w2_weight_shard = load_weight_shard(w2_weight, module.tp_size, module.tp_rank, @@ -1538,7 +1538,7 @@ def load_expert_w3_w1_weight_scale_nvfp4( w3_weight_scale: torch.Tensor, dst_w3_w1_weight_scale: torch.Tensor): device = dst_w3_w1_weight_scale.device - assert device.type == "cuda" + # device don't have to be 'cuda', e.g. 'cpu' for online EPLB w1_weight_scale = load_weight_shard(w1_weight_scale, module.tp_size, module.tp_rank, @@ -1578,7 +1578,7 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module, w2_weight_scale: torch.Tensor, dst_w2_weight_scale: torch.Tensor): device = dst_w2_weight_scale.device - assert device.type == "cuda" + # device don't have to be 'cuda', e.g. 'cpu' for online EPLB w2_weight_scale = load_weight_shard(w2_weight_scale, module.tp_size, module.tp_rank, From 0e8540ced0ddd6010a36c0023e17f023b0192570 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Tue, 12 Aug 2025 10:21:20 +0800 Subject: [PATCH 2/2] move comment forward Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- tensorrt_llm/_torch/modules/fused_moe/quantization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index a7c7aa89af8..c1cb71ad81f 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -347,8 +347,8 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, Load w1 and w3 weights for each expert. Override this method if you need to preprocess the weights differently. """ - device = dst_w3_w1_weight.device # device don't have to be 'cuda', e.g. 'cpu' for online EPLB + device = dst_w3_w1_weight.device w1_weight_shard = load_weight_shard(w1_weight, module.tp_size, module.tp_rank, @@ -372,8 +372,8 @@ def load_expert_w2_weight(self, module: torch.nn.Module, Load w2 weight for each expert. Override this method if you need to preprocess the weights differently. """ - device = dst_w2_weight.device # device don't have to be 'cuda', e.g. 'cpu' for online EPLB + device = dst_w2_weight.device w2_weight_shard = load_weight_shard(w2_weight, module.tp_size, module.tp_rank, @@ -1537,8 +1537,8 @@ def load_expert_w3_w1_weight_scale_nvfp4( self, module: torch.nn.Module, w1_weight_scale: torch.Tensor, w3_weight_scale: torch.Tensor, dst_w3_w1_weight_scale: torch.Tensor): - device = dst_w3_w1_weight_scale.device # device don't have to be 'cuda', e.g. 'cpu' for online EPLB + device = dst_w3_w1_weight_scale.device w1_weight_scale = load_weight_shard(w1_weight_scale, module.tp_size, module.tp_rank, @@ -1577,8 +1577,8 @@ def load_expert_w3_w1_weight_scale_nvfp4( def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module, w2_weight_scale: torch.Tensor, dst_w2_weight_scale: torch.Tensor): - device = dst_w2_weight_scale.device # device don't have to be 'cuda', e.g. 'cpu' for online EPLB + device = dst_w2_weight_scale.device w2_weight_scale = load_weight_shard(w2_weight_scale, module.tp_size, module.tp_rank,