|
64 | 64 | FP8_DTYPE = current_platform.fp8_dtype()
|
65 | 65 | MiB = 1024 * 1024
|
66 | 66 |
|
67 |
| -# FlashInfer max sizes per world size (from collective_fusion.py) |
| 67 | +# FlashInfer max sizes per world size |
| 68 | +# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes |
| 69 | +# use --disable-oneshot to disable oneshot mode for very large input sizes |
68 | 70 | _FI_MAX_SIZES = {
|
69 | 71 | 2: 64 * MiB, # 64MB
|
70 |
| - 4: 32 * MiB, # 32MB |
71 |
| - 6: 32 * MiB, # 32MB |
72 |
| - 8: 32 * MiB, # 32MB |
| 72 | + 4: 64 * MiB, # 64MB |
| 73 | + 8: 64 * MiB, # 64MB |
73 | 74 | }
|
74 | 75 |
|
75 | 76 | # Global workspace tensor for FlashInfer
|
@@ -186,7 +187,7 @@ def flashinfer_fused_allreduce_rmsnorm(
|
186 | 187 | allreduce_out=None,
|
187 | 188 | quant_out=None,
|
188 | 189 | scale_out=None,
|
189 |
| - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED, |
| 190 | + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4_, |
190 | 191 | scale_factor=None,
|
191 | 192 | use_oneshot=use_oneshot,
|
192 | 193 | **allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
@@ -228,7 +229,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
|
228 | 229 | allreduce_out=None,
|
229 | 230 | quant_out=quant_out,
|
230 | 231 | scale_out=None,
|
231 |
| - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED, |
| 232 | + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, |
232 | 233 | scale_factor=scale_factor,
|
233 | 234 | use_oneshot=use_oneshot,
|
234 | 235 | **allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
@@ -271,7 +272,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
|
271 | 272 | allreduce_out=None,
|
272 | 273 | quant_out=quant_out,
|
273 | 274 | scale_out=output_scale,
|
274 |
| - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED, |
| 275 | + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, |
275 | 276 | scale_factor=input_global_scale,
|
276 | 277 | use_oneshot=use_oneshot,
|
277 | 278 | **allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
@@ -579,6 +580,7 @@ def run_benchmarks(
|
579 | 580 | use_residual: bool,
|
580 | 581 | allreduce_params: Optional[FlashInferFusedAllReduceParams],
|
581 | 582 | quant_mode: str = "all",
|
| 583 | + disable_oneshot: bool = False, |
582 | 584 | ):
|
583 | 585 | """Run all benchmarks for given configuration.
|
584 | 586 |
|
@@ -638,17 +640,18 @@ def run_benchmarks(
|
638 | 640 | # FlashInfer Fused AllReduce + RMSNorm Oneshot
|
639 | 641 | if flashinfer_comm is not None and allreduce_params is not None:
|
640 | 642 | try:
|
641 |
| - time_ms = benchmark_operation( |
642 |
| - flashinfer_fused_allreduce_rmsnorm, |
643 |
| - input_tensor, |
644 |
| - residual=residual, |
645 |
| - norm_out=norm_out, |
646 |
| - rms_gamma=rms_gamma, |
647 |
| - rms_eps=rms_eps, |
648 |
| - allreduce_params=allreduce_params, |
649 |
| - use_oneshot=True, |
650 |
| - ) |
651 |
| - results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms |
| 643 | + if not disable_oneshot: |
| 644 | + time_ms = benchmark_operation( |
| 645 | + flashinfer_fused_allreduce_rmsnorm, |
| 646 | + input_tensor, |
| 647 | + residual=residual, |
| 648 | + norm_out=norm_out, |
| 649 | + rms_gamma=rms_gamma, |
| 650 | + rms_eps=rms_eps, |
| 651 | + allreduce_params=allreduce_params, |
| 652 | + use_oneshot=True, |
| 653 | + ) |
| 654 | + results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms |
652 | 655 | except Exception as e:
|
653 | 656 | logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e)
|
654 | 657 | results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf")
|
@@ -712,21 +715,22 @@ def run_benchmarks(
|
712 | 715 | # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
|
713 | 716 | if flashinfer_comm is not None and allreduce_params is not None:
|
714 | 717 | try:
|
715 |
| - time_ms = benchmark_operation( |
716 |
| - flashinfer_fused_allreduce_rmsnorm_fp8_quant, |
717 |
| - input_tensor, |
718 |
| - norm_out=norm_out, |
719 |
| - residual=residual, |
720 |
| - rms_gamma=rms_gamma, |
721 |
| - rms_eps=rms_eps, |
722 |
| - scale_factor=scale_fp8, |
723 |
| - quant_out=quant_out_fp8, |
724 |
| - allreduce_params=allreduce_params, |
725 |
| - use_oneshot=True, |
726 |
| - ) |
727 |
| - results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = ( |
728 |
| - time_ms |
729 |
| - ) |
| 718 | + if not disable_oneshot: |
| 719 | + time_ms = benchmark_operation( |
| 720 | + flashinfer_fused_allreduce_rmsnorm_fp8_quant, |
| 721 | + input_tensor, |
| 722 | + norm_out=norm_out, |
| 723 | + residual=residual, |
| 724 | + rms_gamma=rms_gamma, |
| 725 | + rms_eps=rms_eps, |
| 726 | + scale_factor=scale_fp8, |
| 727 | + quant_out=quant_out_fp8, |
| 728 | + allreduce_params=allreduce_params, |
| 729 | + use_oneshot=True, |
| 730 | + ) |
| 731 | + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = ( |
| 732 | + time_ms |
| 733 | + ) |
730 | 734 | except Exception as e:
|
731 | 735 | logger.error(
|
732 | 736 | "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s",
|
@@ -802,22 +806,23 @@ def run_benchmarks(
|
802 | 806 | # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
|
803 | 807 | if flashinfer_comm is not None and allreduce_params is not None:
|
804 | 808 | try:
|
805 |
| - time_ms = benchmark_operation( |
806 |
| - flashinfer_fused_allreduce_rmsnorm_fp4_quant, |
807 |
| - input_tensor, |
808 |
| - residual=residual, |
809 |
| - norm_out=norm_out, |
810 |
| - rms_gamma=rms_gamma, |
811 |
| - rms_eps=rms_eps, |
812 |
| - input_global_scale=scale_fp4, |
813 |
| - allreduce_params=allreduce_params, |
814 |
| - quant_out=fp4_quant_out, |
815 |
| - output_scale=fp4_output_scale, |
816 |
| - use_oneshot=True, |
817 |
| - ) |
818 |
| - results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = ( |
819 |
| - time_ms |
820 |
| - ) |
| 809 | + if not disable_oneshot: |
| 810 | + time_ms = benchmark_operation( |
| 811 | + flashinfer_fused_allreduce_rmsnorm_fp4_quant, |
| 812 | + input_tensor, |
| 813 | + residual=residual, |
| 814 | + norm_out=norm_out, |
| 815 | + rms_gamma=rms_gamma, |
| 816 | + rms_eps=rms_eps, |
| 817 | + input_global_scale=scale_fp4, |
| 818 | + allreduce_params=allreduce_params, |
| 819 | + quant_out=fp4_quant_out, |
| 820 | + output_scale=fp4_output_scale, |
| 821 | + use_oneshot=True, |
| 822 | + ) |
| 823 | + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = ( |
| 824 | + time_ms |
| 825 | + ) |
821 | 826 | except Exception as e:
|
822 | 827 | logger.error(
|
823 | 828 | "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s",
|
@@ -1224,6 +1229,7 @@ def main():
|
1224 | 1229 | use_residual,
|
1225 | 1230 | allreduce_params,
|
1226 | 1231 | quant_mode=quant_mode,
|
| 1232 | + disable_oneshot=args.disable_oneshot, |
1227 | 1233 | )
|
1228 | 1234 |
|
1229 | 1235 | # Store results for markdown export
|
|
0 commit comments