Skip to content

Commit 2fb688c

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Simplify _NO_WRAPPING_EXCEPTIONS (#7806)
Reviewed By: matteobettini Differential Revision: D48642278 fbshipit-source-id: b74f5744ca32672d70f89dab2e8ef01b073c3be0
1 parent ed885fe commit 2fb688c

File tree

2 files changed

+11
-18
lines changed

2 files changed

+11
-18
lines changed

test/test_datapoints.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,4 +209,3 @@ def test_deepcopy(datapoint, requires_grad):
209209

210210
assert type(datapoint_deepcopied) is type(datapoint)
211211
assert datapoint_deepcopied.requires_grad is requires_grad
212-
assert datapoint_deepcopied.is_leaf

torchvision/datapoints/_datapoint.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,9 @@ def _to_tensor(
3333
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
3434
return tensor.as_subclass(cls)
3535

36-
_NO_WRAPPING_EXCEPTIONS = {
37-
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
38-
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
39-
torch.Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output),
40-
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
41-
# retains the type automatically
42-
torch.Tensor.requires_grad_: lambda cls, input, output: output,
43-
}
36+
# The ops in this set are those that should *preserve* the Datapoint type,
37+
# i.e. they are exceptions to the "no wrapping" rule.
38+
_NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}
4439

4540
@classmethod
4641
def __torch_function__(
@@ -76,22 +71,21 @@ def __torch_function__(
7671
with DisableTorchFunctionSubclass():
7772
output = func(*args, **kwargs or dict())
7873

79-
wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
80-
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
74+
if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls):
75+
# We also require the primary operand, i.e. `args[0]`, to be
8176
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
8277
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
8378
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
8479
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
8580
# be wrapped into a `datapoints.Image`.
86-
if wrapper and isinstance(args[0], cls):
87-
return wrapper(cls, args[0], output)
81+
return cls.wrap_like(args[0], output)
8882

89-
# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
90-
# will retain the input type. Thus, we need to unwrap here.
91-
if isinstance(output, cls):
92-
return output.as_subclass(torch.Tensor)
83+
if isinstance(output, cls):
84+
# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
85+
# so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
86+
return output.as_subclass(torch.Tensor)
9387

94-
return output
88+
return output
9589

9690
def _make_repr(self, **kwargs: Any) -> str:
9791
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.

0 commit comments

Comments
 (0)