Skip to content

Commit 7cd7caa

Browse files
committed
Strenghten preconditions of linalg.cross
This makes `linalg.cross` array API complaint (data-apis/array-api#415) and fixes a few bugs. Fixes #77629 Fixes #83756 [ghstack-poisoned]
1 parent f0ee21f commit 7cd7caa

File tree

5 files changed

+32
-39
lines changed

5 files changed

+32
-39
lines changed

aten/src/ATen/TensorMeta.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ namespace impl {
7171
struct TORCH_API MetaBase {
7272
virtual const Tensor& maybe_get_output(int64_t output_idx) = 0;
7373

74+
// Note: [set_output_*]
7475
// See: https://github.com/pytorch/pytorch/issues/69813
7576
// Whenever defining the output properties in the META function of a
7677
// structured kernel (what was usually done with `set_output`), use one of

aten/src/ATen/native/Cross.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,22 @@
99
namespace at {
1010
namespace meta {
1111

12-
TORCH_PRECOMPUTE_META_FUNC(linalg_cross)
13-
(const Tensor & input, const Tensor & other, const int64_t dimension) {
14-
auto out_size = infer_size(input.sizes(), other.sizes());
15-
Tensor input_broadcasted = input.expand(out_size);
16-
Tensor other_broadcasted = other.expand(out_size);
12+
TORCH_META_FUNC(linalg_cross)
13+
(const Tensor & input, const Tensor & other, int64_t dim) {
14+
auto x_d = input.dim();
15+
auto y_d = other.dim();
16+
// This is to avoid things like:
17+
// linalg.cross(torch.randn(2, 3), torch.randn(5, 2, 3), dim=2)
18+
// This would be very odd. This will still be possible by doing
19+
// linalg.cross(torch.randn(2, 3).unsqueeze(0), torch.randn(5, 2, 3), dim=2)
20+
TORCH_CHECK(x_d == y_d, "linalg.cross: inputs must have the same number of dimensions. Got ");
21+
TORCH_CHECK(input.size(dim) == 3 && other.size(dim) == 3, "linalg.cross: inputs dimension ", dim, " must have length 3. Got ", input.size(dim), " and ", other.size(dim));
1722

18-
int64_t dim = maybe_wrap_dim(dimension, input.dim()); // default dim = -1
19-
TORCH_CHECK(input_broadcasted.size(dim) == 3, "dimension ", dimension, " does not have size 3");
23+
// Broadcast the batch dimension of input and other.
24+
// Since the non-batch dimensions agree, this is the same as broadcast all the inputs
25+
auto out_size = infer_size(input.sizes(), other.sizes());
2026

2127
set_output_raw_strided(0, out_size, {}, input.options());
22-
return TORCH_PRECOMPUTE_STRUCT(linalg_cross)().set_dim(dim);
2328
}
2429

2530
}
@@ -56,8 +61,9 @@ Tensor & cross_out(const Tensor & input, const Tensor & other, const c10::option
5661

5762

5863
TORCH_IMPL_FUNC(linalg_cross_out)
59-
(const Tensor & input, const Tensor & other, const int64_t dim, const Tensor & out) {
60-
auto out_size = infer_size(input.sizes(), other.sizes());
64+
(const Tensor & input, const Tensor & other, int64_t dim, const Tensor & out) {
65+
dim = maybe_wrap_dim(dim, input.dim());
66+
auto out_size = out.sizes();
6167
Tensor input_broadcasted = input.expand(out_size);
6268
Tensor other_broadcasted = other.expand(out_size);
6369

aten/src/ATen/native/native_functions.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12177,8 +12177,6 @@
1217712177
- func: linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
1217812178
python_module: linalg
1217912179
structured: True
12180-
precomputed:
12181-
- dim -> int dim
1218212180
dispatch:
1218312181
CPU, CUDA: linalg_cross_out
1218412182

torch/testing/_internal/common_methods_invocations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
sample_inputs_linalg_cholesky,
110110
sample_inputs_linalg_cholesky_inverse,
111111
sample_inputs_cross,
112+
error_inputs_cross,
112113
sample_inputs_linalg_qr_geqrf,
113114
sample_inputs_linalg_invertible,
114115
sample_inputs_lu_solve,
@@ -4925,7 +4926,6 @@ def _clamp_numpy(a, min=None, max=None):
49254926

49264927
return np.minimum(max, np.maximum(a, min))
49274928

4928-
49294929
def sample_inputs_cumprod(op_info, device, dtype, requires_grad, **kwargs):
49304930
def make_arg(shape):
49314931
# shrink values to be in the interval [-1, +1] for better precision in gradgradcheck

torch/testing/_internal/opinfo/definitions/linalg.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -118,33 +118,21 @@ def fn_UVh(usv):
118118

119119

120120
def sample_inputs_cross(op_info, device, dtype, requires_grad, **kwargs):
121-
yield SampleInput(
122-
make_tensor((S, 3), device=device, dtype=dtype, requires_grad=requires_grad),
123-
args=(
124-
make_tensor(
125-
(S, 3), device=device, dtype=dtype, requires_grad=requires_grad
126-
),
127-
),
128-
)
129-
yield SampleInput(
130-
make_tensor((S, 3, S), device=device, dtype=dtype, requires_grad=requires_grad),
131-
args=(
132-
make_tensor(
133-
(S, 3, S), device=device, dtype=dtype, requires_grad=requires_grad
134-
),
135-
),
136-
kwargs={"dim": 1},
137-
)
138-
yield SampleInput(
139-
make_tensor((S, 3), device=device, dtype=dtype, requires_grad=requires_grad),
140-
args=(
141-
make_tensor(
142-
(S, 3), device=device, dtype=dtype, requires_grad=requires_grad
143-
),
144-
),
145-
kwargs={"dim": -1},
146-
)
121+
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
122+
yield SampleInput(make_arg((S, 3)), args=(make_arg((S, 3)),))
123+
yield SampleInput(make_arg((S, 3, S)), args=(make_arg((S, 3, S)),), kwargs=dict(dim=1))
124+
yield SampleInput(make_arg((1, 3)), args=(make_arg((S, 3)),), kwargs=dict(dim=-1))
125+
126+
def error_inputs_cross(op_info, device, dtype, requires_grad, **kwargs):
127+
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
128+
129+
sample = SampleInput(input=make_arg((S, 3)), args=(make_arg((S, 1)),))
130+
err = "inputs dimension -1 must have length 3"
131+
yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
147132

133+
sample = SampleInput(input=make_arg((5, S, 3)), args=(make_arg((S, 3)),))
134+
err = "inputs must have the same number of dimensions"
135+
yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
148136

149137
def sample_inputs_householder_product(op_info, device, dtype, requires_grad, **kwargs):
150138
"""

0 commit comments

Comments
 (0)