Skip to content

Fix contiguity analysis of predicates to match updated contiguity. #1991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 30, 2022

Conversation

csarofeen
Copy link
Owner

@csarofeen csarofeen commented Sep 17, 2022

This PR depends on #1990

Extends contiguity method introduced there to predicate analysis. A few questions marked in comments but should be safe as is so the questions could be pushed to a follow up.

  int w = 15, x = 31;

  auto tv0 = makeContigTensor(4);
  fusion.addInput(tv0);
  auto tv1 = sin(tv0);
  auto tv2 = view(tv1, {w, x}, {x, w});

  fusion.addOutput(tv2);

  tv2->merge(0)->split(0, 4)->split(0, 8, false);

  TransformPropagator propagator(tv2);
  MaxRootDomainInfoSpanningTree spanning_tree(tv2);
  spanning_tree.traverse(&propagator);

Will now generate:

__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T2) {
  NVFUSER_DEFINE_MAGIC_ZERO
  float T1[((8 * (ceilDiv((ceilDiv((31 * (ceilDiv((T0.size[0] * T0.size[1]), 31))), 4)), 8))) * 4)];
  #pragma unroll
  for(nvfuser_index_t i28 = 0; i28 < 8; ++i28) {
    #pragma unroll 1
    for(nvfuser_index_t i29 = 0; i29 < (ceilDiv((ceilDiv((31 * (ceilDiv((T0.size[0] * T0.size[1]), 31))), 4)), 8)); ++i29) {
      #pragma unroll
      for(nvfuser_index_t i30 = 0; i30 < 4; ++i30) {
        int64_t i66;
        i66 = (((i28 * (ceilDiv((ceilDiv((31 * (ceilDiv((T0.size[0] * T0.size[1]), 31))), 4)), 8))) + i29) * 4) + (i30 + nvfuser_zero);
        if ((i66 < (31 * (ceilDiv((T0.size[0] * T0.size[1]), 31))))) {
          T1[((((i28 * (ceilDiv((ceilDiv((31 * (ceilDiv((T0.size[0] * T0.size[1]), 31))), 4)), 8))) + i29) * 4) + i30)]
             = sinf(T0[i66]);
        }
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO
  #pragma unroll
  for(nvfuser_index_t i31 = 0; i31 < 8; ++i31) {
    #pragma unroll 1
    for(nvfuser_index_t i32 = 0; i32 < (ceilDiv((ceilDiv((31 * (ceilDiv((T0.size[0] * T0.size[1]), 31))), 4)), 8)); ++i32) {
      #pragma unroll
      for(nvfuser_index_t i33 = 0; i33 < 4; ++i33) {
        int64_t i151;
        i151 = (((i31 * (ceilDiv((ceilDiv((31 * (ceilDiv((T0.size[0] * T0.size[1]), 31))), 4)), 8))) + i32) * 4) + (i33 + nvfuser_zero);
        if ((i151 < (31 * (ceilDiv((T0.size[0] * T0.size[1]), 31))))) {
          T2[i151]
             = T1[((((i31 * (ceilDiv((ceilDiv((31 * (ceilDiv((T0.size[0] * T0.size[1]), 31))), 4)), 8))) + i32) * 4) + i33)];
        }
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO
}

No significant perf change detected with this PR which is expected
image

@csarofeen csarofeen requested a review from naoyam September 17, 2022 22:24
@naoyam
Copy link
Collaborator

naoyam commented Sep 28, 2022

This PR should close #2001

Base automatically changed from contig_merge_split to devel September 28, 2022 22:56
@csarofeen
Copy link
Owner Author

@naoyam this should be ready for review.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments so far

Comment on lines 535 to 537
// stride. If we're computing predicates, then we don't want to do this,
// as when we mark something as non-contiguous we don't want it merged
// with any other domains.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow this. First of all, are we always assuming tensors are fully contiguous when generating predicates? That's what I remember, and if it's still the case, this part of the code should never be reachable for predicates.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried to split up the conditional into two stages. The predicate path gets caught below in !ignore_consistent_ordering_ sorry for the confusion. Predicate will have ignore_consistent_ordering_ == true so it won't exit/return non-contiguous based on this check.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is confusing, let me take another attempt at it.

for (auto root_i : c10::irange(predicate_contiguity.size())) {
auto root_id = consumer_root_domain[root_i];
if (root_id->maybePartial()) {
// predicate_contiguity[root_i] = false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe to mark it as contiguous? Probably no unit test exercises this case, but as it was commented, predicates should be done at root domains in this case.

Copy link
Collaborator

@naoyam naoyam Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed and added a unit test (FusionContigPredicateShift). The test should fail without the fix.

gather_expr->windowShape().at(root_i) != 1)) {
// TODO: The following line commented out didn't have any failures, is it
// needed?
predicate_contiguity[root_i] = false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is necessary. It probably just doesn't have tests. Let me add tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed and added a unit test (FusionContigPredicateShift). The test should fail without the fix.

Comment on lines +535 to +537
// If we're computing predicates (ignore_consistent_ordering_==true),
// then we don't have this same constraint, we can just ignore
// contiguity of the roots all together.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Thanks for clarification.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this isn't correct. Even for predicates, some root domains are marked as non-contiguous when they need to be predicated at root domains.

Some cleanups would be needed here. I'll work on it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A problem with the design is root_contiguity is used both as indicating contiguity AND if domains are not necessary to be predicated at root. These are two separate concepts, so I split out the latter as final_ids.

Comment on lines 539 to 540
// TODO: This didn't error when I removed "!ignore_consistent_ordering_"
// is it really needed?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still want to keep this comment? It should actually not be necessary. If this is used for indexing, is_indexing_pass is true, so nothing should change. If used for predicates, !root_contiguity_ should be always true, so is_indexing_pass doesn't matter.

I actually like the current code as it's easier to understand the underlying logic.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamping for now

@naoyam
Copy link
Collaborator

naoyam commented Sep 30, 2022

@csarofeen I pushed a few commits. Please let me know if they make sense. I already approved the PR so that this can be merged if they look good.

@csarofeen
Copy link
Owner Author

LGTM Thanks!

@csarofeen csarofeen merged commit 1a0e355 into devel Sep 30, 2022
jjsjann123 added a commit that referenced this pull request Nov 9, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Codegen changes include:

* codegen improvement:
    i. allow non-root trivial reductions, allow empty/no-op fusion
    ii. fixes vectorization checks and size calculation
    iii. bank conflict handle improvement
    iv. enables transpose scheduler

* misc:
    i. CI tests failure fixes
    ii. cpp tests file clean up
    iii. trivial forwarding supports added in codegen runtime
    iv. added factory methods support in codegen

Commits that's in this PR from the devel branch:

```
7117a7e patching nvfuser conv cudnn test numerics mismatch (#2048)
65af1a4 Inserting sync for redundant parallel types is already done at the (#2023)
6ac74d1 Fix sync map (#2047)
f5bca33 Bank conflict checker improvements (#2032)
d2ca7e3 Minor update on cp.async code generation. (#1901)
d36cf61 Test file cleanup (#2040)
0b8e83f Allow non-root trivial reductions (#2037)
a2dfe40 Fix vectorize size calculation (#2035)
e040676 Use withPredicate to replace setPredicate to maintain Exprs immutable (#2025)
197221b removing ci workflow (#2034)
40e2703 Reduction rand like patch (#2031)
bc77266 Add utility for checking bank conflict of shared memory (#2029)
ddd1cf7 Add back FusionReductionWithTrivialReduction_CUDA (#2030)
fbd97e5 Revert "Cleanup trivial reduction workarounds (#2006)" (#2024)
bca20c1 Cleanup trivial reduction workarounds (#2006)
e4b6585 Trivial forwarding (#1995)
1a0e355 Fix contiguity analysis of predicates to match updated contiguity. (#1991)
a4effa6 Enable output allocation cache (#2010)
35440b7 Patching bn inference (#2016)
0f9f0b4 Add matmul benchmark (#2007)
45045cd Enable tests previously disabled due to an aliasing bug (#2005)
967aa77 Contiguous indexing for View operations (#1990)
a43cb20 Make inlining even more modular (#2004)
dc45835 Test util cleanup (#2003)
3ca21eb More strict validation (#2000)
a7a7d57 Fix build problem (#1999)
fc235b0 Just fixes comments (#1998)
482386c cleanup (#1997)
4cbe0db Improve divisible split detection (#1970)
42ccc52 Minor build fix. (#1996)
fcf8c09 Cleanup of lower_utils.cpp: Isolate out GpuLower usage (#1989)
15f2f6d Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (#1988)
8f1c7f5 Minor cleanup lower_unroll.cpp (#1994)
1d9858c Minor cleanup (#1992)
f262d9c Add support for uniform RNG (#1986)
eb1dad1 Remove non-const functions, remove GpuLower instance on build, pass in ca_map. (#1987)
634820c Add support for some empty fusion (#1981)
eabe8d8 Segment self mapping fusions (#1954)
e96aacf Enable Transpose operation (#1882)
425dce2 Add a null scheduler that helps segmenting away no-op schedules (#1835)
306d4a6 Fix canScheduleCompileTime check of transpose scheduler (#1969)
b1bd32c Minor fix (#1967)
bd93578 Enable transpose scheduler (#1927)
b7a206e Move scheduler vectorize utilities into their own file (#1959)
d9420e4 View scheduling (#1928)
c668e13 Upstream push ci fixes (#1965)
c40202b Fix dump effective bandwidth (#1962)
93505bc WAR on index mapping when exact and permissive maps differ (#1960)
45e95fd Allow splitting inner-most ID to create virtual innermost ID in transpose scheduler (#1930)
a3ecb33 Improve the comments at the beginning of index_compute.h (#1946)
f7bc341 Remove unused variables (#1955)
df3393a Some cleanup (#1957)
7d1d7c8 TVDomainGuard factory (#1953)
357ba22 Fill allocation with nan on tests (#1956)
8eafc54 Fix detection of unmappable root domains (#1952)
90a51f2 Some indexing cleanups, Add eye support (#1940)
ddc01e4 Exclude unsupported data types (#1951)
992e17c test the groups the same order as they are merged (#1949)
208262b Move detection of self mapping IDs to IterDomainGraph from (#1941)
ac4de38 Merge pull request #1945 from csarofeen/master_merge_0828
6310948 Add full, full_like, zeros, zeros_like, ones, ones_like (#1943)
aab10bc Merge remote-tracking branch 'upstream/viable/strict' into HEAD
4c254c0 Fix arange when step is negative (#1942)
89330aa Tensor factories must set the output shape as its input (#1939)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D40869846](https://our.internmc.facebook.com/intern/diff/D40869846)
Pull Request resolved: pytorch#87779
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants