Skip to content

Improve divisible split detection #1970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 27, 2022
Merged

Improve divisible split detection #1970

merged 4 commits into from
Sep 27, 2022

Conversation

csarofeen
Copy link
Owner

Adds propagation of divisible split information, as well as add divisible splits from view based transformations.

Tangibly NVFuserTest.FusionNonDivisibleSplitVectorize2_CUDA before:

__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 0> T3) {
  alignas(16) extern __shared__ char array[];
  void* shared_mem = array;
  NVFUSER_DEFINE_MAGIC_ZERO
  Array<float, (8 * 4), 4> T1;
  #pragma unroll
  for(nvfuser_index_t i21 = 0; i21 < 8; ++i21) {
    if (((((i21 + nvfuser_zero) * (ceilDiv(T0.size[0], 8))) + ((((nvfuser_index_t)threadIdx.x) * 4) + 3)) < T0.size[0])) {
      loadGlobalToLocal<float, 4, false>(&T1[(i21 * 4)],  &T0[(((i21 + nvfuser_zero) * (ceilDiv(T0.size[0], 8))) + (((nvfuser_index_t)threadIdx.x) * 4))]);
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO
  // Alias Allocation - register
  auto& T2 = T1;
  #pragma unroll
  for(nvfuser_index_t i23 = 0; i23 < 8; ++i23) {
    #pragma unroll
    for(nvfuser_index_t i24 = 0; i24 < 4; ++i24) {
      if ((((i23 * (ceilDiv(T0.size[0], 8))) + ((((nvfuser_index_t)threadIdx.x) * 4) + (i24 + nvfuser_zero))) < T0.size[0])) {
        T2[((i23 * 4) + i24)]
          = T1[((i23 * 4) + i24)]
          + (float) 1.00000000000000000e+00;
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO
  T3[0] = 0.00000000000000000e+00;
  #pragma unroll
  for(nvfuser_index_t i25 = 0; i25 < 8; ++i25) {
    #pragma unroll
    for(nvfuser_index_t i26 = 0; i26 < 4; ++i26) {
      blockReduce<true, false, false>(
        T3[0],
        T2[((i25 * 4) + i26)],
        [](float &a, float b) { a = a + b; },
        threadIdx,
        blockDim,
        static_cast<float*>(shared_mem),
        (((i25 * (ceilDiv(T0.size[0], 8))) + ((((nvfuser_index_t)threadIdx.x) * 4) + (i26 + nvfuser_zero))) < T0.size[0]),
        float(0.00000000000000000e+00));
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO
}

After:

__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 0> T3) {
  alignas(16) extern __shared__ char array[];
  void* shared_mem = array;
  NVFUSER_DEFINE_MAGIC_ZERO
  Array<float, (8 * 4), 4> T1;
  #pragma unroll
  for(nvfuser_index_t i21 = 0; i21 < 8; ++i21) {
    T1.set(0);
  }
  NVFUSER_UPDATE_MAGIC_ZERO
  #pragma unroll
  for(nvfuser_index_t i21 = 0; i21 < 8; ++i21) {
    if (((((i21 + nvfuser_zero) * (ceilDiv(T0.size[0], 8))) + ((((nvfuser_index_t)threadIdx.x) * 4) + 3)) < T0.size[0])) {
      loadGlobalToLocal<float, 4, false>(&T1[(i21 * 4)],  &T0[(((i21 + nvfuser_zero) * (ceilDiv(T0.size[0], 8))) + (((nvfuser_index_t)threadIdx.x) * 4))]);
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO
  // Alias Allocation - register
  auto& T2 = T1;
  #pragma unroll
  for(nvfuser_index_t i23 = 0; i23 < 8; ++i23) {
    #pragma unroll
    for(nvfuser_index_t i24 = 0; i24 < 4; ++i24) {
      T2[((i23 * 4) + i24)]
        = T1[((i23 * 4) + i24)]
        + (float) 1.00000000000000000e+00;
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO
  T3[0] = 0.00000000000000000e+00;
  #pragma unroll
  for(nvfuser_index_t i25 = 0; i25 < 8; ++i25) {
    #pragma unroll
    for(nvfuser_index_t i26 = 0; i26 < 4; ++i26) {
      blockReduce<true, false, false>(
        T3[0],
        T2[((i25 * 4) + i26)],
        [](float &a, float b) { a = a + b; },
        threadIdx,
        blockDim,
        static_cast<float*>(shared_mem),
        (((i25 * (ceilDiv(T0.size[0], 8))) + ((((nvfuser_index_t)threadIdx.x) * 4) + (i26 + nvfuser_zero))) < T0.size[0]),
        float(0.00000000000000000e+00));
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO
}

Specifically this predicate is successfully removed:

      if ((((i23 * (ceilDiv(T0.size[0], 8))) + ((((nvfuser_index_t)threadIdx.x) * 4) + (i24 + nvfuser_zero))) < T0.size[0])) {
        T2[((i23 * 4) + i24)]
          = T1[((i23 * 4) + i24)]
          + (float) 1.00000000000000000e+00;
      }

@csarofeen
Copy link
Owner Author

@naoyam or @zasdfgbnm this should be ready to review. There is one bug here, which is that rfactor domains that are a result of non-divisible splits (only happens with reduction) might not be identified as such, but indexed as such. Was thinking @samnordmann might be able to help here, since this would be required to fix if we wanted mult-gpu supported rfactor stages.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments. Nothing blocking

Fusion* fusion);

// Same as above but will use provided ComputeAtMap instead of building its own.
std::unordered_set<Split*> getAllDivisibleSplits(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: any specific reason not to have TORCH_CUDA_CU_API for this version.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, I typically only add TORCH_CUDA_CU_API when I build a test on an interface. Can add it for symmetry.

@@ -0,0 +1,124 @@

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The file name should be lower_divisible_split.cpp.

Comment on lines 29 to 36
if (!tv->hasRFactor() ||
std::any_of(
rfactor_dom.begin(),
rfactor_dom.end(),
// Also not a view transform if there's a reduction dimension.
[](IterDomain* id) { return id->isReduction(); })) {
continue;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be simplified by using TensorDomain::hasViewLikeRFactor()?


auto all_tvs = ir_utils::allTvs(fusion);
// Find all tensor views with a view like rfactor
for (auto tv : all_tvs) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a note saying splits used in view are by definition divisible?

continue;
}

// We could have a cass technically like:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type: case

auto vec_id_it = std::find_if(
tv->domain()->domain().begin(),
tv->domain()->domain().end(),
// Also not a view transform if there's a reduction dimension.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated comment

auto concrete_id = entry.first;
auto original_view_split = entry.second;

auto exact_mapped_ids =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: const auto&

auto transform_6 = transform_5->in()->definition()->as<Split>();
auto transform_7 = transform_6->in()->definition();

TORCH_CHECK(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not asserting transform_6 and transform_7 are also detected as divisible?

@naoyam
Copy link
Collaborator

naoyam commented Sep 27, 2022

@naoyam or @zasdfgbnm this should be ready to review. There is one bug here, which is that rfactor domains that are a result of non-divisible splits (only happens with reduction) might not be identified as such, but indexed as such. Was thinking @samnordmann might be able to help here, since this would be required to fix if we wanted mult-gpu supported rfactor stages.

Is the bug due to this PR? Or is it already in the code?

@csarofeen
Copy link
Owner Author

The bug already exists before this PR.

@csarofeen csarofeen requested a review from naoyam September 27, 2022 15:50
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@naoyam
Copy link
Collaborator

naoyam commented Sep 27, 2022

The bug already exists before this PR.

Can you please file an issue?

@csarofeen
Copy link
Owner Author

Yeah, I'm actually going to ask @samnordmann to file an issue on it and work on it.

@csarofeen csarofeen merged commit 4cbe0db into devel Sep 27, 2022
@zasdfgbnm zasdfgbnm deleted the divisible_split_prop branch September 27, 2022 18:01
jjsjann123 added a commit that referenced this pull request Nov 9, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Codegen changes include:

* codegen improvement:
    i. allow non-root trivial reductions, allow empty/no-op fusion
    ii. fixes vectorization checks and size calculation
    iii. bank conflict handle improvement
    iv. enables transpose scheduler

* misc:
    i. CI tests failure fixes
    ii. cpp tests file clean up
    iii. trivial forwarding supports added in codegen runtime
    iv. added factory methods support in codegen

Commits that's in this PR from the devel branch:

```
7117a7e patching nvfuser conv cudnn test numerics mismatch (#2048)
65af1a4 Inserting sync for redundant parallel types is already done at the (#2023)
6ac74d1 Fix sync map (#2047)
f5bca33 Bank conflict checker improvements (#2032)
d2ca7e3 Minor update on cp.async code generation. (#1901)
d36cf61 Test file cleanup (#2040)
0b8e83f Allow non-root trivial reductions (#2037)
a2dfe40 Fix vectorize size calculation (#2035)
e040676 Use withPredicate to replace setPredicate to maintain Exprs immutable (#2025)
197221b removing ci workflow (#2034)
40e2703 Reduction rand like patch (#2031)
bc77266 Add utility for checking bank conflict of shared memory (#2029)
ddd1cf7 Add back FusionReductionWithTrivialReduction_CUDA (#2030)
fbd97e5 Revert "Cleanup trivial reduction workarounds (#2006)" (#2024)
bca20c1 Cleanup trivial reduction workarounds (#2006)
e4b6585 Trivial forwarding (#1995)
1a0e355 Fix contiguity analysis of predicates to match updated contiguity. (#1991)
a4effa6 Enable output allocation cache (#2010)
35440b7 Patching bn inference (#2016)
0f9f0b4 Add matmul benchmark (#2007)
45045cd Enable tests previously disabled due to an aliasing bug (#2005)
967aa77 Contiguous indexing for View operations (#1990)
a43cb20 Make inlining even more modular (#2004)
dc45835 Test util cleanup (#2003)
3ca21eb More strict validation (#2000)
a7a7d57 Fix build problem (#1999)
fc235b0 Just fixes comments (#1998)
482386c cleanup (#1997)
4cbe0db Improve divisible split detection (#1970)
42ccc52 Minor build fix. (#1996)
fcf8c09 Cleanup of lower_utils.cpp: Isolate out GpuLower usage (#1989)
15f2f6d Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (#1988)
8f1c7f5 Minor cleanup lower_unroll.cpp (#1994)
1d9858c Minor cleanup (#1992)
f262d9c Add support for uniform RNG (#1986)
eb1dad1 Remove non-const functions, remove GpuLower instance on build, pass in ca_map. (#1987)
634820c Add support for some empty fusion (#1981)
eabe8d8 Segment self mapping fusions (#1954)
e96aacf Enable Transpose operation (#1882)
425dce2 Add a null scheduler that helps segmenting away no-op schedules (#1835)
306d4a6 Fix canScheduleCompileTime check of transpose scheduler (#1969)
b1bd32c Minor fix (#1967)
bd93578 Enable transpose scheduler (#1927)
b7a206e Move scheduler vectorize utilities into their own file (#1959)
d9420e4 View scheduling (#1928)
c668e13 Upstream push ci fixes (#1965)
c40202b Fix dump effective bandwidth (#1962)
93505bc WAR on index mapping when exact and permissive maps differ (#1960)
45e95fd Allow splitting inner-most ID to create virtual innermost ID in transpose scheduler (#1930)
a3ecb33 Improve the comments at the beginning of index_compute.h (#1946)
f7bc341 Remove unused variables (#1955)
df3393a Some cleanup (#1957)
7d1d7c8 TVDomainGuard factory (#1953)
357ba22 Fill allocation with nan on tests (#1956)
8eafc54 Fix detection of unmappable root domains (#1952)
90a51f2 Some indexing cleanups, Add eye support (#1940)
ddc01e4 Exclude unsupported data types (#1951)
992e17c test the groups the same order as they are merged (#1949)
208262b Move detection of self mapping IDs to IterDomainGraph from (#1941)
ac4de38 Merge pull request #1945 from csarofeen/master_merge_0828
6310948 Add full, full_like, zeros, zeros_like, ones, ones_like (#1943)
aab10bc Merge remote-tracking branch 'upstream/viable/strict' into HEAD
4c254c0 Fix arange when step is negative (#1942)
89330aa Tensor factories must set the output shape as its input (#1939)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D40869846](https://our.internmc.facebook.com/intern/diff/D40869846)
Pull Request resolved: pytorch#87779
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants