Skip to content

Commit abc54f9

Browse files
zou3519pytorchmergebot
authored andcommitted
Revert "Revert "[functorch] Refactor life handle storage (pytorch#90317)"" (pytorch#90856)
Adds the fix for -Wsign-compare. See original PR (pytorch#90317) for commit message Pull Request resolved: pytorch#90856 Approved by: https://github.com/samdow
1 parent 81f351a commit abc54f9

File tree

7 files changed

+87
-57
lines changed

7 files changed

+87
-57
lines changed

aten/src/ATen/functorch/ADInterpreters.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,12 @@ static void autogradBasedTransformProcess(
7777
static void autogradBasedTransformSendToNext(
7878
const c10::OperatorHandle& op,
7979
torch::jit::Stack* stack,
80-
int64_t current_level,
80+
const Interpreter& interpreter,
8181
TransformType transform_type,
8282
optional<bool> prev_grad_mode,
8383
optional<bool> prev_fwd_grad_mode,
8484
bool grad_special_case) {
85+
auto current_level = interpreter.level();
8586
if (transform_type == TransformType::Grad) {
8687
TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
8788
}
@@ -110,7 +111,7 @@ static void autogradBasedTransformSendToNext(
110111
// if (c10::show_dispatch_trace_enabled()) {
111112
// std::cout << "wrap " << current_level << std::endl;
112113
// }
113-
return makeTensorWrapper(tensor, current_level, is_immutable);
114+
return makeTensorWrapper(tensor, interpreter, is_immutable);
114115
};
115116

116117
// TODO: we only need to do the following (marked with !) on in-place functions
@@ -208,8 +209,11 @@ void GradInterpreterPtr::sendToNextInterpreterImpl(
208209
torch::jit::Stack* stack,
209210
bool grad_special_case) {
210211
autogradBasedTransformSendToNext(
211-
op, stack, level(),
212-
TransformType::Grad, prevGradMode(), nullopt, grad_special_case);
212+
op, stack, *base_,
213+
TransformType::Grad,
214+
prevGradMode(),
215+
nullopt,
216+
grad_special_case);
213217
}
214218

215219
void JvpInterpreterPtr::processImpl(
@@ -223,8 +227,11 @@ void JvpInterpreterPtr::sendToNextInterpreterImpl(
223227
torch::jit::Stack* stack,
224228
bool grad_special_case) {
225229
autogradBasedTransformSendToNext(
226-
op, stack, level(),
227-
TransformType::Jvp, nullopt, prevFwdGradMode(), grad_special_case);
230+
op, stack, *base_,
231+
TransformType::Jvp,
232+
nullopt,
233+
prevFwdGradMode(),
234+
grad_special_case);
228235
}
229236

230237
}} // namespace at::functorch

aten/src/ATen/functorch/DynamicLayer.cpp

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,6 @@ RandomnessType DynamicLayer::randomness() const {
7474
return VmapInterpreterPtr(&interpreter_).randomness();
7575
}
7676

77-
// Maps level to life handle, see NOTE: [Life handles and lexically scoped transforms]
78-
// for details
79-
using DynmetaData = std::unordered_map<int64_t, std::shared_ptr<bool>>;
80-
DynmetaData kDynMetaDataSingleton;
81-
82-
static DynmetaData& getGlobalDynmetaData() {
83-
return kDynMetaDataSingleton;
84-
}
85-
8677
// functorch stores some TLS. Inside the TLS is the stack of transforms.
8778
// Unfortunately, since functorch isn't a part of libtorch, we have
8879
// a level of indirection. FuncTorchTLSBase is the interface that lives in libtorch,
@@ -166,10 +157,16 @@ static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
166157
return getRawFunctorchTLS()->dynamicLayerStack;
167158
}
168159

169-
std::shared_ptr<bool> getLifeHandleForLevel(int64_t level) {
170-
auto it = getGlobalDynmetaData().find(level);
171-
TORCH_INTERNAL_ASSERT(it != kDynMetaDataSingleton.end(), "level should be alive");
172-
return it->second;
160+
const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level) {
161+
auto& dynamicLayerStack = dynamicLayerStackAccessor();
162+
TORCH_INTERNAL_ASSERT(
163+
(int64_t)dynamicLayerStack.size() >= level && level >= 1,
164+
"If you're trying to construct a tensor with the current level (",
165+
level,
166+
") then the interpreter for that level must be on the DynamicLayerStack ");
167+
168+
auto& dynamic_layer = dynamicLayerStack[level - 1];
169+
return dynamic_layer.interpreter().is_alive_ptr();
173170
}
174171

175172
optional<DynamicLayer> maybeCurrentDynamicLayer() {
@@ -209,11 +206,6 @@ void setDynamicLayerStack(const std::vector<DynamicLayer>& stack) {
209206
dynamicLayerStackAccessor() = stack;
210207
}
211208

212-
bool areTransformsActive() {
213-
const auto& data = getGlobalDynmetaData();
214-
return !data.empty();
215-
}
216-
217209
DynamicLayer popDynamicLayer() {
218210
auto& dynamicLayerStack = dynamicLayerStackAccessor();
219211
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
@@ -262,32 +254,23 @@ int64_t initAndPushDynamicLayer(
262254
DynamicLayer new_layer(transform_type, layerId, batch_size, randomness, prev_grad_mode, prev_fwd_grad_mode, functionalize_add_back_views);
263255
pushDynamicLayer(std::move(new_layer));
264256

265-
auto& data = getGlobalDynmetaData();
257+
// NB: this function should be called while holding the GIL to avoid races
258+
new_layer.interpreter().set_is_alive(true);
266259

267-
TORCH_INTERNAL_ASSERT(data.find(layerId) == data.end());
268260
if (transform_type == TransformType::Grad) {
269261
TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
270262
}
271263
if (transform_type == TransformType::Jvp) {
272264
TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
273265
}
274-
data[layerId] = std::make_shared<bool>(true);
275266
return layerId;
276267
}
277268

278269
DynamicLayer popDynamicLayerAndDeleteMetadata() {
279270
auto result = popDynamicLayer();
280-
auto level = result.layerId();
281271

282-
// TODO: is this lock safe? No one else should be writing to the same bucket
283-
auto& data = getGlobalDynmetaData();
284-
auto it = data.find(level);
285-
if (it == data.end()) {
286-
return result;
287-
}
288-
// invalidate the thing
289-
*(it->second) = false;
290-
data.erase(level);
272+
// NB: this function should be called while holding the GIL to avoid races
273+
result.interpreter().set_is_alive(false);
291274
return result;
292275
}
293276

aten/src/ATen/functorch/DynamicLayer.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,6 @@ TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
8080
TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
8181
TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included);
8282

83-
// NB: Not lock safe, you should only call this from Python where the GIL will
84-
// prevent race conditions.
85-
TORCH_API bool areTransformsActive();
86-
8783
// NOTE: [Life handles and lexically scoped transforms]
8884
// functorch transforms are lexically scoped.
8985
// Given a level, we store a "life handle" that is a boolean that tells us if the
@@ -92,9 +88,7 @@ TORCH_API bool areTransformsActive();
9288
// functorch's TensorWrapper (for grad transforms) stores a life handle.
9389
// If a TensorWrapper escapes from the scope of the transform, then somehow
9490
// it must know it escaped; it can tell by querying the life handle.
95-
//
96-
// NB: not lock safe. TODO: does it need a lock?
97-
TORCH_API std::shared_ptr<bool> getLifeHandleForLevel(int64_t level);
91+
TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level);
9892

9993
// Returns if an operator is in-place. An operator is inplace if:
10094
// 1. The first argument is a Tensor and it is being written to

aten/src/ATen/functorch/Interpreter.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,31 @@ struct Interpreter {
155155
return *savedLocalDispatchKeySet_;
156156
}
157157

158+
// An Interpreter is alive if we are currently inside the ongoing transform
159+
// for the interpreter. For example, vmap(f)(x); inside of f, the vmap's
160+
// corresponding Interpreter is alive, even when it is not on the DynamicLayerStack.
161+
bool is_alive() const {
162+
return *is_alive_;
163+
}
164+
const std::shared_ptr<bool>& is_alive_ptr() const {
165+
return is_alive_;
166+
}
167+
void set_is_alive(bool alive) {
168+
*is_alive_ = alive;
169+
}
170+
158171
// Please don't use this
159172
explicit Interpreter() = default;
160173

161174
private:
162175
explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
163-
type_(type), level_(level), meta_(meta) {}
176+
type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(meta) {}
164177

165178
// fields
166179
TransformType type_;
167180
int64_t level_;
168181
optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
182+
std::shared_ptr<bool> is_alive_;
169183
InterpreterMeta meta_;
170184
};
171185

aten/src/ATen/functorch/TensorWrapper.cpp

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,22 @@ void dumpTensorCout(const Tensor& tensor) {
5858
std::cout << std::endl;
5959
}
6060

61-
c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, bool should_be_alive) {
61+
c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, const std::shared_ptr<bool>& life_handle) {
6262
auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
6363
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
6464
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
6565
key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
66-
if (should_be_alive) {
67-
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, getLifeHandleForLevel(level));
68-
} else {
69-
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, std::make_shared<bool>(false));
70-
}
66+
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, life_handle);
7167
}
7268

73-
Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable) {
69+
// use makeTensorWrapper instead to avoid potential footguns:
70+
// unsafeMakeTensorWrapper doesn't check that level and life_handle
71+
// refer to the same interpreter
72+
static Tensor unsafeMakeTensorWrapper(
73+
const Tensor& tensor,
74+
int64_t level,
75+
bool is_immutable,
76+
const std::shared_ptr<bool>& life_handle) {
7477
auto wrapped = maybeGetTensorWrapper(tensor);
7578
if (wrapped) {
7679
TORCH_INTERNAL_ASSERT(wrapped->level() < level);
@@ -80,20 +83,38 @@ Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable)
8083
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
8184
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
8285
key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
83-
auto life_handle = getLifeHandleForLevel(level);
84-
auto result = at::detail::make_tensor<TensorWrapper>(key_set, tensor, level, std::move(life_handle), is_immutable);
86+
auto result = at::detail::make_tensor<TensorWrapper>(
87+
key_set, tensor, level, life_handle, is_immutable);
8588
TORCH_INTERNAL_ASSERT(result.key_set().has(DispatchKey::FuncTorchGradWrapper));
8689
return result;
8790
}
8891

92+
Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable) {
93+
auto life_handle = getLifeHandleForLevel(level);
94+
return unsafeMakeTensorWrapper(
95+
tensor,
96+
level,
97+
is_immutable,
98+
getLifeHandleForLevel(level));
99+
}
100+
101+
Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable) {
102+
return unsafeMakeTensorWrapper(
103+
tensor,
104+
interpreter.level(),
105+
is_immutable,
106+
interpreter.is_alive_ptr());
107+
}
108+
109+
89110
bool TensorWrapper::is_alive() const {
90111
return *is_alive_;
91112
}
92113

93114
c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
94115
const c10::VariableVersion& version_counter,
95116
bool allow_tensor_metadata_change) const {
96-
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive());
117+
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive_);
97118
dest_impl->set_version_counter(version_counter);
98119

99120
// TODO: is this even right?
@@ -104,7 +125,7 @@ c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
104125
c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
105126
c10::VariableVersion&& version_counter,
106127
bool allow_tensor_metadata_change) const {
107-
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive());
128+
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive_);
108129
dest_impl->set_version_counter(version_counter);
109130

110131
// TODO: is this even right?

aten/src/ATen/functorch/TensorWrapper.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/functorch/Macros.h>
1010
#include <ATen/Tensor.h>
11+
#include <ATen/functorch/Interpreter.h>
1112

1213
namespace at {
1314
namespace functorch {
@@ -89,7 +90,18 @@ struct TORCH_API TensorWrapper : public c10::TensorImpl {
8990
std::shared_ptr<bool> is_alive_;
9091
};
9192

93+
// There are two variants of makeTensorWrapper: one that accepts a level
94+
// and one that accepts an Interpreter.
95+
//
96+
// The one that accepts a level tries to automatically get the life handle from the
97+
// interpreter on the DynamicLayerStack.
98+
// It needs to be used with caution: if the interpreter is not on the
99+
// DynamicLayerStack, then we won't be able to find the life handle.
100+
//
101+
// In practice this isn't a problem: when we're constructing TensorWrapper in
102+
// Python, the corresponding interpreter is on the stack.
92103
TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable=false);
104+
TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable=false);
93105
TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor);
94106
TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor);
95107
TORCH_API void dumpTensorCout(const Tensor& tensor);

torch/csrc/functorch/init.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,6 @@ void initFuncTorchBindings(PyObject* module) {
458458
m.def("dump_tensor", &dump_tensor, "dump_tensor");
459459
m.def("reshape_dim_into", &at::functorch::reshape_dim_into);
460460
m.def("reshape_dim_outof", &at::functorch::reshape_dim_outof);
461-
m.def("are_transforms_active", &at::functorch::areTransformsActive);
462461
// various debugging things. Maybe we should offer these as first-class APIs
463462
// on Tensors?
464463
m.def("is_batchedtensor", &is_batchedtensor);

0 commit comments

Comments
 (0)