-
Notifications
You must be signed in to change notification settings - Fork 7
[MatMul] loop interleaving pass to interleave double buffered unrolled loops #1975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: index_codegen
Are you sure you want to change the base?
Conversation
@@ -627,6 +637,9 @@ class TORCH_CUDA_CU_API TensorView : public Val { | |||
//! Indicates if the prolog of the double buffer loop of double | |||
//! buffer tensor will be lifted out of the main loop. | |||
bool skew_double_buffer_loop_ = false; | |||
|
|||
// Loop where the next level of unrolled loops are interleaved. | |||
c10::optional<std::pair<int, int>> maybe_interleave_axis_and_factor_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments on the pair
// If we see main loop before seeing the double buffer axis, | ||
// it cannot be proven safe to interleave by double buffering | ||
// but the other two points might apply. | ||
can_interleave = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't a break
be added here?
continue; | ||
} | ||
|
||
// Double buffered tv doesn't need to be checked, see Point 2 above: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: Point 1
// [Supported Interleaving Cases] | ||
// All the expressions that are inside the main loop or subloop can | ||
// only be 3 cases: | ||
// 1. It's double/circular buffered across a loop that's either at or on the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't follow this case.
What I can't yet figure out is what the underlying generic condition this transformation must satisfy. Generally speaking, it seems safe if there's no data dependency between the subloop TVs, which basically corresponds to the Point 3. In the case of Point 2, it is also safe despite the data dependency because the dependency is constrained inside the sub loop, right? I can't wrap my head around the Point 1 yet, though.
if (concrete_main_loop_ == concrete_loop_id && | ||
fl->doubleBufferLoopStage() == DoubleBufferLoopStage::Main) { | ||
handleMainLoop(fl); | ||
} else { | ||
kir::ExprMutator::handle(fl); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean the interleave main loop must also be a main loop of double buffering?
// Need to insert commits for multi-stage circular buffering | ||
// on the prologs, but do not need to wait for them until | ||
// the main loop. | ||
if (stage_depth > 2 && loop_type_ == DoubleBufferLoopStage::Prolog) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a generic bug fix or is it related to the interleaving transformation?
if (need_insert_commit) { | ||
main_loop->body().insert_before( | ||
*block_sync_it, IrBuilder::create<kir::CpAsyncCommit>()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not completely following what should be done here, but the above comment on need_insert_commit
indicates a commit should be inserted before the wait, but this seems to insert a commit after the wait inserted above. Am I missing something?
The loop interleaving optimization in this PR is needed on ampere with cp.async.
The transform this pass enables is the following:
original code:
with some simple conservative checking that expr 1-3 have no direct dependencies, the pass transforms the above into:
The particular use case is the following:
The interleaving essentially optimizes away the congestion mentioned above on the comment.