Skip to content

Update tls logic to work better with guarded call #73925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

albanD
Copy link
Collaborator

@albanD albanD commented Mar 8, 2022

Description of the new behavior is in PythonFallbackKernel.cpp.
The updated test makes sure that we only call alias on the first Tensor.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 8, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/albanD/pytorch/blob/ac2b8f74902ea0cd8e99972b3e297f9f1a405585/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
linux-binary-manywheel ciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build ciflow/all, ciflow/cpu, ciflow/default, ciflow/libtorch, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
windows-binary-libtorch-debug ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
windows-binary-libtorch-release ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
windows-binary-wheel ciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.3-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-xla-linux-bionic-py3.7-clang8 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk, ciflow/xla 🚫 skipped

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 8, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 3477503 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-11T19:23:37.9641135Z RuntimeError:
2022-03-11T19:23:37.3226402Z Author: PyTorch Team
2022-03-11T19:23:37.3226648Z Author-email: [email protected]
2022-03-11T19:23:37.3226910Z License: BSD-3
2022-03-11T19:23:37.3227172Z Location: /opt/conda/lib/python3.7/site-packages
2022-03-11T19:23:37.3227422Z Requires: typing-extensions
2022-03-11T19:23:37.3227629Z Required-by: 
2022-03-11T19:23:37.3540424Z + python check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt
2022-03-11T19:23:37.9640010Z Traceback (most recent call last):
2022-03-11T19:23:37.9640525Z   File "check_forward_backward_compatibility.py", line 308, in <module>
2022-03-11T19:23:37.9640947Z     s = parse_schema(line.strip())
2022-03-11T19:23:37.9641135Z RuntimeError: 
2022-03-11T19:23:37.9641379Z Unknown custom class type profiler._RecordFunction. Please ensure it is registered.:
2022-03-11T19:23:37.9641956Z profiler::_record_function_exit._RecordFunction(__torch__.torch.classes.profiler._RecordFunction _0) -> ()
2022-03-11T19:23:37.9642344Z                                                                                  ~~~~~~~~~~~~~~~ <--- HERE
2022-03-11T19:23:37.9642476Z 
2022-03-11T19:23:38.0400911Z + cleanup
2022-03-11T19:23:38.0401143Z + retcode=1
2022-03-11T19:23:38.0401290Z + set +x
2022-03-11T19:23:38.0442969Z ##[error]Process completed with exit code 1.
2022-03-11T19:23:38.0472466Z ##[group]Run # Ensure the working directory gets chowned back to the current user
2022-03-11T19:23:38.0472784Z �[36;1m# Ensure the working directory gets chowned back to the current user�[0m

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty cool! Just one small comment on naming:

I think some of the meaning of the terms shifted during this PR, which confused me for a bit. Previously TLSState referred to the current local dispatch key set. In this PR TLSState seems to refers to tls_on_entry. Stash also means something slightly different now.

@soulitzer
Copy link
Contributor

Also TLSState = Thread Local State State? :P

fwAD.make_dual(s, torch.rand_like(s))
self.assertEqual(counter[0], 1)
fwAD.make_dual(torch.rand_like(s), s)
self.assertEqual(counter[0], 2)
self.assertEqual(counter[0], 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I realize what is happening here now, my test last time was completely oblivious to what torch function does.

Before this PR torch_function was wrapping the output of torch.rand_like in the subclass, but now that we disable torch function for this subclass, that no longer happens, and the counter is only triggered once as expected.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

@@ -7939,13 +7939,16 @@ class MySubclass(torch.Tensor):
def __new__(cls, data=None):
return torch.Tensor._make_subclass(cls, data)

__torch_function__ = torch._C._disabled_torch_function_impl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(although I think we can get rid of this after #73942)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes depending on which one lands first.

@albanD albanD requested a review from ezyang March 11, 2022 19:10
@albanD
Copy link
Collaborator Author

albanD commented Mar 11, 2022

FYI @ezyang as well as you're thinking on how to replace this system altogether.

@albanD
Copy link
Collaborator Author

albanD commented Mar 11, 2022

This is ready for a final review

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@facebook-github-bot
Copy link
Contributor

@albanD has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Mar 14, 2022
Summary:
Description of the new behavior is in PythonFallbackKernel.cpp.
The updated test makes sure that we only call alias on the first Tensor.

Pull Request resolved: #73925

Reviewed By: samdow

Differential Revision: D34862940

Pulled By: albanD

fbshipit-source-id: 4d020e41c8bb8b10262dcafd524e84a5ad4d7af0
@github-actions
Copy link
Contributor

Hey @albanD.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@albanD albanD added the topic: not user facing topic category label Mar 15, 2022
@zou3519
Copy link
Contributor

zou3519 commented Mar 15, 2022

@albanD @soulitzer this PR broke functorch: https://app.circleci.com/pipelines/github/pytorch/functorch/2032/workflows/2a0f2124-3870-4c6e-b95b-00951c1e9332/jobs/12098 (I confirmed by bisection to this PR that it is the one that causes the failure)

I don't know the exact reason, but I suspect it has to do with the following:

  1. functorch's DynamicLayerFrontMode is active inside of vmap.
  2. PythonTLS saves DynamicLayerFrontMode as a part of the "saved TLS".
  3. Somehow this goes wrong.

This doesn't seem easy to resolve. Could you help take a look please?

Minimal repro:

import torch
import functorch
from functorch import vmap, make_fx

def f(x):
    return torch.sin(x)
inp = torch.randn(5, 3)
f = vmap(f)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(5, 3)

@albanD
Copy link
Collaborator Author

albanD commented Mar 15, 2022

Hmmm not sure why this PR would change this behavior as this was happening before?
But yes, functorch breaks some of the assumptions we make here about how the dispatcher is used by jumping at specific keys. So I am not very surprised :/

@zou3519
Copy link
Contributor

zou3519 commented Mar 15, 2022

It might not look this way because of our lack of test coverage for this, but specifically this PR seems to break the torch_dispatch interaction with functorch. The main consequence of this right now is that we cannot compose AOTAutograd with functorch eager-mode transforms (e.g. vmap, grad, etc), so this is a really large breakage for us.

It would be a huge favor to us if we could figure out how to resolve this soon.

zou3519 added a commit that referenced this pull request Mar 15, 2022
zou3519 added a commit that referenced this pull request Mar 15, 2022
This reverts commit dff0285.

ghstack-source-id: e3d459a
Pull Request resolved: #74268
@zou3519
Copy link
Contributor

zou3519 commented Mar 16, 2022

Here's my current understanding to what is happening.

The purpose of this PR is to save the TLS when going from Python -> into the dispatcher so that when __torch_dispatch__ is invoked, the TLS may be set to the saved one.

In the example,

f = vmap(torch.sin)
fx_f = make_fx(f)(inp)

we are:

  • creating a ProxyTensor object, then calling vmap(f)(ProxyTensor)
  • vmap adds DynamicLayerFrontMode to the dispatch local include key set.
  • DynamicLayerFrontMode gets removed from the local include set after the batching rule for sin gets called.
  • ProxyTensor.sin re-enables DynamicLayerFrontMode, which is a problem because we've already handled vmap.

So I think the question now is how should the TLS logic in this PR interact with mode-style dispatch keys?

  • Let's say we have a HypotheticalMode dispatch key, and it is sandwiched between [PythonTLS, HypotheticalMode, Python].
  • Furthermore, let's say that HypotheticalMode::sin (1) prints "sin" (2) temporarily removes HypotheticalMode from the local include set (in other words, it is "done"), and (3) dispatches down the line (either call or redispatch).

Imagine some Python code that looks like:

subclass = TensorSubclass(x)
add_to_tls_local_include_set(HypotheticalMode)
subclass.sin()

This would end up running through HypotheticalMode twice. Would that be the desired behavior?

@albanD
Copy link
Collaborator Author

albanD commented Mar 16, 2022

Would that be the desired behavior?

Yes. The high level idea of this TLS logic is the following: if you call torch.foo(*args) from top python or from within the torch_dispatch of torch.foo for one of the args. You should get the exact same result.

So in this case, that looks like the expected behavior to me? (not that we want to keep it necessarily but at least it does follow the high level idea)

pytorchmergebot pushed a commit that referenced this pull request Mar 16, 2022
facebook-github-bot pushed a commit that referenced this pull request Mar 17, 2022
…74268)

Summary:
This reverts commit dff0285.

Pull Request resolved: #74268

Approved by: https://github.com/albanD

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/a9d9f91f317adca4ce935ec6de63119bda0f6732

Reviewed By: jbschlosser

Differential Revision: D34914754

Pulled By: zou3519

fbshipit-source-id: 600105afff87e02194a8b3445a4a38e9995b16c9
@suo suo mentioned this pull request Mar 22, 2022
zou3519 added a commit that referenced this pull request Mar 22, 2022
This PR relands #73925 which we
reverted due to a large breakage in functorch.

As a part of the reland, this PR adds a change we agreed upon in
https://docs.google.com/document/d/1i7Y9VZp9PxtgVcrQh6nGQXkXkPc1uMep0dM-OMOGJ9o/edit
The change is moving the PythonTLSSnapshot key after
DynamicLayerFrontMode.

Test Plan:
- I tested this with an updated version of functorch and all the tests
pass so I think we are out of the woods.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Mar 22, 2022
This PR relands #73925 which we
reverted due to a large breakage in functorch.

As a part of the reland, this PR adds a change we agreed upon in
https://docs.google.com/document/d/1i7Y9VZp9PxtgVcrQh6nGQXkXkPc1uMep0dM-OMOGJ9o/edit
The change is moving the PythonTLSSnapshot key after
DynamicLayerFrontMode.

Test Plan:
- I tested this with an updated version of functorch and all the tests
pass so I think we are out of the woods.

ghstack-source-id: 9b5aa58
Pull Request resolved: #74577
zou3519 added a commit that referenced this pull request Mar 23, 2022
This PR relands #73925 which we
reverted due to a large breakage in functorch.

As a part of the reland, this PR adds a change we agreed upon in
https://docs.google.com/document/d/1i7Y9VZp9PxtgVcrQh6nGQXkXkPc1uMep0dM-OMOGJ9o/edit
The change is moving the PythonTLSSnapshot key after
DynamicLayerFrontMode.

Test Plan:
- I tested this with an updated version of functorch and all the tests
pass so I think we are out of the woods.

ghstack-source-id: ee34b1b
Pull Request resolved: #74577
zou3519 added a commit that referenced this pull request Mar 23, 2022
…ith guarded call (#73925)"

This PR relands #73925 which we
reverted due to a large breakage in functorch.

As a part of the reland, this PR adds a change we agreed upon in
https://docs.google.com/document/d/1i7Y9VZp9PxtgVcrQh6nGQXkXkPc1uMep0dM-OMOGJ9o/edit
The change is moving the PythonTLSSnapshot key after
DynamicLayerFrontMode.

Test Plan:
- I tested this with an updated version of functorch and all the tests
pass so I think we are out of the woods.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Mar 23, 2022
…#73925)"

This PR relands #73925 which we
reverted due to a large breakage in functorch.

As a part of the reland, this PR adds a change we agreed upon in
https://docs.google.com/document/d/1i7Y9VZp9PxtgVcrQh6nGQXkXkPc1uMep0dM-OMOGJ9o/edit
The change is moving the PythonTLSSnapshot key after
DynamicLayerFrontMode.

Test Plan:
- I tested this with an updated version of functorch and all the tests
pass so I think we are out of the woods.

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Mar 25, 2022
This PR relands #73925 which we
reverted due to a large breakage in functorch.

As a part of the reland, this PR adds a change we agreed upon in
https://docs.google.com/document/d/1i7Y9VZp9PxtgVcrQh6nGQXkXkPc1uMep0dM-OMOGJ9o/edit
The change is moving the PythonTLSSnapshot key after
DynamicLayerFrontMode.

Test Plan:
- I tested this with an updated version of functorch and all the tests
pass so I think we are out of the woods.

Pull Request resolved: #74577

Approved by: https://github.com/albanD
facebook-github-bot pushed a commit that referenced this pull request Mar 29, 2022
…74577)

Summary:
This PR relands #73925 which we
reverted due to a large breakage in functorch.

As a part of the reland, this PR adds a change we agreed upon in
https://docs.google.com/document/d/1i7Y9VZp9PxtgVcrQh6nGQXkXkPc1uMep0dM-OMOGJ9o/edit
The change is moving the PythonTLSSnapshot key after
DynamicLayerFrontMode.

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/a75c718d7c370becbf8351f5bba013cfd66cf983

Test plan from GitHub:
- I tested this with an updated version of functorch and all the tests
pass so I think we are out of the woods.

Pull Request resolved: #74577

Approved by: https://github.com/albanD

Reviewed By: malfet

Differential Revision: D35188072

Pulled By: zou3519

fbshipit-source-id: a613760e43c0c7b918e536ade1e6935dc88cd3ca
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants