Skip to content

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Dec 17, 2022

Warning: this PR contains #2258 and #2273. Please review this PR after I have merged these two PRs and rebased this PR.

This PR adds a few more passes that are capable of simplifying matmul indexing well. The newly added passes are: cancelDivMod, distributeDivisibleDivMod, and distributeMul. The most helpful pass for matmul is distributeDivisibleDivMod. It simplifies indices like:

(threadIdx.x + 16 * i1) % 8

into

threadIdx.x % 8

which helps removing data dependency on i1 so that the index can be hoisted outside of the i1 loop.

Example matmul kernel code

Command:

$CUDA_VISIBLE_DEVICES=1 $PYTORCH_NVFUSER_DUMP="cuda_kernel,expr_simplify,ptxas_verbose" ./build/bin/nvfuser_bench --benchmark_filter=Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time

Kernel diff compare (this PR + #1900) vs #1900 alone
https://www.diffchecker.com/pcxGCQkn

Matmul perf benchmark

Compare (this PR + #1900) vs #1900 alone

Command:

$CUDA_VISIBLE_DEVICES=1 ./build/bin/nvfuser_bench --benchmark_filter=.*Matmul.*Legacy/2048/3456/4096.*

Before:

---------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                       Time             CPU   Iterations
---------------------------------------------------------------------------------------------------------------------------------
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time       1209 us         1387 us          527
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time       1204 us         1382 us          504
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time       1126 us         1304 us          538
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time       2018 us         2200 us          312
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time       2082 us         2271 us          273
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time       1957 us         2138 us          305
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time       1297 us         1491 us          460
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time       1358 us         1580 us          460
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time       1314 us         1493 us          454
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time       1411 us         1590 us          423
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time       1407 us         1587 us          424
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time       1318 us         1497 us          453
EagerModeMatmul/no_quant_eagermode_TT_Legacy/2048/3456/4096/manual_time                       878 us          942 us          806
EagerModeMatmul/no_quant_eagermode_TN_Legacy/2048/3456/4096/manual_time                       904 us          971 us          747
EagerModeMatmul/no_quant_eagermode_NT_Legacy/2048/3456/4096/manual_time                       834 us          903 us          856

After:

---------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                       Time             CPU   Iterations
---------------------------------------------------------------------------------------------------------------------------------
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time        929 us         1109 us          678
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time        951 us         1138 us          655
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time        893 us         1083 us          680
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time       1066 us         1249 us          595
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time       1177 us         1358 us          529
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time       1032 us         1217 us          604
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time        936 us         1117 us          669
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time        911 us         1182 us          670
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time        928 us         1109 us          670
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time        914 us         1117 us          671
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time        936 us         1115 us          668
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time        926 us         1108 us          669
EagerModeMatmul/no_quant_eagermode_TT_Legacy/2048/3456/4096/manual_time                       876 us          957 us          811
EagerModeMatmul/no_quant_eagermode_TN_Legacy/2048/3456/4096/manual_time                       903 us          977 us          755
EagerModeMatmul/no_quant_eagermode_NT_Legacy/2048/3456/4096/manual_time                       829 us          905 us          840

@zasdfgbnm zasdfgbnm requested a review from naoyam January 17, 2023 00:46
@zasdfgbnm
Copy link
Collaborator Author

@naoyam This is ready for review

Base automatically changed from compatible-sign-check to devel January 17, 2023 23:58
@naoyam
Copy link
Collaborator

naoyam commented Jan 18, 2023

This PR adds a few more passes that are capable of simplifying matmul indexing well. The newly added passes are: cancelDivMod, distributeDivisibleDivMod, and distributeMul. The most helpful pass for matmul is distributeDivisibleDivMod. It simplifies indices like:

(threadIdx.x + 16 * i1) % 8

into

threadIdx.x

Why is this legal? Is threadIdx.x assumed to be less than 8?

}

BinaryOp* toDivModOp(Expr* expr) {
if (auto bop = dynamic_cast<BinaryOp*>(expr)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Most of the functions have this pattern of conditional branches where we could reduce indentation levels by negating the condition and exit.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed most of them

@zasdfgbnm
Copy link
Collaborator Author

This PR adds a few more passes that are capable of simplifying matmul indexing well. The newly added passes are: cancelDivMod, distributeDivisibleDivMod, and distributeMul. The most helpful pass for matmul is distributeDivisibleDivMod. It simplifies indices like:

(threadIdx.x + 16 * i1) % 8

into

threadIdx.x

Why is this legal? Is threadIdx.x assumed to be less than 8?

Oh, sorry. I meant to say threadIdx.x % 8...

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zasdfgbnm zasdfgbnm merged commit cead7ad into devel Jan 18, 2023
@zasdfgbnm zasdfgbnm deleted the distribute-divmod branch January 18, 2023 20:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants