Skip to content

Improve matmul instruction scheduling with loop rotation #2488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Mar 2, 2023

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Feb 17, 2023

Introduction

Loop rotation is a lowering pass that transform

for i in range(n):
  statement1(i)
  statement2(i)
  statement3(i)
  statement4(i)

into

statement1(0)
statement2(0)
for i in range(n):
  statement3(i)
  statement4(i)
  statement1(i+1)
  statement2(i+1)

In the matmul kernel, both the cp.async and the ld.matrix are circular/double buffered. This PR applies loop rotation to the matmul main loop to pull the first iteration's ld.matrix out of the main loop of cp.async.

That is, to change the code from

cp.async prologue
// main loop for cp.async
for (...) {
  cp.async
  ld.matrix prologue // <-- to be rotated
  // main loop for ld.matrix
  for (...) {
    ld.matrix
    mma
  }
  mma // epilogue for ld.matrix
}

to

cp.async prologue
ld.matrix prologue  // <-- rotated
// main loop for cp.async
for (...) {
  cp.async
  // main loop for ld.matrix
  for (...) {
    ld.matrix
    mma
  }
  mma // epilogue for ld.matrix
  ld.matrix prologue  // <-- rotated
}

In order to do so, I need to do a reorder to change the matmul schedule from

//                               vvvvvv ld.matrix double buffer loop
[BIDx, BIDy, Serial, TIDz, TIDy, Serial]
//           ^^^^^^ cp.async circular buffer loop

to

//                   vvvvvv ld.matrix double buffer loop
[BIDx, BIDy, Serial, Serial, TIDz, TIDy]
//           ^^^^^^ cp.async circular buffer loop

Because in the first schedule, the loop structure is

for blockIdx.x:
  for blockIdx.y:
    cp.async
    for i1:  # cp.async circular buffer loop
      cp.async
      for threadIdx.z:
        for threadIdx.y:
          ld.matrix
          for i2:  # ld.matrix double buffer loop
            ld.matrix
            mma
          mma

where inside the cp.async circular buffer loop, the entire ld.matrix->mma is contained in the threadIdx trivial loop, and the ld.matrix is not separable.

In contrast, for the second schedule, we have

for blockIdx.x:
  for blockIdx.y:
    cp.async
    for i1:  # cp.async circular buffer loop
      cp.async
      ld.matrix
      for i2:  # ld.matrix double buffer loop
        for threadIdx.z:
          for threadIdx.y:
            ld.matrix
            mma
      for threadIdx.z:
        for threadIdx.y:
          mma

The blockIdx and threadIdx loops are trivial loops, so this schedule change actually doesn't affect the generated CUDA kernel. However, it does make kernel IR easier to deal with.

Benchmark

Using command

$CUDA_VISIBLE_DEVICES=1 ./build/bin/nvfuser_bench --benchmark_filter=.*Matmul.*Legacy/2048/3456/4096.*

Before this PR:

---------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                       Time             CPU   Iterations
---------------------------------------------------------------------------------------------------------------------------------
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time        990 us         2746 us          711
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time        871 us         2629 us          720
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time       1064 us         2821 us          579
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time       1278 us         3034 us          499
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time       1159 us         2914 us          607
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time       1432 us         3188 us          447
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time       1209 us         2966 us          526
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time       1134 us         2892 us          619
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time       1320 us         3076 us          532
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time       1216 us         2973 us          578
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time       1114 us         2872 us          550
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time       1371 us         3130 us          512
EagerModeMatmul/no_quant_eagermode_TT_Legacy/2048/3456/4096/manual_time                       845 us          912 us          832
EagerModeMatmul/no_quant_eagermode_TN_Legacy/2048/3456/4096/manual_time                       916 us          985 us          765
EagerModeMatmul/no_quant_eagermode_NT_Legacy/2048/3456/4096/manual_time                       792 us          863 us          884

After this PR:

---------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                       Time             CPU   Iterations
---------------------------------------------------------------------------------------------------------------------------------
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time        887 us         2643 us          793
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time        899 us         2655 us          780
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time        978 us         2734 us          717
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time        891 us         2648 us          787
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time        899 us         2655 us          782
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time        997 us         2753 us          704
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time        904 us         2662 us          732
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time        903 us         2660 us          778
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time        946 us         2703 us          742
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time        888 us         2646 us          727
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time        885 us         2643 us          794
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time        945 us         2702 us          744
EagerModeMatmul/no_quant_eagermode_TT_Legacy/2048/3456/4096/manual_time                       845 us          911 us          832
EagerModeMatmul/no_quant_eagermode_TN_Legacy/2048/3456/4096/manual_time                       916 us          985 us          765
EagerModeMatmul/no_quant_eagermode_NT_Legacy/2048/3456/4096/manual_time                       792 us          863 us          884

@zasdfgbnm zasdfgbnm changed the title Loop rotation WIP Improve matmul instruction scheduling with loop rotation Mar 1, 2023
Comment on lines +10 to +11
assert(ind >= 0);
assert(ind <= max_ind);
Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Mar 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR, but asserting different conditions separately provides a better error message. (The line number in the error message will tell me which is violated).

@zasdfgbnm zasdfgbnm marked this pull request as ready for review March 1, 2023 21:08
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zasdfgbnm zasdfgbnm merged commit 5913acc into devel Mar 2, 2023
@zasdfgbnm zasdfgbnm deleted the loop-rotation branch March 2, 2023 05:23
drzejan2 added a commit that referenced this pull request Mar 14, 2023
- apply improvement in matmul instruction scheduling with loop rotation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants