Skip to content

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Sep 27, 2022

Enables trivial forwarding

Background:
nvfuser codegen doesn't handle aliases at all. When we have a fusion that forwards an input to output without any operations on it, this is a no-op for codegen and the output tensor is never written to. However, the codegen cannot "forward" an input to output, since all outputs are allocated in integration. If we do not special case it, we'll ended up having a "fresh" tensor allocated for the forwarded-input.

Approach:
There are two aspects of the support:
step 1. Codegen handles forwarding implicitly. Forwarded inputs doesn't have any producer in the IR, hence the output argument is not used in the code. But it does require to have an argument in the kernel as a place-holder so we'll map each arguments correctly.
step 2. Integration handles the trivial forwarding of inputs. When we put together fusion_outputs for a given fusion, when outputs are just fusion inputs, we directly return the input tensor.

@jjsjann123
Copy link
Collaborator Author

WIP. There's still failing cpp tests with IMA. 😮‍💨

Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm missing some context. Why is there special handling here with a placeholder? Can you give me some more information of what's going on and what you're trying to do?

for (auto inp_i : c10::irange(kernel->inputs().size())) {
if (kernel->inputs()[inp_i] == kernel->outputs()[out_i]) {
TORCH_INTERNAL_ASSERT(
inp_i < args.size(),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the assumption here? Hard to understand what this is checking. Could you comment why args.size() would be less than kernel->inputs().size()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at this point their sizes should match. So we should instead do an assert before the loop on
args.size() == kernel->inputs().size()

c10::Device device(c10::DeviceType::CUDA, args.getDeviceIndex());
const auto tensor_options =
at::TensorOptions().dtype(at::kFloat).device(device);
outputs.emplace_back(at::empty({0}, tensor_options));
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need a placeholder here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want a place holder so that we are not allocating anything for trivially forwarded inputs. Basically, codegen is ignoring all trivially forwarded output, because there's no producer of those and we are passing an empty tensor on IO just so we'll match kernel signature.

Later when we put together the output of the fusion, forwarded inputs are copied (only meta tensor) directly to the output.

There's some funny bits here between the interaction of alias vs trivially-forwarded. I don't think we have a fool-proof API for the two at the moment. We can clean up afterwards. I'm enabling the support here since Ivan needed it for his copy_to PR. pytorch#84545

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the early review. I'll put some doc in code when I get it cleaned up.

@jjsjann123
Copy link
Collaborator Author

WIP. There's still failing cpp tests with IMA. face_exhaling

This PR by itself is clean, the IMA comes from enabling cache hit for segmented fusion, which I'll be tracking in a separate PR. I have commented out that code and I'll have this PR cleaned up.

@jjsjann123 jjsjann123 marked this pull request as ready for review September 27, 2022 22:48
@jjsjann123
Copy link
Collaborator Author

Put some documents in there, should be good for review now

Comment on lines 723 to 725
if (kernel->outputs()[out_i]->isFusionInput()) {
TORCH_INTERNAL_ASSERT(false, "trivial input forwarding NOT IMPLEMENTED");
// for (auto inp_i : c10::irange(kernel->inputs().size())) {
// if (kernel->inputs()[inp_i] == kernel->outputs()[out_i]) {
// TORCH_INTERNAL_ASSERT(
// inp_i < inputs.size(),
// "Issue with an input showing up as output, couldn't find
// input.");
// TORCH_INTERNAL_ASSERT(
// inputs[inp_i].isTensor(),
// "Cannot register a scalar as an output in a fusion.");
// outputs.push_back(inputs[inp_i].toTensor());
// break;
// }
// }
for (auto inp_i : c10::irange(kernel->inputs().size())) {
if (kernel->inputs()[inp_i] == kernel->outputs()[out_i]) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this for loop here? Could we simplify the code into:

TORCH_INTERNAL_ASSERT(args.size() == kernel->inputs().size(), ...);
if (kernel->outputs()[out_i]->isFusionInput()) {
  tensor_options = ...
  outputs.emplace_back(at::empty({0}, tensor_options));
}

Doing so, we won't be checking that an output is not a scalar, but isn't this already checked somewhere else? For example, does Fusion::addOutput check it?

at::Tensor t0 = at::randn({10, 4}, options);
at::Tensor t1 = at::randn({10, 4}, options);

// Note
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What note?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops.

@jjsjann123
Copy link
Collaborator Author

CI looks green. Test also passed locally, except for the BN python test that has been failing for a while.

Any remaining issues @csarofeen ?

jjsjann123 added a commit that referenced this pull request Sep 30, 2022
Fixes BN inference. I'm stealing Ivan's changes from pytorch#85562

We are returning mini-batch stats during inference run in aten, this is not the right behavior and we should have changed that instead. But for the time being, let's change nvfuser behavior just to get CI green.

Also, the extra set here to avoid trivial forwarding should be removed once #1995 is merged.
Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the cleanup, LGTM now.

@jjsjann123
Copy link
Collaborator Author

Tests passed. merging this one.

@jjsjann123 jjsjann123 merged commit e4b6585 into devel Sep 30, 2022
@jjsjann123 jjsjann123 deleted the trivial_forwarding branch September 30, 2022 18:45
jjsjann123 added a commit that referenced this pull request Nov 9, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Codegen changes include:

* codegen improvement:
    i. allow non-root trivial reductions, allow empty/no-op fusion
    ii. fixes vectorization checks and size calculation
    iii. bank conflict handle improvement
    iv. enables transpose scheduler

* misc:
    i. CI tests failure fixes
    ii. cpp tests file clean up
    iii. trivial forwarding supports added in codegen runtime
    iv. added factory methods support in codegen

Commits that's in this PR from the devel branch:

```
7117a7e patching nvfuser conv cudnn test numerics mismatch (#2048)
65af1a4 Inserting sync for redundant parallel types is already done at the (#2023)
6ac74d1 Fix sync map (#2047)
f5bca33 Bank conflict checker improvements (#2032)
d2ca7e3 Minor update on cp.async code generation. (#1901)
d36cf61 Test file cleanup (#2040)
0b8e83f Allow non-root trivial reductions (#2037)
a2dfe40 Fix vectorize size calculation (#2035)
e040676 Use withPredicate to replace setPredicate to maintain Exprs immutable (#2025)
197221b removing ci workflow (#2034)
40e2703 Reduction rand like patch (#2031)
bc77266 Add utility for checking bank conflict of shared memory (#2029)
ddd1cf7 Add back FusionReductionWithTrivialReduction_CUDA (#2030)
fbd97e5 Revert "Cleanup trivial reduction workarounds (#2006)" (#2024)
bca20c1 Cleanup trivial reduction workarounds (#2006)
e4b6585 Trivial forwarding (#1995)
1a0e355 Fix contiguity analysis of predicates to match updated contiguity. (#1991)
a4effa6 Enable output allocation cache (#2010)
35440b7 Patching bn inference (#2016)
0f9f0b4 Add matmul benchmark (#2007)
45045cd Enable tests previously disabled due to an aliasing bug (#2005)
967aa77 Contiguous indexing for View operations (#1990)
a43cb20 Make inlining even more modular (#2004)
dc45835 Test util cleanup (#2003)
3ca21eb More strict validation (#2000)
a7a7d57 Fix build problem (#1999)
fc235b0 Just fixes comments (#1998)
482386c cleanup (#1997)
4cbe0db Improve divisible split detection (#1970)
42ccc52 Minor build fix. (#1996)
fcf8c09 Cleanup of lower_utils.cpp: Isolate out GpuLower usage (#1989)
15f2f6d Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (#1988)
8f1c7f5 Minor cleanup lower_unroll.cpp (#1994)
1d9858c Minor cleanup (#1992)
f262d9c Add support for uniform RNG (#1986)
eb1dad1 Remove non-const functions, remove GpuLower instance on build, pass in ca_map. (#1987)
634820c Add support for some empty fusion (#1981)
eabe8d8 Segment self mapping fusions (#1954)
e96aacf Enable Transpose operation (#1882)
425dce2 Add a null scheduler that helps segmenting away no-op schedules (#1835)
306d4a6 Fix canScheduleCompileTime check of transpose scheduler (#1969)
b1bd32c Minor fix (#1967)
bd93578 Enable transpose scheduler (#1927)
b7a206e Move scheduler vectorize utilities into their own file (#1959)
d9420e4 View scheduling (#1928)
c668e13 Upstream push ci fixes (#1965)
c40202b Fix dump effective bandwidth (#1962)
93505bc WAR on index mapping when exact and permissive maps differ (#1960)
45e95fd Allow splitting inner-most ID to create virtual innermost ID in transpose scheduler (#1930)
a3ecb33 Improve the comments at the beginning of index_compute.h (#1946)
f7bc341 Remove unused variables (#1955)
df3393a Some cleanup (#1957)
7d1d7c8 TVDomainGuard factory (#1953)
357ba22 Fill allocation with nan on tests (#1956)
8eafc54 Fix detection of unmappable root domains (#1952)
90a51f2 Some indexing cleanups, Add eye support (#1940)
ddc01e4 Exclude unsupported data types (#1951)
992e17c test the groups the same order as they are merged (#1949)
208262b Move detection of self mapping IDs to IterDomainGraph from (#1941)
ac4de38 Merge pull request #1945 from csarofeen/master_merge_0828
6310948 Add full, full_like, zeros, zeros_like, ones, ones_like (#1943)
aab10bc Merge remote-tracking branch 'upstream/viable/strict' into HEAD
4c254c0 Fix arange when step is negative (#1942)
89330aa Tensor factories must set the output shape as its input (#1939)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D40869846](https://our.internmc.facebook.com/intern/diff/D40869846)
Pull Request resolved: pytorch#87779
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants