Skip to content

Commit cd23fa3

Browse files
laithsakkacan-gaa-hou
authored andcommitted
Remove guard_size_oblivious from default contiguity python check, and add aten.sym_is_contiguous. (pytorch#159197)
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous() but want to find those call sites to handle this properly by calling is_contiguous_or_false() and not is_contiguous() explitly when appropriate. I had to fix one issue after removing the implicit size oblivious reasoning. here is context we defined in this pytorch#157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE. when people call is_contiguous we do sym_is_contiguous().guard_bool() when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false() one issue not handled well was this path ``` c10::SymBool TensorImpl::sym_is_contiguous_custom( at::MemoryFormat memory_format) const { if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { return pyobj_slot_.load_pyobj_interpreter()->is_contiguous( this, memory_format); } return sym_is_contiguous_default(memory_format); } ``` namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format); This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning. once we removed that implicit size oblivious reasoning, the right thing we want is to call return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format); otherwise we would get DDE even if the caller is doing sym_is_contiguous. so I had to define it for pyinterpreter, and then I had to override it for nested tensors. Pull Request resolved: pytorch#159197 Approved by: https://github.com/ezyang
1 parent 008e0b5 commit cd23fa3

File tree

20 files changed

+141
-34
lines changed

20 files changed

+141
-34
lines changed

aten/src/ATen/native/TensorProperties.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <ATen/ops/is_set_to_native.h>
1919
#include <ATen/ops/size_native.h>
2020
#include <ATen/ops/stride_native.h>
21+
#include <ATen/ops/sym_is_contiguous_native.h>
2122
#include <ATen/ops/sym_numel_native.h>
2223
#include <ATen/ops/sym_size_native.h>
2324
#include <ATen/ops/sym_storage_offset_native.h>
@@ -57,6 +58,12 @@ c10::SymInt sym_size(const Tensor& self, int64_t dim) {
5758
return self.sym_size(dim);
5859
}
5960

61+
c10::SymBool sym_is_contiguous(
62+
const Tensor& self,
63+
c10::MemoryFormat memory_format) {
64+
return self.sym_is_contiguous(memory_format);
65+
}
66+
6067
c10::SymInt sym_stride(const Tensor& self, int64_t dim) {
6168
return self.sym_stride(dim);
6269
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5509,6 +5509,13 @@
55095509
tags: core
55105510
manual_cpp_binding: True
55115511

5512+
- func: sym_is_contiguous(Tensor self, MemoryFormat memory_format=contiguous_format) -> SymBool
5513+
variants: function
5514+
device_check: NoCheck
5515+
device_guard: False
5516+
tags: core
5517+
manual_cpp_binding: True
5518+
55125519
- func: sym_numel(Tensor self) -> SymInt
55135520
variants: function
55145521
device_check: NoCheck

c10/core/TensorImpl.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,15 @@ void TensorImpl::throw_data_ptr_access_error() const {
313313
c10::SymBool TensorImpl::sym_is_contiguous_custom(
314314
at::MemoryFormat memory_format) const {
315315
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
316-
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
317-
this, memory_format);
316+
// TO reduce BC breaking and reduce having to introduce
317+
// sym_is_contiguous. call is_contiguous when tensor does not
318+
if (C10_UNLIKELY(has_symbolic_sizes_strides_)) {
319+
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(
320+
this, memory_format);
321+
} else {
322+
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
323+
this, memory_format);
324+
}
318325
}
319326

320327
return sym_is_contiguous_default(memory_format);

c10/core/impl/PyInterpreter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
6060
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override {
6161
PANIC(is_contiguous);
6262
}
63+
c10::SymBool sym_is_contiguous(const TensorImpl* self, at::MemoryFormat)
64+
const override {
65+
PANIC(sym_is_contiguous);
66+
}
6367
bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
6468
const override {
6569
PANIC(is_strides_like);

c10/core/impl/PyInterpreter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ struct C10_API PyInterpreterVTable {
168168

169169
virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat)
170170
const = 0;
171+
virtual c10::SymBool sym_is_contiguous(
172+
const TensorImpl* self,
173+
at::MemoryFormat) const = 0;
171174
virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
172175
const = 0;
173176
virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0;

test/functorch/test_vmap_registrations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@
208208
"aten::subtract_.Scalar",
209209
"aten::subtract_.Tensor",
210210
"aten::svd.U",
211+
"aten::sym_is_contiguous",
211212
"aten::sym_size.int",
212213
"aten::sym_stride.int",
213214
"aten::sym_numel",

test/test_python_dispatch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,6 +1958,8 @@ def __new__(cls, data, wrapper):
19581958
def __torch_dispatch__(cls, func, types, args, kwargs):
19591959
if func.overloadpacket == torch.ops.aten.is_contiguous:
19601960
return contiguous_data.is_contiguous()
1961+
if func.overloadpacket == torch.ops.aten.sym_is_contiguous:
1962+
return torch.ops.aten.sym_is_contiguous(contiguous_data)
19611963
return NotImplemented
19621964

19631965
class ExampleTensor3(torch.Tensor):
@@ -1971,6 +1973,8 @@ def __new__(cls, data, wrapper):
19711973
def __torch_dispatch__(cls, func, types, args, kwargs):
19721974
if func.overloadpacket == torch.ops.aten.is_contiguous:
19731975
return not_contiguous_data.is_contiguous()
1976+
if func.overloadpacket == torch.ops.aten.sym_is_contiguous:
1977+
return torch.ops.aten.sym_is_contiguous(not_contiguous_data)
19741978
return NotImplemented
19751979

19761980
err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'"
@@ -2003,6 +2007,7 @@ def __new__(cls, data):
20032007
@classmethod
20042008
def __torch_dispatch__(cls, func, types, args, kwargs):
20052009
if func in [
2010+
torch.ops.aten.sym_is_contiguous.default,
20062011
torch.ops.aten.is_contiguous.default,
20072012
torch.ops.aten.is_contiguous.memory_format,
20082013
torch.ops.aten.is_strides_like_format.default,

tools/autograd/gen_python_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
"is_sparse_csr",
9898
"size",
9999
"stride",
100+
"sym_is_contiguous",
100101
"sym_size",
101102
"sym_stride",
102103
"sym_storage_offset",

torch/_dynamo/convert_frame.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1560,7 +1560,6 @@ def __call__(
15601560
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
15611561
) -> ConvertFrameReturn:
15621562
assert frame_state is not None
1563-
15641563
input_codes.add(frame.f_code)
15651564

15661565
is_skipfile = trace_rules.check(frame.f_code)

torch/_prims_common/__init__.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,14 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
265265
from torch.fx.experimental.symbolic_shapes import (
266266
guard_or_false,
267267
guard_or_true,
268-
guard_size_oblivious,
269268
is_nested_int,
270269
)
271270

272-
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
273-
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
271+
def eval_eager(x):
272+
return bool(x)
273+
274+
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
275+
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
274276

275277
if maybe_guard_or_false(a.numel() < 2):
276278
return True
@@ -305,14 +307,13 @@ def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
305307
if a.ndim != 4:
306308
return False
307309

308-
from torch.fx.experimental.symbolic_shapes import (
309-
guard_or_false,
310-
guard_or_true,
311-
guard_size_oblivious,
312-
)
310+
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
311+
312+
def eval_eager(x):
313+
return bool(x)
313314

314-
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
315-
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
315+
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
316+
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
316317

317318
expected_stride = 1
318319
for idx in (1, 3, 2, 0):
@@ -334,14 +335,13 @@ def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool:
334335
if a.ndim != 5:
335336
return False
336337

337-
from torch.fx.experimental.symbolic_shapes import (
338-
guard_or_false,
339-
guard_or_true,
340-
guard_size_oblivious,
341-
)
338+
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
339+
340+
def eval_eager(x):
341+
return bool(x)
342342

343-
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
344-
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
343+
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
344+
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
345345

346346
expected_stride = 1
347347
for idx in (1, 4, 3, 2, 0):
@@ -406,7 +406,7 @@ def is_channels_last_contiguous_or_false_3d(a: Tensor) -> bool:
406406

407407

408408
# similar to is_contiguous_for_memory_format but return false on data dependency.
409-
def contiguous_for_memory_format_or_false( # type: ignore[return]
409+
def is_contiguous_for_memory_format_or_false( # type: ignore[return]
410410
a: Tensor, *, memory_format: torch.memory_format
411411
) -> bool:
412412
return is_contiguous_for_memory_format(
@@ -550,11 +550,14 @@ def compute_elementwise_output_logical_to_physical_perm(
550550
is_contiguous = True
551551
is_channels_last = True
552552
for t in tensors:
553-
is_contiguous = is_contiguous and contiguous_for_memory_format_or_false(
553+
is_contiguous = is_contiguous and is_contiguous_for_memory_format_or_false(
554554
t, memory_format=torch.contiguous_format
555555
)
556-
is_channels_last = is_channels_last and contiguous_for_memory_format_or_false(
557-
t, memory_format=torch.channels_last
556+
is_channels_last = (
557+
is_channels_last
558+
and is_contiguous_for_memory_format_or_false(
559+
t, memory_format=torch.channels_last
560+
)
558561
)
559562

560563
if is_contiguous and not is_channels_last:

0 commit comments

Comments
 (0)