Skip to content

Commit 0160ff0

Browse files
committed
Add runtime swap AB for SM100 blockwise GEMM
Signed-off-by: Barry Kang <[email protected]>
1 parent 50e5e72 commit 0160ff0

File tree

3 files changed

+144
-40
lines changed

3 files changed

+144
-40
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@
4545
from tensorrt_llm.llmapi.utils import enable_llm_debug
4646
from tensorrt_llm.mapping import Mapping
4747
from tensorrt_llm.models.modeling_utils import QuantConfig
48-
from tensorrt_llm.quantization.utils.fp8_utils import (
49-
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
48+
from tensorrt_llm.quantization.utils.fp8_utils import \
49+
transform_sf_into_required_layout
5050

5151
from ..attention_backend import AttentionMetadata
5252
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
@@ -1384,14 +1384,29 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
13841384
None) is not None and not dequant_kv_b_proj:
13851385
kv_b_proj_scale, k_b_proj_trans_scale = load_kv_b_proj_and_k_b_proj_trans(
13861386
name, is_scale=True)
1387-
module.weight_scale.copy_(
1388-
kv_b_proj_scale.reshape(module.weight_scale.shape))
13891387
attn_module.k_b_proj_trans_scale.copy_(
13901388
k_b_proj_trans_scale.reshape(
13911389
attn_module.k_b_proj_trans_scale.shape))
13921390

1393-
_, v_b_proj_scale = split_kv_b_proj(
1394-
module.weight_scale.data, is_scale=True)
1391+
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
1392+
) and get_sm_version() == 100:
1393+
_, v_b_proj_scale = split_kv_b_proj(kv_b_proj_scale,
1394+
is_scale=True)
1395+
kv_b_proj_scale = transform_sf_into_required_layout(
1396+
kv_b_proj_scale,
1397+
mn=kv_b_proj.shape[0],
1398+
k=kv_b_proj.shape[1],
1399+
recipe=(1, 128, 128),
1400+
is_sfa=False)
1401+
module.weight_scale.copy_(
1402+
kv_b_proj_scale.reshape(
1403+
module.weight_scale.shape))
1404+
else:
1405+
module.weight_scale.copy_(
1406+
kv_b_proj_scale.reshape(
1407+
module.weight_scale.shape))
1408+
_, v_b_proj_scale = split_kv_b_proj(
1409+
module.weight_scale.data, is_scale=True)
13951410
attn_module.v_b_proj_scale = nn.Parameter(
13961411
v_b_proj_scale, requires_grad=False)
13971412

@@ -1432,6 +1447,14 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
14321447
fused_a_scale = torch.cat(
14331448
[q_a_proj_scale, fused_a_scale], dim=0)
14341449

1450+
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
1451+
) and get_sm_version() == 100:
1452+
fused_a_scale = transform_sf_into_required_layout(
1453+
fused_a_scale,
1454+
mn=fused_a.shape[0],
1455+
k=fused_a.shape[1],
1456+
recipe=(1, 128, 128),
1457+
is_sfa=False)
14351458
module.weight_scale.data.copy_(fused_a_scale)
14361459

14371460
module.weight.data.copy_(fused_a)
@@ -1462,21 +1485,6 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
14621485
for n, p in module.named_parameters():
14631486
p.data.copy_(module_weights[n][:])
14641487

1465-
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
1466-
) and get_sm_version() == 100 and hasattr(
1467-
module, "weight_scale"):
1468-
weight, weight_scale = resmooth_to_fp8_e8m0(
1469-
module.weight, module.weight_scale)
1470-
transfromed_scale = transform_sf_into_required_layout(
1471-
weight_scale,
1472-
mn=weight.shape[0],
1473-
k=weight.shape[1],
1474-
recipe=(1, 128, 128),
1475-
is_sfa=False)
1476-
module.weight = nn.Parameter(weight, requires_grad=False)
1477-
module.weight_scale = nn.Parameter(transfromed_scale,
1478-
requires_grad=False)
1479-
14801488
for idx, layer in enumerate(
14811489
self.model.layers[:self.config.num_hidden_layers]):
14821490
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/modules/linear.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -558,11 +558,19 @@ def create_weights(self, module: Linear, in_features: int,
558558
module.weight = Parameter(torch.empty(weight_shape,
559559
dtype=torch.float8_e4m3fn),
560560
requires_grad=False)
561-
scale_shape = (math.ceil(out_features / 128),
562-
math.ceil(in_features / 128))
563-
module.weight_scale = Parameter(torch.empty(scale_shape,
564-
dtype=torch.float32),
565-
requires_grad=False)
561+
562+
if get_sm_version() == 100:
563+
scale_shape = (math.ceil(in_features / 512),
564+
math.ceil(out_features))
565+
module.weight_scale = Parameter(torch.empty(scale_shape,
566+
dtype=torch.int32).T,
567+
requires_grad=False)
568+
else:
569+
scale_shape = (math.ceil(out_features / 128),
570+
math.ceil(in_features / 128))
571+
module.weight_scale = Parameter(torch.empty(scale_shape,
572+
dtype=torch.float32),
573+
requires_grad=False)
566574
# Not really used for Gemm now.
567575
# Only used to quantize output of FP8 attention.
568576
module.input_scale = Parameter(torch.tensor(1., dtype=torch.float32),
@@ -592,14 +600,30 @@ def apply(self, module: Linear, input: torch.Tensor,
592600
module.weight_scale)
593601
else:
594602
from tensorrt_llm import deep_gemm
595-
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
596-
output = torch.empty((input.shape[0], module.weight.shape[0]),
597-
device=input.device,
598-
dtype=torch.bfloat16)
599-
deep_gemm.fp8_gemm_nt((a, a_sf),
600-
(module.weight, module.weight_scale),
601-
output,
602-
disable_ue8m0_cast=True)
603+
if input.shape[0] < 32:
604+
# Swap AB
605+
a, a_sf = fp8_utils.per_token_quant_and_transform(
606+
input, swap_ab=True)
607+
output_padded = torch.empty(
608+
(module.weight.shape[0], a.shape[0]),
609+
device=input.device,
610+
dtype=torch.bfloat16)
611+
deep_gemm.fp8_gemm_nt((module.weight, module.weight_scale),
612+
(a, a_sf),
613+
output_padded,
614+
disable_ue8m0_cast=True)
615+
output = fp8_utils.masked_transpose(output_padded,
616+
input.shape[0])
617+
else:
618+
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
619+
output = torch.empty(
620+
(input.shape[0], module.weight.shape[0]),
621+
device=input.device,
622+
dtype=torch.bfloat16)
623+
deep_gemm.fp8_gemm_nt((a, a_sf),
624+
(module.weight, module.weight_scale),
625+
output,
626+
disable_ue8m0_cast=True)
603627
else:
604628
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
605629
input)
@@ -625,6 +649,13 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
625649
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
626650
module.tp_rank,
627651
module.tp_mode).squeeze()
652+
if get_sm_version() == 100:
653+
weight_scale = fp8_utils.transform_sf_into_required_layout(
654+
weight_scale,
655+
mn=module.weight.shape[0],
656+
k=module.weight.shape[1],
657+
recipe=(1, 128, 128),
658+
is_sfa=False)
628659
copy_weight(module.weight_scale, weight_scale)
629660
if "input_scale" in weights[0]:
630661
copy_weight(module.input_scale, weights[0]["input_scale"])
@@ -661,6 +692,13 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
661692
module.tp_rank, module.tp_mode)
662693
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
663694
copy_weight(module.weight, fused_weight)
695+
if get_sm_version() == 100:
696+
fused_scale = fp8_utils.transform_sf_into_required_layout(
697+
fused_scale,
698+
mn=fused_weight.shape[0],
699+
k=fused_weight.shape[1],
700+
recipe=(1, 128, 128),
701+
is_sfa=False)
664702
copy_weight(module.weight_scale, fused_scale)
665703

666704

tensorrt_llm/quantization/utils/fp8_utils.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ def per_token_quant_and_transform(
448448
input: torch.Tensor,
449449
quant_group_size: int = 128,
450450
scale_ue8m0: bool = True,
451+
swap_ab=False,
451452
):
452453
"""
453454
input shape [g, m, k]
@@ -467,18 +468,21 @@ def per_token_quant_and_transform(
467468
fp8_min = -fp8_max
468469

469470
m, k = input.shape
471+
m_padded = m if not swap_ab else align(m, 8)
470472

471473
# Create output
472-
output = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda")
474+
output = torch.empty((m_padded, k),
475+
dtype=torch.float8_e4m3fn,
476+
device="cuda")
473477

474478
# Create output scale
475479
alignment = 4
476480
scale_k = ceil_div(k, quant_group_size)
477-
m_padded = align(m, alignment)
481+
m_aligned = align(m_padded, alignment)
478482
scale_k_padded = align(scale_k, alignment)
479-
output_scale = torch.zeros((scale_k_padded // 4, m_padded),
483+
output_scale = torch.empty((scale_k_padded // 4, m_aligned),
480484
dtype=torch.int32,
481-
device='cuda')
485+
device="cuda")
482486

483487
# Get block/grid/stage/warp
484488
BLOCK_NUM_PER_EXPERT = 64
@@ -508,13 +512,67 @@ def per_token_quant_and_transform(
508512
num_warps=num_warps,
509513
SCALE_UE8M0=scale_ue8m0,
510514
)
511-
output_scale = output_scale.transpose(0, 1)[:m, :]
515+
output_scale = output_scale.transpose(0, 1)[:m_padded, :]
512516
check_sf_layout(
513517
output_scale,
514-
m,
518+
m_padded,
515519
k,
516520
(1, 128),
517521
num_groups=None,
518522
tma_stride_check=True,
519523
)
520524
return output, output_scale
525+
526+
527+
@triton.jit
528+
def _transpose_kernel(input_ptr, output_ptr, M, N, stride_in_m, stride_in_n,
529+
stride_out_m, stride_out_n, BLOCK_SIZE: tl.constexpr):
530+
row_block = tl.program_id(0)
531+
col_block = tl.program_id(1)
532+
533+
row = row_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
534+
col = col_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
535+
536+
mask_row = row < M
537+
mask_col = col < N
538+
mask = mask_row[:, None] & mask_col[None, :]
539+
540+
input_idx = row[:, None] * stride_in_m + col[None, :] * stride_in_n
541+
data = tl.load(input_ptr + input_idx, mask=mask, other=0)
542+
543+
output_idx = row[:, None] * stride_out_n + col[None, :] * stride_out_m
544+
tl.store(output_ptr + output_idx, data, mask=mask)
545+
546+
547+
def masked_transpose(input: torch.Tensor, n_available: int) -> torch.Tensor:
548+
"""
549+
Perform a masked transpose operation on a 2D tensor.
550+
551+
Args:
552+
input: Input tensor of shape (M, N)
553+
n_available: Number of columns to transpose (must be <= N)
554+
555+
Returns:
556+
Transposed tensor of shape (n_available, M)
557+
"""
558+
M, N = input.shape
559+
assert n_available <= N, "n_available must be less than or equal to N"
560+
BLOCK_SIZE = 32
561+
output = torch.empty((n_available, M),
562+
dtype=input.dtype,
563+
device=input.device)
564+
565+
grid = ((M + BLOCK_SIZE - 1) // BLOCK_SIZE,
566+
(n_available + BLOCK_SIZE - 1) // BLOCK_SIZE)
567+
_transpose_kernel[grid](
568+
input,
569+
output,
570+
M,
571+
n_available,
572+
input.stride(0),
573+
input.stride(1),
574+
output.stride(0),
575+
output.stride(1),
576+
BLOCK_SIZE=BLOCK_SIZE,
577+
)
578+
return output

0 commit comments

Comments
 (0)