Skip to content

Rewrite reducePredicateRegisterUsage #2533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 6, 2023
Merged

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Mar 1, 2023

The former approach does not make sense because it does a lot of reordering, even if there are no register usage savings. This reordering can be annoying because it makes the code very hard to read. I am rewriting this pass so that it only reorders things when there is a register saving.

@zasdfgbnm
Copy link
Collaborator Author

Marking this as ready, but I would like to wait for #2500 because I don't want this to conflict with the new assertCUDAKernels in loop rotation tests.

@zasdfgbnm zasdfgbnm marked this pull request as ready for review March 1, 2023 09:19
@zasdfgbnm zasdfgbnm requested a review from naoyam March 1, 2023 09:20
@@ -337,7 +337,7 @@ if(BUILD_TEST)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu2.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu3.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_compute_with.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_expr_simplifier.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_expr_simplifier.cpp)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we prefix our test with test_gpu_ because we were inside torchscript's test directory, and we need to distinguish our tests with other torch jit tests. But I don't think this prefix makes sense anymore.

@@ -47,7 +47,9 @@ void assertSimplifiedMod(Val* x, Val* y, Val* z) {

} // namespace

TEST_F(NVFuserTest, FusionAssociativeAndCommutativeReordering_CUDA) {
Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Mar 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were using this naming convention during the time when we link our tests with other jit tests and this naming convention helped us to find our tests from other jit tests. Now our tests has a standalone executable, so I don't think these naming conventions provide us with any benefit any more. So I am removing it to get a shorter test name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember there was some CI setting that relies on this naming convention. CC: @jjsjann123

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked with @jjsjann123 offline, we should keep the _CUDA suffix, but we can change the remaining. Updated this PR.

Comment on lines 374 to 375
// This is failing ?!
// setAssertOutOfBound(true);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is OK to fail due to the limitation of this feature.

@naoyam
Copy link
Collaborator

naoyam commented Mar 6, 2023

when there is a register saving.

I haven't looked into the PR yet, but how do you know if there's register saving?

@zasdfgbnm
Copy link
Collaborator Author

zasdfgbnm commented Mar 6, 2023

when there is a register saving.

I haven't looked into the PR yet, but how do you know if there's register saving?

I change this pass to only consider register saving on unrolled loop. For example, in threadIdx.x + 3 < T0.size[0], there is no unrolled loop, so changing it to threadIdx.x - T0.size[0] < -3 does not save anything, so now we will not move terms across the < boundary. This pass works by finding all terms that has unrolled loop index dependency, compute its register type, and compare the register type of the remaining terms. If there is a save (that is, remaining has gp register and unroll has uniform or imm, or remaining has uniform and unroll has imm), then move terms.

For example, if I have

#pragma unroll
for i = 0..8:
  threadIdx.x / 128 + i * 32 == 256 + blockIdx.y * i

Then I have register type:

gp,no_unroll + imm,unroll == imm,no_unroll + uniform,unroll

So I will need 8 general purpose register for the left and 8 uniform register for the right.

If I do

gp,no_unroll - imm,no_unroll == uniform,unroll - imm,unroll

Then I will need 1 general purpose register for the left, and 8 uniform register for the right, which saves 7 general purpose registers.

@naoyam
Copy link
Collaborator

naoyam commented Mar 6, 2023

when there is a register saving.

I haven't looked into the PR yet, but how do you know if there's register saving?

I change this pass to only consider register saving on unrolled loop. For example, in threadIdx.x + 3 < T0.size[0], there is no unrolled loop, so changing it to threadIdx.x - T0.size[0] < -3 does not save anything, so now we will not move terms across the < boundary. This pass works by finding all terms that has unrolled loop index dependency, compute its register type, and compare the register type of the remaining terms. If there is a save (that is, remaining has gp register and unroll has uniform or imm, or remaining has uniform and unroll has imm), then move terms.

For example, if I have

#pragma unroll
for i = 0..8:
  threadIdx.x / 128 + i * 32 == 256 + blockIdx.y * i

Then I have register type:

gp,no_unroll + imm,unroll == imm,no_unroll + uniform,unroll

So I will need 8 general purpose register for the left and 8 uniform register for the right.

If I do

gp,no_unroll - imm,no_unroll == uniform,unroll - imm,unroll

Then I will need 1 general purpose register for the left, and 8 uniform register for the right, which saves 7 general purpose registers.

Thanks for the explanation. Yeah, I found this part (https://github.com/csarofeen/pytorch/pull/2533/files#diff-7853cbfc8ac2e2e18643fb0ba06777e2ee9f4d43065973e9ffb38ebc0fcc0f68R1697), and it makes sense.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please just make sure the test name change doesn't invalidate anything around CI.

@@ -47,7 +47,9 @@ void assertSimplifiedMod(Val* x, Val* y, Val* z) {

} // namespace

TEST_F(NVFuserTest, FusionAssociativeAndCommutativeReordering_CUDA) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember there was some CI setting that relies on this naming convention. CC: @jjsjann123

} else {
redist_lhs({bop->lhs()});

auto [lhs_unroll, lhs_unroll_rtype, lhs_other, lhs_other_rtype] = classify(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know about this syntax. I'm assuming it has the same effect as using std::tie, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is similar, but not the same. This is called structured binding. My understanding is, structured binding is mostly used to declare and initialize a variable. std::tie is mostly for assignment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, so this can be used for variable declarations. That sounds handy.

@zasdfgbnm zasdfgbnm merged commit 5a69c1b into devel Mar 6, 2023
@zasdfgbnm zasdfgbnm deleted the reducePredicateRegisterUsage branch March 6, 2023 22:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants