Skip to content

Add utility for checking bank conflict of shared memory #2029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 4, 2022

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Oct 4, 2022

Inspired by #1900, but the approach here is different. The approach here does the checking at compile time, and can be used in unit tests. Sample debug dump:

======= Bank confliction =======
Expr: T11_s[( ( ( ( ( i99 * 128 ) + threadIdx.x ) * 2 ) / 32 ) * 32 ), ( ( ( ( i99 * 128 ) + threadIdx.x ) * 2 ) % 32 )] view( T4 )
   = T10_g[( ( ( ( ( blockIdx.x % ( ceilDiv(( 16 * T0.size[0] ), 32) ) ) * 32 ) + ( ( ( ( ( i99 + nvfuser_zero ) * 128 ) + threadIdx.x ) * 2 ) / 32 ) ) % T0.size[0] ) * ( T0.size[2] * T0.size[1] ) ), ( ( ( ( ( ( ( blockIdx.x / ( ceilDiv(( 16 * T0.size[0] ), 32) ) ) * 32 ) + ( ( ( ( ( i99 + nvfuser_zero ) * 128 ) + threadIdx.x ) * 2 ) % 32 ) ) / T0.size[2] ) * 16 ) + ( ( ( ( blockIdx.x % ( ceilDiv(( 16 * T0.size[0] ), 32) ) ) * 32 ) + ( ( ( ( ( i99 + nvfuser_zero ) * 128 ) + threadIdx.x ) * 2 ) / 32 ) ) / T0.size[0] ) ) * T0.size[2] ), ( ( ( ( blockIdx.x / ( ceilDiv(( 16 * T0.size[0] ), 32) ) ) * 32 ) + ( ( ( ( ( i99 + nvfuser_zero ) * 128 ) + threadIdx.x ) * 2 ) % 32 ) ) % T0.size[2] )] view( T0 );

output conflict: 2 way
Expr: T1_s[( ( ( ( ( i99 * 128 ) + threadIdx.x ) * 2 ) / 32 ) * 32 ), ( ( ( ( i99 * 128 ) + threadIdx.x ) * 2 ) % 32 )] view( T4 )
   = T0_g[( ( ( ( ( blockIdx.x % ( ceilDiv(( 16 * T0.size[0] ), 32) ) ) * 32 ) + ( ( ( ( ( i99 + nvfuser_zero ) * 128 ) + threadIdx.x ) * 2 ) / 32 ) ) % T0.size[0] ) * ( T0.size[2] * T0.size[1] ) ), ( ( ( ( ( ( ( blockIdx.x / ( ceilDiv(( 16 * T0.size[0] ), 32) ) ) * 32 ) + ( ( ( ( ( i99 + nvfuser_zero ) * 128 ) + threadIdx.x ) * 2 ) % 32 ) ) / T0.size[2] ) * 16 ) + ( ( ( ( blockIdx.x % ( ceilDiv(( 16 * T0.size[0] ), 32) ) ) * 32 ) + ( ( ( ( ( i99 + nvfuser_zero ) * 128 ) + threadIdx.x ) * 2 ) / 32 ) ) / T0.size[0] ) ) * T0.size[2] ), ( ( ( ( blockIdx.x / ( ceilDiv(( 16 * T0.size[0] ), 32) ) ) * 32 ) + ( ( ( ( ( i99 + nvfuser_zero ) * 128 ) + threadIdx.x ) * 2 ) % 32 ) ) % T0.size[2] )] view( T0 );

output conflict: 2 way
================================

@zasdfgbnm zasdfgbnm requested review from csarofeen and naoyam October 4, 2022 09:25
if (conflict_ways.first > 1 || conflict_ways.second > 1) {
bank_conflict_info_[expr] = conflict_ways;
}
} else if (expr->isA<LoadStoreOp>()) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if ld.matrix is supported correctly here. Will need to dig deeper to understand how it works. I will update this in later PR.

