-
Notifications
You must be signed in to change notification settings - Fork 7
Contiguous indexing for View operations #1990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… lower_utils_cleanup
… contig_merge_split
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments so far
bool traverse_members = false); | ||
|
||
// Same as getStmts version but filters to only return the Expr*s | ||
static std::vector<Expr*> getExprs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this effectively the same as getAllExprsBetween
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I can rename it, good point. Also traverseFrom
that's used here should probably now be traverseBetween
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering if there's any preference between DependencyCheck::getAllExprsBetween
and StmtSort::getExprsBetween
?
// domains in its dependencies, then it can't be a contiguously indexable | ||
// iterdomain. | ||
if (!(consistent_transform_info_->isConsistentlyOrdered(merge->out()) && | ||
consistent_transform_info_->exclusivelyConsumesRoots(merge->out()))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the exclusiveness always necessary? For example:
[4, 8] -> split(0, 2) -> [4/2, 2, 8] -> merge(1) ->[4/2, 2*8]
In this case the second leaf ID (2*8
) doesn't exclusively consume the first root ID, but it should be valid to do indexing at that domain.
However, to do so, we would also need to index the first leaf ID, which is the outer output domain of the split, so some more extension would be needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! this is exactly the conclusion I came to as well! "Contiguous indexing" doesn't have to be on a merge output. But to do this we should start at the root domains and forward propagate on the indexing map to figure out how we want to index! We could even have multiple choices that are valid and could want something like a min-cut approach. I wonder how important this might be for multi-gpu fusions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect the closer to the leaf domains, the more efficient the index math would be.
Would be good to have this as a comment on the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I want to get into it in a comment, but this is something I want to keep in mind for a larger indexing refactor:
I suspect the closer to the leaf domains, the more efficient the index math would be.
I think it's a bit more complex than this, and probably what we're looking for is 2 things:
Bottlenecks in the transform graph, so series of merges -> series of splits, we'd want to be in between.
But maybe we also will want to prioritize IterDomains that are frequently used. So a concept of taking into consideration how many other TVs could reuse the indexing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Sergei-Lebedev I know it's early in your understanding of indexing, but this is a really interesting point we should review eventually.
// ID found, remove it | ||
root_ids.erase(root_id); | ||
// If the last id isn't contiguous that's fine, we can use the stride of | ||
// the last iter domain to multiply the contig index. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more about the definition of contiguous tensors in PyTorch. Specifically, I'm unclear what a tensor layout would look like when an outer domain is contiguous and inner is non contiguous. Let's say:
T0: [I0, I1]
I0: contiguous
I1: Non contiguous with stride 2
T0->merge(0);
Then, what this says is that T0
can be indexed as i * 2
, where i
is the loop index of the merged domain. This is valid only when the stride of the I0
is I1->extent() * 2
. For example, if the stride of I0
is I1->extent() * 2 - 1
, indexing as i * 2
would be invalid.
Is this always the case when we have outer-contiguous inner-non-contiguous tensors? I.e., the stride of the outer domain is guaranteed to be the multiple of the extent of the non-contiguous inner domain and its stride?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Per dimension contiguity is to exactly mark this, where I0
's stride is exactly I1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This concept is pretty interesting, tagging @mruberry as it's just a small but interesting part of the stride and stride order conversation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still working on responses. Pushing for now.
// domains in its dependencies, then it can't be a contiguously indexable | ||
// iterdomain. | ||
if (!(consistent_transform_info_->isConsistentlyOrdered(merge->out()) && | ||
consistent_transform_info_->exclusivelyConsumesRoots(merge->out()))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! this is exactly the conclusion I came to as well! "Contiguous indexing" doesn't have to be on a merge output. But to do this we should start at the root domains and forward propagate on the indexing map to figure out how we want to index! We could even have multiple choices that are valid and could want something like a min-cut approach. I wonder how important this might be for multi-gpu fusions.
// ID found, remove it | ||
root_ids.erase(root_id); | ||
// If the last id isn't contiguous that's fine, we can use the stride of | ||
// the last iter domain to multiply the contig index. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Per dimension contiguity is to exactly mark this, where I0
's stride is exactly I1
.
// ID found, remove it | ||
root_ids.erase(root_id); | ||
// If the last id isn't contiguous that's fine, we can use the stride of | ||
// the last iter domain to multiply the contig index. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This concept is pretty interesting, tagging @mruberry as it's just a small but interesting part of the stride and stride order conversation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
// The resulting iter domain does exclusively consume the roots. | ||
// | ||
// Also: | ||
// [i0, i1, i2, i3] merge->(1) merge->(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge(1)->merge(1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shoot thanks.
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Codegen changes include: * codegen improvement: i. allow non-root trivial reductions, allow empty/no-op fusion ii. fixes vectorization checks and size calculation iii. bank conflict handle improvement iv. enables transpose scheduler * misc: i. CI tests failure fixes ii. cpp tests file clean up iii. trivial forwarding supports added in codegen runtime iv. added factory methods support in codegen Commits that's in this PR from the devel branch: ``` 7117a7e patching nvfuser conv cudnn test numerics mismatch (#2048) 65af1a4 Inserting sync for redundant parallel types is already done at the (#2023) 6ac74d1 Fix sync map (#2047) f5bca33 Bank conflict checker improvements (#2032) d2ca7e3 Minor update on cp.async code generation. (#1901) d36cf61 Test file cleanup (#2040) 0b8e83f Allow non-root trivial reductions (#2037) a2dfe40 Fix vectorize size calculation (#2035) e040676 Use withPredicate to replace setPredicate to maintain Exprs immutable (#2025) 197221b removing ci workflow (#2034) 40e2703 Reduction rand like patch (#2031) bc77266 Add utility for checking bank conflict of shared memory (#2029) ddd1cf7 Add back FusionReductionWithTrivialReduction_CUDA (#2030) fbd97e5 Revert "Cleanup trivial reduction workarounds (#2006)" (#2024) bca20c1 Cleanup trivial reduction workarounds (#2006) e4b6585 Trivial forwarding (#1995) 1a0e355 Fix contiguity analysis of predicates to match updated contiguity. (#1991) a4effa6 Enable output allocation cache (#2010) 35440b7 Patching bn inference (#2016) 0f9f0b4 Add matmul benchmark (#2007) 45045cd Enable tests previously disabled due to an aliasing bug (#2005) 967aa77 Contiguous indexing for View operations (#1990) a43cb20 Make inlining even more modular (#2004) dc45835 Test util cleanup (#2003) 3ca21eb More strict validation (#2000) a7a7d57 Fix build problem (#1999) fc235b0 Just fixes comments (#1998) 482386c cleanup (#1997) 4cbe0db Improve divisible split detection (#1970) 42ccc52 Minor build fix. (#1996) fcf8c09 Cleanup of lower_utils.cpp: Isolate out GpuLower usage (#1989) 15f2f6d Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (#1988) 8f1c7f5 Minor cleanup lower_unroll.cpp (#1994) 1d9858c Minor cleanup (#1992) f262d9c Add support for uniform RNG (#1986) eb1dad1 Remove non-const functions, remove GpuLower instance on build, pass in ca_map. (#1987) 634820c Add support for some empty fusion (#1981) eabe8d8 Segment self mapping fusions (#1954) e96aacf Enable Transpose operation (#1882) 425dce2 Add a null scheduler that helps segmenting away no-op schedules (#1835) 306d4a6 Fix canScheduleCompileTime check of transpose scheduler (#1969) b1bd32c Minor fix (#1967) bd93578 Enable transpose scheduler (#1927) b7a206e Move scheduler vectorize utilities into their own file (#1959) d9420e4 View scheduling (#1928) c668e13 Upstream push ci fixes (#1965) c40202b Fix dump effective bandwidth (#1962) 93505bc WAR on index mapping when exact and permissive maps differ (#1960) 45e95fd Allow splitting inner-most ID to create virtual innermost ID in transpose scheduler (#1930) a3ecb33 Improve the comments at the beginning of index_compute.h (#1946) f7bc341 Remove unused variables (#1955) df3393a Some cleanup (#1957) 7d1d7c8 TVDomainGuard factory (#1953) 357ba22 Fill allocation with nan on tests (#1956) 8eafc54 Fix detection of unmappable root domains (#1952) 90a51f2 Some indexing cleanups, Add eye support (#1940) ddc01e4 Exclude unsupported data types (#1951) 992e17c test the groups the same order as they are merged (#1949) 208262b Move detection of self mapping IDs to IterDomainGraph from (#1941) ac4de38 Merge pull request #1945 from csarofeen/master_merge_0828 6310948 Add full, full_like, zeros, zeros_like, ones, ones_like (#1943) aab10bc Merge remote-tracking branch 'upstream/viable/strict' into HEAD 4c254c0 Fix arange when step is negative (#1942) 89330aa Tensor factories must set the output shape as its input (#1939) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D40869846](https://our.internmc.facebook.com/intern/diff/D40869846) Pull Request resolved: pytorch#87779 Approved by: https://github.com/davidberard98
Changes contiguity analysis so it works through view operations. For example if we have:
Code will now be generated as:
Specifically we can see the consistent linearized indexing on T0 and T1. Though there's still work to do to make sure the predicates are also contiguous.
No significant performance changes detected with this PR which is expected:
