Skip to content

Commit 9676061

Browse files
Natalia Gimelsheinpytorchmergebot
Natalia Gimelshein
authored andcommitted
port torch cov tests to error inputs (#73977)
Summary: Per title Pull Request resolved: #73977 Reviewed By: malfet Differential Revision: D34779552 Pulled By: ngimel fbshipit-source-id: b4191101a029981eb27c75e1b56d739db046f819 (cherry picked from commit 2c2af72)
1 parent 122f864 commit 9676061

File tree

2 files changed

+58
-41
lines changed

2 files changed

+58
-41
lines changed

test/test_torch.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,23 +1746,6 @@ def check(t, correction=1, fweights=None, aweights=None):
17461746
for correction, fw, aw in product([0, 1, 2], [None, fweights], [None, aweights]):
17471747
check(x, correction, fweights, aweights)
17481748

1749-
# FIXME: port to ErrorInputs
1750-
def test_cov_error(self, device):
1751-
def check(msg, *args, **kwargs):
1752-
with self.assertRaisesRegex(RuntimeError, r'cov\(\):.*' + msg + r'.*'):
1753-
torch.cov(*args, **kwargs)
1754-
1755-
a = torch.rand(2)
1756-
check(r'expected input to have two or fewer dimensions', torch.rand(2, 2, 2))
1757-
check(r'expected fweights to have one or fewer dimensions', a, fweights=torch.rand(2, 2))
1758-
check(r'expected aweights to have one or fewer dimensions', a, aweights=torch.rand(2, 2))
1759-
check(r'expected fweights to have integral dtype', a, fweights=torch.rand(2))
1760-
check(r'expected aweights to have floating point dtype', a, aweights=torch.tensor([1, 1]))
1761-
check(r'expected fweights to have the same numel', a, fweights=torch.tensor([1]))
1762-
check(r'expected aweights to have the same numel', a, aweights=torch.rand(1))
1763-
check(r'fweights cannot be negative', a, fweights=torch.tensor([-1, -2]))
1764-
check(r'aweights cannot be negative', a, aweights=torch.tensor([-1., -2.]))
1765-
17661749
@skipIfNoSciPy
17671750
@dtypes(*get_all_fp_dtypes())
17681751
def test_uniform_kstest(self, device, dtype):

torch/testing/_internal/common_methods_invocations.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ class ErrorInput(object):
208208

209209
__slots__ = ['sample_input', 'error_type', 'error_regex']
210210

211-
def __init__(self, sample_input, *, error_type, error_regex):
211+
def __init__(self, sample_input, *, error_type=RuntimeError, error_regex):
212212
self.sample_input = sample_input
213213
self.error_type = error_type
214214
self.error_regex = error_regex
@@ -1474,8 +1474,8 @@ def error_inputs_hsplit(op_info, device, **kwargs):
14741474
dtype=torch.float32,
14751475
device=device),
14761476
args=(0,),)
1477-
return (ErrorInput(si1, error_type=RuntimeError, error_regex=err_msg1),
1478-
ErrorInput(si2, error_type=RuntimeError, error_regex=err_msg2),)
1477+
return (ErrorInput(si1, error_regex=err_msg1),
1478+
ErrorInput(si2, error_regex=err_msg2),)
14791479

14801480
def error_inputs_vsplit(op_info, device, **kwargs):
14811481
err_msg1 = ("torch.vsplit requires a tensor with at least 2 dimension, "
@@ -1491,8 +1491,8 @@ def error_inputs_vsplit(op_info, device, **kwargs):
14911491
dtype=torch.float32,
14921492
device=device),
14931493
args=(0,),)
1494-
return (ErrorInput(si1, error_type=RuntimeError, error_regex=err_msg1),
1495-
ErrorInput(si2, error_type=RuntimeError, error_regex=err_msg2),)
1494+
return (ErrorInput(si1, error_regex=err_msg1),
1495+
ErrorInput(si2, error_regex=err_msg2),)
14961496

14971497
def error_inputs_dsplit(op_info, device, **kwargs):
14981498
err_msg1 = ("torch.dsplit requires a tensor with at least 3 dimension, "
@@ -1508,8 +1508,8 @@ def error_inputs_dsplit(op_info, device, **kwargs):
15081508
dtype=torch.float32,
15091509
device=device),
15101510
args=(0,),)
1511-
return (ErrorInput(si1, error_type=RuntimeError, error_regex=err_msg1),
1512-
ErrorInput(si2, error_type=RuntimeError, error_regex=err_msg2),)
1511+
return (ErrorInput(si1, error_regex=err_msg1),
1512+
ErrorInput(si2, error_regex=err_msg2),)
15131513

15141514
def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwargs):
15151515
# Each test case consists of the sizes in the chain of multiplications
@@ -3060,12 +3060,12 @@ def error_inputs_gather(op_info, device, **kwargs):
30603060

30613061
# Index should be smaller than self except on dimesion 1
30623062
bad_src = make_tensor((1, 1), device=device, dtype=torch.float32)
3063-
yield ErrorInput(SampleInput(bad_src, args=(1, idx,)), error_type=RuntimeError,
3063+
yield ErrorInput(SampleInput(bad_src, args=(1, idx,)),
30643064
error_regex="Size does not match at dimension 0")
30653065

30663066
# Index must have long dtype
30673067
bad_idx = idx.to(torch.int32)
3068-
yield ErrorInput(SampleInput(src, args=(1, bad_idx)), error_type=RuntimeError,
3068+
yield ErrorInput(SampleInput(src, args=(1, bad_idx)),
30693069
error_regex="Expected dtype int64 for index")
30703070

30713071
# TODO: FIXME
@@ -3074,28 +3074,28 @@ def error_inputs_gather(op_info, device, **kwargs):
30743074
src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
30753075
idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
30763076
out = torch.empty((2, 2), device=device, dtype=torch.float64)
3077-
yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}), error_type=RuntimeError,
3077+
yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}),
30783078
error_regex="Expected out tensor to have dtype")
30793079

30803080
# src and index tensors must have the same # of dimensions
30813081
# idx too few dimensions
30823082
src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
30833083
idx = torch.tensor((0, 0), device=device, dtype=torch.long)
3084-
yield ErrorInput(SampleInput(src, args=(1, idx)), error_type=RuntimeError,
3084+
yield ErrorInput(SampleInput(src, args=(1, idx)),
30853085
error_regex="Index tensor must have the same number of dimensions")
30863086

30873087
# src too few dimensions
30883088
src = torch.tensor((1, 2), device=device, dtype=torch.float32)
30893089
idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
3090-
yield ErrorInput(SampleInput(src, args=(0, idx)), error_type=RuntimeError,
3090+
yield ErrorInput(SampleInput(src, args=(0, idx)),
30913091
error_regex="Index tensor must have the same number of dimensions")
30923092

30933093
# index out of bounds
30943094
# NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices
30953095
if torch.device(device).type == 'cpu':
30963096
src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
30973097
idx = torch.tensor(((0, 23), (1, 0)), device=device, dtype=torch.long)
3098-
yield ErrorInput(SampleInput(src, args=(1, idx,)), error_type=RuntimeError,
3098+
yield ErrorInput(SampleInput(src, args=(1, idx,)),
30993099
error_regex="index 23 is out of bounds for dimension")
31003100

31013101
# Error inputs for scatter
@@ -3104,28 +3104,28 @@ def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
31043104
src = make_tensor((2, 5), device=device, dtype=torch.float32)
31053105
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
31063106
dst = torch.zeros((3, 5), device=device, dtype=torch.double)
3107-
yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3107+
yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
31083108
error_regex="Expected self.dtype to be equal to src.dtype")
31093109

31103110
# Index dtype must be long
31113111
src = make_tensor((2, 5), device=device, dtype=torch.float32)
31123112
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.int32)
31133113
dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
3114-
yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3114+
yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
31153115
error_regex="Expected dtype int64 for index")
31163116

31173117
# Index and destination must have the same number of dimensions
31183118
src = make_tensor((2, 5), device=device, dtype=torch.float32)
31193119
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
31203120
dst = torch.zeros((3, 5, 3), device=device, dtype=torch.float32)
3121-
yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3121+
yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
31223122
error_regex="Index tensor must have the same number of dimensions as self tensor")
31233123

31243124
# Index and src must have the same number of dimensions when src is not a scalar
31253125
src = make_tensor((2, 5, 2), device=device, dtype=torch.float32)
31263126
idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
31273127
dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
3128-
yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3128+
yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
31293129
error_regex="Index tensor must have the same number of dimensions as src tensor")
31303130

31313131
# Index out of bounds
@@ -3134,7 +3134,7 @@ def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
31343134
src = make_tensor((2, 5), device=device, dtype=torch.float32)
31353135
idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
31363136
dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
3137-
yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3137+
yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
31383138
error_regex="index 34 is out of bounds for dimension 0 with size 3")
31393139

31403140
def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs):
@@ -5532,6 +5532,39 @@ def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs):
55325532
return inputs
55335533

55345534

5535+
def error_inputs_cov(op_info, device, **kwargs):
5536+
a = torch.rand(S, device=device)
5537+
error_inputs = []
5538+
error_inputs.append(ErrorInput(
5539+
SampleInput(torch.rand(S, S, S, device=device)),
5540+
error_regex="expected input to have two or fewer dimensions"))
5541+
error_inputs.append(ErrorInput(
5542+
SampleInput(a, kwargs={'fweights': torch.rand(S, S, device=device)}),
5543+
error_regex="expected fweights to have one or fewer dimensions"))
5544+
error_inputs.append(ErrorInput(
5545+
SampleInput(a, kwargs={'aweights': torch.rand(S, S, device=device)}),
5546+
error_regex="expected aweights to have one or fewer dimensions"))
5547+
error_inputs.append(ErrorInput(
5548+
SampleInput(a, kwargs={'fweights': torch.rand(S, device=device)}),
5549+
error_regex="expected fweights to have integral dtype"))
5550+
error_inputs.append(ErrorInput(
5551+
SampleInput(a, kwargs={'aweights': torch.tensor([1, 1], device=device)}),
5552+
error_regex="expected aweights to have floating point dtype"))
5553+
error_inputs.append(ErrorInput(
5554+
SampleInput(a, kwargs={'fweights': torch.tensor([1], device=device)}),
5555+
error_regex="expected fweights to have the same numel"))
5556+
error_inputs.append(ErrorInput(
5557+
SampleInput(a, kwargs={'aweights': torch.rand(1, device=device)}),
5558+
error_regex="expected aweights to have the same numel"))
5559+
error_inputs.append(ErrorInput(
5560+
SampleInput(a, kwargs={'fweights': torch.tensor([-1, -2, -3, -4 , -5], device=device)}),
5561+
error_regex="fweights cannot be negative"))
5562+
error_inputs.append(ErrorInput(
5563+
SampleInput(a, kwargs={'aweights': torch.tensor([-1., -2., -3., -4., -5.], device=device)}),
5564+
error_regex="aweights cannot be negative"))
5565+
return error_inputs
5566+
5567+
55355568
def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs):
55365569
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
55375570
make_arg = partial(make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad)
@@ -5967,7 +6000,7 @@ def error_inputs_neg(op_info, device, **kwargs):
59676000
msg = ("Negation, the `\\-` operator, on a bool tensor is not supported."
59686001
" If you are trying to invert a mask, use the `\\~` or"
59696002
" `logical_not\\(\\)` operator instead.")
5970-
return (ErrorInput(si, error_type=RuntimeError, error_regex=msg),)
6003+
return (ErrorInput(si, error_regex=msg),)
59716004

59726005
def sample_inputs_nextafter(op_info, device, dtype, requires_grad, **kwargs):
59736006
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
@@ -7104,7 +7137,7 @@ def error_inputs_where(op_info, device, **kwargs):
71047137
si = SampleInput(make_tensor(shape, device=devices[0], dtype=torch.float32),
71057138
args=(make_tensor(shape, dtype=torch.bool, device=devices[1]),
71067139
make_tensor(shape, device=devices[2], dtype=torch.float32)))
7107-
yield ErrorInput(si, error_type=RuntimeError, error_regex=err_msg)
7140+
yield ErrorInput(si, error_regex=err_msg)
71087141

71097142
def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs):
71107143
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
@@ -7164,13 +7197,13 @@ def error_inputs_kthvalue(op_info, device, **kwargs):
71647197
si = SampleInput(t, args=(5,), kwargs={'out': (t, indices)})
71657198

71667199
k_out_of_range_err = "selected number k out of range for dimension"
7167-
return (ErrorInput(si, error_type=RuntimeError, error_regex="unsupported operation"),
7200+
return (ErrorInput(si, error_regex="unsupported operation"),
71687201
ErrorInput(SampleInput(torch.randn(2, 2, device=device), args=(3, 0)),
7169-
error_type=RuntimeError, error_regex=k_out_of_range_err),
7202+
error_regex=k_out_of_range_err),
71707203
ErrorInput(SampleInput(torch.randn(2, 2, device=device), args=(3,)),
7171-
error_type=RuntimeError, error_regex=k_out_of_range_err),
7204+
error_regex=k_out_of_range_err),
71727205
ErrorInput(SampleInput(torch.tensor(2, device=device), args=(3,)),
7173-
error_type=RuntimeError, error_regex=k_out_of_range_err),)
7206+
error_regex=k_out_of_range_err),)
71747207

71757208
def sample_inputs_dropout(op_info, device, dtype, requires_grad, *,
71767209
train=None, valid_input_dim=None, **kwargs):
@@ -9087,6 +9120,7 @@ def ref_pairwise_distance(input1, input2):
90879120
backward_dtypesIfCUDA=all_types_and_complex_and(torch.half, *[torch.bfloat16]
90889121
if (CUDA11OrLater or TEST_WITH_ROCM) else []),
90899122
sample_inputs_func=sample_inputs_cov,
9123+
error_inputs_func=error_inputs_cov,
90909124
supports_out=False,
90919125
supports_forward_ad=True,
90929126
supports_fwgrad_bwgrad=True,

0 commit comments

Comments
 (0)