-
Notifications
You must be signed in to change notification settings - Fork 7
Lower unroll cleanup, make it support IfThenElse #2496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -60,6 +60,7 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator { | |||
|
|||
private: | |||
void registerReplace(Expr* reference, Expr* new_expr, kir::Scope* scope); | |||
void registerReplace(Expr* reference, Expr* new_expr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need the other version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is needed. Running tests to verify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are no longer needed. Removed.
} | ||
if (expr != expr_with_predicate) { | ||
GpuLower::current()->propagateExprInfo(expr, expr_with_predicate); | ||
} | ||
inline_ite->thenBody().push_back(expr_with_predicate); | ||
} else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) { | ||
handle(for_loop); | ||
} else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing you're adding something in the original PR that introduces IfThenElse
before the unroll pass. What is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is loop rotation (which happens right before the unroll pass):
for i in range(n):
statement1(i)
statement2(i)
statement3(i)
statement4(i)
transform to
if 0 < n:
for i = 0:
statement1(i)
statement2(i)
for i ...:
statement3(i)
statement4(i)
if i + 1 < n:
statement1(i)
statement2(i)
I am actually not materializing these conditions (because I think existing predicates should already cover all illegal access), so I am just generating:
for i = 0:
statement1(i)
statement2(i)
for i ...:
statement3(i)
statement4(i)
if true:
statement1(i)
statement2(i)
But the if true
is still necessary because I use it as a special container to mark which part of the for loop is rotated from the next iteration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Not about this PR, but another thing we should clean up is to make the pass dependencies more explicit. I suspect there's a pass that assumes there's no kir::IfThenElse
in the incoming expr list, but I don't remember that's validated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Not about this PR, but another thing we should clean up is to make the pass dependencies more explicit. I suspect there's a pass that assumes there's no
kir::IfThenElse
in the incoming expr list, but I don't remember that's validated.
I think you are referring to the double buffer pass? Fortunately, there is a TORCH_INTERNAL_ASSERT
on handle(kir::IfThenElse* ite)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, but it's not much about a specific pass. These assumptions on the lowering pass dependencies should be explicitly represented and enforced.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sound like something like a pass manager? Passes should use a unified data structure and should have some metadata stored in that data structure so that the pass manager can parse and decide the order of passes?
@@ -68,8 +69,6 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator { | |||
return expr_replacement_map_; | |||
} | |||
|
|||
using OptOutDispatch::handle; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be using kir::ExprMutator::handle;
, but just removing it results in a build error for me (using clang):
/raid/tmp/nmaruyama/debug1/third_party/nvfuser/csrc/lower_unroll.h:74:8: error: 'nvfuser::UnrollPass::handle' hides overloaded virtual function [-Werror,-Woverloaded-virtual]
void handle(Expr* expr) final;
^
/raid/tmp/nmaruyama/debug1/third_party/nvfuser/csrc/kernel_ir_dispatch.h:36:16: note: hidden overloaded virtual function 'nvfuser::kir::IrVisitor::handle' declared here: type mismatch at 1st parameter ('nvfuser::kir::IfThenElse *' vs 'nvfuser::Expr *')
virtual void handle(IfThenElse*) override;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't make sense to me. But I will not argue with a compiler about which is correct, so I just added it back.
Reopen PR to trigger CI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Split from #2488