Skip to content

Commit 965f4ea

Browse files
SherlockNoMadpytorchmergebot
authored andcommitted
[Reland] Add sym_size/stride/numel/storage_offset to native_function.yaml (ROCm#91… (pytorch#92402)
Pull Request resolved: pytorch#91919 Approved by: https://github.com/ezyang Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#92402 Approved by: https://github.com/ezyang
1 parent 79db5bc commit 965f4ea

File tree

9 files changed

+64
-25
lines changed

9 files changed

+64
-25
lines changed

aten/src/ATen/core/function_schema.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type)
1919
}
2020

2121
FunctionSchema FunctionSchema::cloneWithRealTypes(bool with_symint) const {
22+
auto alwaysCloneWithRealTypes = [&](const Argument& a) {
23+
return a.cloneWithType(a.real_type());
24+
};
2225
auto cloneWithRealTypes = [&](const Argument& a) {
2326
if (with_symint) {
2427
return a.cloneWithType(a.real_type());
@@ -39,7 +42,8 @@ FunctionSchema FunctionSchema::cloneWithRealTypes(bool with_symint) const {
3942
};
4043
std::vector<Argument> new_arguments, new_returns;
4144
std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes);
42-
std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), cloneWithRealTypes);
45+
// NB: SymInt returns are always SymInt
46+
std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), alwaysCloneWithRealTypes);
4347
return FunctionSchema(
4448
name(),
4549
overload_name(),

aten/src/ATen/native/TensorProperties.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,22 @@ int64_t stride(const Tensor& self, int64_t dim) {
4949
return self.stride(dim);
5050
}
5151

52+
c10::SymInt sym_size(const Tensor& self, int64_t dim) {
53+
return self.sym_size(dim);
54+
}
55+
56+
c10::SymInt sym_stride(const Tensor& self, int64_t dim) {
57+
return self.sym_stride(dim);
58+
}
59+
60+
c10::SymInt sym_numel(const Tensor& self) {
61+
return self.sym_numel();
62+
}
63+
64+
c10::SymInt sym_storage_offset(const Tensor& self) {
65+
return self.sym_storage_offset();
66+
}
67+
5268
int64_t size(const Tensor& self, Dimname dim) {
5369
size_t pos_dim = dimname_to_position(self, dim);
5470
return self.sizes()[pos_dim];

aten/src/ATen/native/native_functions.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5044,6 +5044,27 @@
50445044
device_check: NoCheck
50455045
device_guard: False
50465046

5047+
- func: sym_size.int(Tensor self, int dim) -> SymInt
5048+
variants: function
5049+
device_check: NoCheck
5050+
device_guard: False
5051+
tags: core
5052+
manual_cpp_binding: True
5053+
5054+
- func: sym_numel(Tensor self) -> SymInt
5055+
variants: function
5056+
device_check: NoCheck
5057+
device_guard: False
5058+
tags: core
5059+
manual_cpp_binding: True
5060+
5061+
- func: sym_storage_offset(Tensor self) -> SymInt
5062+
variants: function
5063+
device_check: NoCheck
5064+
device_guard: False
5065+
tags: core
5066+
manual_cpp_binding: True
5067+
50475068
- func: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
50485069
variants: function, method
50495070
device_check: NoCheck
@@ -5318,6 +5339,13 @@
53185339
device_check: NoCheck
53195340
device_guard: False
53205341

5342+
- func: sym_stride.int(Tensor self, int dim) -> SymInt
5343+
variants: function
5344+
device_check: NoCheck
5345+
device_guard: False
5346+
tags: core
5347+
manual_cpp_binding: True
5348+
53215349
- func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
53225350
device_check: NoCheck # TensorIterator
53235351
variants: function, method

test/functorch/test_vmap_registrations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@
286286
"aten::subtract_.Scalar",
287287
"aten::subtract_.Tensor",
288288
"aten::svd.U",
289+
"aten::sym_size.int",
290+
"aten::sym_stride.int",
291+
"aten::sym_numel",
292+
"aten::sym_storage_offset",
289293
"aten::tensor_split.indices",
290294
"aten::tensor_split.sections",
291295
"aten::tensor_split.tensor_indices_or_sections",

tools/autograd/gen_python_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@
8888
"is_sparse_csr",
8989
"size",
9090
"stride",
91+
"sym_size",
92+
"sym_stride",
93+
"sym_storage_offset",
94+
"sym_numel",
9195
".*_backward",
9296
".*_backward_(out|input|weight|bias)",
9397
".*_forward",

torch/csrc/jit/runtime/register_prim_ops.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -415,32 +415,13 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
415415
TORCH_SELECTIVE_SCHEMA("aten::sym_size(Tensor self) -> SymInt[]"),
416416
sym_size,
417417
aliasAnalysisFromSchema()),
418-
OperatorGeneratorArgs(
419-
TORCH_SELECTIVE_SCHEMA(
420-
"aten::sym_size.int(Tensor self, int dim) -> SymInt"),
421-
sym_size_int,
422-
aliasAnalysisFromSchema()),
423-
OperatorGeneratorArgs(
424-
TORCH_SELECTIVE_SCHEMA(
425-
"aten::sym_stride.int(Tensor self, int dim) -> SymInt"),
426-
sym_stride_int,
427-
aliasAnalysisFromSchema()),
428418
OperatorGeneratorArgs(
429419
TORCH_SELECTIVE_SCHEMA("aten::stride(Tensor self) -> int[]"),
430420
[](Stack& stack) {
431421
at::Tensor arg = pop(stack).toTensor();
432422
push(stack, arg.strides());
433423
},
434424
aliasAnalysisFromSchema()),
435-
OperatorGeneratorArgs(
436-
TORCH_SELECTIVE_SCHEMA("aten::sym_numel(Tensor self) -> SymInt"),
437-
sym_numel,
438-
aliasAnalysisFromSchema()),
439-
OperatorGeneratorArgs(
440-
TORCH_SELECTIVE_SCHEMA(
441-
"aten::sym_storage_offset(Tensor self) -> SymInt"),
442-
sym_storage_offset,
443-
aliasAnalysisFromSchema()),
444425
OperatorGeneratorArgs(
445426
TORCH_SELECTIVE_SCHEMA("aten::sym_stride(Tensor self) -> SymInt[]"),
446427
sym_stride,

torchgen/api/cpp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> Named
226226
# and a function with a return type of 'std::tuple' has >1 return name.
227227
def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
228228
# placeholder is ignored
229-
r = valuetype_type(t, binds="__placeholder__", symint=symint)
229+
# NB: symint is ALWAYS respected for return types. So symint argument
230+
# here is IGNORED
231+
r = valuetype_type(t, binds="__placeholder__", symint=True)
230232
if r is not None:
231233
return r.type
232234

@@ -249,7 +251,7 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
249251
assert (
250252
not mutable
251253
), "Native functions should never return a mutable tensor list. They should return void."
252-
elem = returntype_type(t.elem, mutable=False, symint=symint)
254+
elem = returntype_type(t.elem, mutable=False)
253255
assert t.size is None, f"fixed size list returns not supported: {t}"
254256
return VectorCType(elem)
255257

torchgen/api/types/signatures.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class CppSignature:
3535
# Is this a symint C++ signature. For BC reasons, functions that take
3636
# SymInts still present as int64_t in C++, and the SymInt variant is
3737
# offered at a different overload name
38+
#
39+
# NB: If a function RETURNS a SymInt, this is ALWAYS false
3840
symint: bool
3941

4042
# The set of C++ arguments which should not have defaults applied to them

torchgen/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,9 +1628,7 @@ def modifies_arguments(self) -> bool:
16281628
return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
16291629

16301630
def has_symint(self) -> bool:
1631-
return self.arguments.has_symint_arg() or any(
1632-
r.type.is_symint_like() for r in self.returns
1633-
)
1631+
return self.arguments.has_symint_arg()
16341632

16351633
def __str__(self) -> str:
16361634
all_arguments_str = str(self.arguments)

0 commit comments

Comments
 (0)