Comment on lines +63 to +68
if (fl->index()->isA<NamedScalar>() &&
fl->index()->as<NamedScalar>()->name() == "threadIdx.x") {
expr_eval.bind(fl->index(), tidx);
} else {
expr_eval.bind(fl->index(), 0);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this assume threadIdx.y/z don't matter for bank conflicts? What about if we have, e.g., blockDim.x == 1 && blockDim.y == 32?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am currently making this assumption. But I can add a new overload that takes a launch parameter, which will lift this assumption.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with the current state as long as the assumptions are made clear. Not sure how important to make this more flexible at this point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also OK with this assumption in this PR. I will write a few followup PRs to lift some of these assumptions.

@naoyam
Copy link
Collaborator

naoyam commented Oct 4, 2022

Compile time is nice. Is there any limitation though? Can all bank conflicts be found at the compile time? This doesn't necessarily be perfect, but it'd be important to understand limitations if any. Left a related inline comment.

@zasdfgbnm
Copy link
Collaborator Author

zasdfgbnm commented Oct 4, 2022

Compile time is nice. Is there any limitation though? Can all bank conflicts be found at the compile time?

Good question!

On the one hand, for our current supported cases, I can not think of any of them that can not be supported at compile time. In the future, if we begin to support crazy things that have data-dependent indexing (such as T1[T0[i0] + 1]), then this could become an issue.

But on the other hand, the bank conflict checking utility as in this PR made many assumptions:

  1. blockDim.x is large enough to hold one phase
  2. The address only depends on loop variables (there can not be a thing like T0.stride[0], blockDim.x)
  3. The data of the tensor is accessed by T0[index], where index is the one stored in the TensorIndex object, not sure if there is any case that does not use such an access pattern.
  4. Only checking the first iteration, and the start of all loop variables are assumed to be 0 (if we have something like T1_s[tidx, 5], then different iterations should have different results, which this utility will not be able to handle now)
  5. All shared memory tensors are allocated starting from a multiple of (4*32)
  6. The only source of bank confliction is from within a tensor. There is no bank conflict between different tensors. This will fail if we have a lot of tensors like T1_s[tidx, 2] ca_pos(1). (Edit: Just checked, T1_s[tidx, 2] ca_pos(1) should be handled correctly)

And besides that, this utility is only tested with limited cases (which are the transpose examples I added in unit tests). And if there are cases that could break the above assumption, it is possible that this utility will just generate silent wrong result. So we definitely need to improve this over time, and add tests for more cases. And I would not consider this tool as a single source of trust.

@zasdfgbnm
Copy link
Collaborator Author

zasdfgbnm commented Oct 4, 2022

For the above assumptions, 1 and 2 can be easily lifted by adding new interface to allow user feeding more information, 3 should be improved over time when we add more tests and does more code review. For 4, it would be easy to lift the "start from 0" assumption, and I don't see any value for computing the bank conflict for all iterations (I am not developing a tool to replace nsight compute). For 6, does our system support this use case? I can not imagine how the allocation will be done.(Edit: checked, this is supported well)

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving as I think this is good enough although it has several assumptions. Please make sure the assumptions are added as comments.

@zasdfgbnm zasdfgbnm merged commit bc77266 into devel Oct 4, 2022
@zasdfgbnm zasdfgbnm deleted the bank-conflict-checker branch October 4, 2022 19:50
jjsjann123 added a commit that referenced this pull request Nov 9, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Codegen changes include:

* codegen improvement:
    i. allow non-root trivial reductions, allow empty/no-op fusion
    ii. fixes vectorization checks and size calculation
    iii. bank conflict handle improvement
    iv. enables transpose scheduler

* misc:
    i. CI tests failure fixes
    ii. cpp tests file clean up
    iii. trivial forwarding supports added in codegen runtime
    iv. added factory methods support in codegen

Commits that's in this PR from the devel branch:

```
7117a7e patching nvfuser conv cudnn test numerics mismatch (#2048)
65af1a4 Inserting sync for redundant parallel types is already done at the (#2023)
6ac74d1 Fix sync map (#2047)
f5bca33 Bank conflict checker improvements (#2032)
d2ca7e3 Minor update on cp.async code generation. (#1901)
d36cf61 Test file cleanup (#2040)
0b8e83f Allow non-root trivial reductions (#2037)
a2dfe40 Fix vectorize size calculation (#2035)
e040676 Use withPredicate to replace setPredicate to maintain Exprs immutable (#2025)
197221b removing ci workflow (#2034)
40e2703 Reduction rand like patch (#2031)
bc77266 Add utility for checking bank conflict of shared memory (#2029)
ddd1cf7 Add back FusionReductionWithTrivialReduction_CUDA (#2030)
fbd97e5 Revert "Cleanup trivial reduction workarounds (#2006)" (#2024)
bca20c1 Cleanup trivial reduction workarounds (#2006)
e4b6585 Trivial forwarding (#1995)
1a0e355 Fix contiguity analysis of predicates to match updated contiguity. (#1991)
a4effa6 Enable output allocation cache (#2010)
35440b7 Patching bn inference (#2016)
0f9f0b4 Add matmul benchmark (#2007)
45045cd Enable tests previously disabled due to an aliasing bug (#2005)
967aa77 Contiguous indexing for View operations (#1990)
a43cb20 Make inlining even more modular (#2004)
dc45835 Test util cleanup (#2003)
3ca21eb More strict validation (#2000)
a7a7d57 Fix build problem (#1999)
fc235b0 Just fixes comments (#1998)
482386c cleanup (#1997)
4cbe0db Improve divisible split detection (#1970)
42ccc52 Minor build fix. (#1996)
fcf8c09 Cleanup of lower_utils.cpp: Isolate out GpuLower usage (#1989)
15f2f6d Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (#1988)
8f1c7f5 Minor cleanup lower_unroll.cpp (#1994)
1d9858c Minor cleanup (#1992)
f262d9c Add support for uniform RNG (#1986)
eb1dad1 Remove non-const functions, remove GpuLower instance on build, pass in ca_map. (#1987)
634820c Add support for some empty fusion (#1981)
eabe8d8 Segment self mapping fusions (#1954)
e96aacf Enable Transpose operation (#1882)
425dce2 Add a null scheduler that helps segmenting away no-op schedules (#1835)
306d4a6 Fix canScheduleCompileTime check of transpose scheduler (#1969)
b1bd32c Minor fix (#1967)
bd93578 Enable transpose scheduler (#1927)
b7a206e Move scheduler vectorize utilities into their own file (#1959)
d9420e4 View scheduling (#1928)
c668e13 Upstream push ci fixes (#1965)
c40202b Fix dump effective bandwidth (#1962)
93505bc WAR on index mapping when exact and permissive maps differ (#1960)
45e95fd Allow splitting inner-most ID to create virtual innermost ID in transpose scheduler (#1930)
a3ecb33 Improve the comments at the beginning of index_compute.h (#1946)
f7bc341 Remove unused variables (#1955)
df3393a Some cleanup (#1957)
7d1d7c8 TVDomainGuard factory (#1953)
357ba22 Fill allocation with nan on tests (#1956)
8eafc54 Fix detection of unmappable root domains (#1952)
90a51f2 Some indexing cleanups, Add eye support (#1940)
ddc01e4 Exclude unsupported data types (#1951)
992e17c test the groups the same order as they are merged (#1949)
208262b Move detection of self mapping IDs to IterDomainGraph from (#1941)
ac4de38 Merge pull request #1945 from csarofeen/master_merge_0828
6310948 Add full, full_like, zeros, zeros_like, ones, ones_like (#1943)
aab10bc Merge remote-tracking branch 'upstream/viable/strict' into HEAD
4c254c0 Fix arange when step is negative (#1942)
89330aa Tensor factories must set the output shape as its input (#1939)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D40869846](https://our.internmc.facebook.com/intern/diff/D40869846)
Pull Request resolved: pytorch#87779
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants