Skip to content

Commit f3665dd

Browse files
Chilleepytorchmergebot
authored andcommitted
Reland #2 of "Added {logical_not, trace} refs, moved logical ops to use method overloads"
Pull Request resolved: pytorch#79819 Approved by: https://github.com/mruberry
1 parent 26b5129 commit f3665dd

File tree

5 files changed

+61
-27
lines changed

5 files changed

+61
-27
lines changed

test/test_meta.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,6 @@ def run_meta_crossref(
400400
torch.mode: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::mode
401401
torch.multinomial: {bf16, f32, f64}, # aten::multinomial, aten::multinomial.out
402402
torch.mvlgamma: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense, aten::mvlgamma.out
403-
torch.nanmean: {bf16, f16, f32, f64},
404-
torch.nanquantile: {f32, f64},
405403
torch.nn.functional.conv1d: {bf16, f32, f64, i64},
406404
torch.nn.functional.conv2d: {bf16, f32, f64, i64},
407405
torch.nn.functional.conv_transpose1d: {f32, f64, i64},
@@ -465,9 +463,9 @@ def run_meta_crossref(
465463
torch.functional.cdist: {f32, f64},
466464
torch.functional.tensordot: {bf16, f32, f64, i16, i32, i64, i8, u8},
467465
torch.inner: {bf16, f32, f64, i16, i32, i64, i8, u8},
468-
torch.logical_not: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8},
469466
torch.nn.functional.cross_entropy: {bf16, f32, f64},
470467
torch.nn.functional.interpolate: {bf16, f32, f64, u8},
468+
torch.nanmean: {bf16, f16, f32, f64}, # TODO(chilli): Doesn't seem to work for some reason?
471469
torch.nn.functional.nll_loss: {bf16, f32, f64}, # TODO
472470
torch.linalg.pinv: {f32, f64},
473471
torch.empty: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
@@ -627,8 +625,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
627625
aten.log_sigmoid_forward.output: {bf16, f64, f32},
628626
aten.logcumsumexp.default: {bf16, f64, f32},
629627
aten.logcumsumexp.out: {bf16, f64, f32},
630-
aten.logical_not.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
631-
aten.logical_not_.default: {bf16, f16, f64, f32},
632628
aten.masked_select.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
633629
aten.masked_select.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
634630
aten.max_pool3d_with_indices.default: {f64, f32},

torch/_decomp/decompositions.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,11 +1005,6 @@ def _fused_dropout_decomposition(input, p, generator=None):
10051005
return (res, mask)
10061006

10071007

1008-
@register_decomposition(aten.logical_not)
1009-
def logical_not(self: Tensor) -> Tensor:
1010-
return ~self.to(dtype=torch.bool)
1011-
1012-
10131008
@register_decomposition(aten.xlogy.Tensor)
10141009
@pw_cast_for_int_to_real
10151010
def xlogy(self: Tensor, other: Tensor) -> Tensor:
@@ -1166,11 +1161,6 @@ def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:
11661161
return result.log().add(maxes_squeezed)
11671162

11681163

1169-
@register_decomposition(aten.trace.default)
1170-
def trace(self: Tensor) -> Tensor:
1171-
return torch.sum(torch.diag(self))
1172-
1173-
11741164
# nb: Should use acc_t, not op_math
11751165
@register_decomposition(aten.log_sigmoid_forward)
11761166
@out_wrapper_multi('output', 'buffer')

torch/_prims/context.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@ def torch_to_refs_map():
2727
(torch.nn.functional, torch._refs.nn.functional),
2828
(torch.special, torch._refs.special),
2929
]
30-
r = {}
30+
r: Dict[Any, Any] = {
31+
torch.Tensor.__invert__: torch._refs.bitwise_not,
32+
torch.Tensor.__xor__: torch._refs.bitwise_xor,
33+
torch.Tensor.__and__: torch._refs.bitwise_and,
34+
torch.Tensor.__or__: torch._refs.bitwise_or,
35+
torch.Tensor.__eq__: torch._refs.eq,
36+
}
3137
for mod_torch, mod_refs in modules:
3238
for s in mod_refs.__all__: # type: ignore[attr-defined]
3339
r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)

torch/_refs/__init__.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
"square",
8989
"tan",
9090
"tanh",
91+
"trace",
9192
#
9293
# Elementwise Binary References
9394
#
@@ -119,6 +120,7 @@
119120
# 'ldexp',
120121
"le",
121122
"logical_and",
123+
"logical_not",
122124
"logical_or",
123125
"logical_xor",
124126
"lt",
@@ -996,10 +998,10 @@ def _lcm(a: TensorLikeType, b: TensorLikeType):
996998

997999
def _logical_and(a: TensorLikeType, b: TensorLikeType):
9981000
if not utils.is_boolean_dtype(a.dtype):
999-
a = ne(a, 0)
1001+
a = a != 0
10001002
if not utils.is_boolean_dtype(b.dtype):
1001-
b = ne(b, 0)
1002-
return bitwise_and(a, b)
1003+
b = b != 0
1004+
return a & b
10031005

10041006

10051007
logical_and = _make_elementwise_binary_reference(
@@ -1009,12 +1011,21 @@ def _logical_and(a: TensorLikeType, b: TensorLikeType):
10091011
)
10101012

10111013

1014+
@_make_elementwise_unary_reference(
1015+
ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, aten_op=torch.ops.aten.logical_not
1016+
)
1017+
def logical_not(a: TensorLikeType):
1018+
if not utils.is_boolean_dtype(a.dtype):
1019+
return a == 0
1020+
return ~a
1021+
1022+
10121023
def _logical_or(a: TensorLikeType, b: TensorLikeType):
10131024
if not utils.is_boolean_dtype(a.dtype):
1014-
a = ne(a, 0)
1025+
a = a != 0
10151026
if not utils.is_boolean_dtype(b.dtype):
1016-
b = ne(b, 0)
1017-
return bitwise_or(a, b)
1027+
b = b != 0
1028+
return a | b
10181029

10191030

10201031
logical_or = _make_elementwise_binary_reference(
@@ -1026,10 +1037,10 @@ def _logical_or(a: TensorLikeType, b: TensorLikeType):
10261037

10271038
def _logical_xor(a: TensorLikeType, b: TensorLikeType):
10281039
if not utils.is_boolean_dtype(a.dtype):
1029-
a = ne(a, 0)
1040+
a = a != 0
10301041
if not utils.is_boolean_dtype(b.dtype):
1031-
b = ne(b, 0)
1032-
return bitwise_xor(a, b)
1042+
b = b != 0
1043+
return a ^ b
10331044

10341045

10351046
# TODO: skip unnecessary conversion of long to float
@@ -2614,6 +2625,13 @@ def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
26142625
return item(all(eq(a, b))) # type: ignore[return-value]
26152626

26162627

2617-
# populate the decomp table
2628+
@register_decomposition(torch.ops.aten.trace)
2629+
def trace(self: TensorLikeType) -> TensorLikeType:
2630+
utils.check(
2631+
self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
2632+
)
2633+
return torch.sum(torch.diag(self, 0))
2634+
2635+
26182636
import torch._refs.nn.functional
26192637
import torch._refs.special

torch/testing/_internal/common_methods_invocations.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3662,6 +3662,10 @@ def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs):
36623662
requires_grad=requires_grad))),)
36633663

36643664

3665+
def error_inputs_trace(op, device):
3666+
yield ErrorInput(SampleInput(make_tensor((3, 4, 5), dtype=torch.float32, device=device)), error_regex="expected a matrix")
3667+
3668+
36653669
def sample_inputs_renorm(self, device, dtype, requires_grad, **kwargs):
36663670
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
36673671
cases = (((S, S, S), (2, 1, 0.5)),
@@ -4330,7 +4334,6 @@ def error_inputs_embedding(op_info, device, **kwargs):
43304334
def error_inputs_t(op_info, device, **kwargs):
43314335
yield ErrorInput(
43324336
SampleInput(torch.randn(2, 3, 4, 5, device=device)),
4333-
error_type=RuntimeError,
43344337
error_regex="expects a tensor with <= 2",
43354338
)
43364339

@@ -17634,6 +17637,7 @@ def error_inputs_mean(op_info, device, **kwargs):
1763417637
dtypes=all_types_and_complex(),
1763517638
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
1763617639
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
17640+
error_inputs_func=error_inputs_trace,
1763717641
supports_inplace_autograd=False,
1763817642
supports_out=False,
1763917643
supports_forward_ad=True,
@@ -20620,6 +20624,16 @@ def __init__(
2062020624
),
2062120625
)
2062220626
),
20627+
ElementwiseUnaryPythonRefInfo(
20628+
"_refs.logical_not",
20629+
torch_opinfo_name="logical_not",
20630+
skips=(
20631+
DecorateInfo(
20632+
# NotImplementedError: argument of type: <class 'complex'>
20633+
unittest.skip("Fails aten complex and nvfuser doesn't support eq(a, 0)"), 'TestCommon', 'test_python_ref_executor'
20634+
),
20635+
)
20636+
),
2062320637
ElementwiseBinaryPythonRefInfo(
2062420638
"_refs.logical_or",
2062520639
torch_opinfo_name="logical_or",
@@ -21193,6 +21207,16 @@ def __init__(
2119321207
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
2119421208
),
2119521209
),
21210+
PythonRefInfo(
21211+
"_refs.trace",
21212+
torch_opinfo_name="trace",
21213+
decorators=(
21214+
# TODO: torch.diag is currently not supported by either refs, meta funcs, or NVFuser
21215+
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
21216+
DecorateInfo(unittest.skip("diag is not supported by meta"), 'TestCommon', 'test_python_ref_meta'),
21217+
DecorateInfo(unittest.skip("diag is not supported by nvfuser"), 'TestCommon', 'test_python_ref_executor'),
21218+
),
21219+
),
2119621220
#
2119721221
# Tensor Creation Reference OpInfos
2119821222
#

0 commit comments

Comments
 (0)