@@ -208,7 +208,7 @@ class ErrorInput(object):
208
208
209
209
__slots__ = ['sample_input', 'error_type', 'error_regex']
210
210
211
- def __init__(self, sample_input, *, error_type, error_regex):
211
+ def __init__(self, sample_input, *, error_type=RuntimeError , error_regex):
212
212
self.sample_input = sample_input
213
213
self.error_type = error_type
214
214
self.error_regex = error_regex
@@ -1474,8 +1474,8 @@ def error_inputs_hsplit(op_info, device, **kwargs):
1474
1474
dtype=torch.float32,
1475
1475
device=device),
1476
1476
args=(0,),)
1477
- return (ErrorInput(si1, error_type=RuntimeError, error_regex=err_msg1),
1478
- ErrorInput(si2, error_type=RuntimeError, error_regex=err_msg2),)
1477
+ return (ErrorInput(si1, error_regex=err_msg1),
1478
+ ErrorInput(si2, error_regex=err_msg2),)
1479
1479
1480
1480
def error_inputs_vsplit(op_info, device, **kwargs):
1481
1481
err_msg1 = ("torch.vsplit requires a tensor with at least 2 dimension, "
@@ -1491,8 +1491,8 @@ def error_inputs_vsplit(op_info, device, **kwargs):
1491
1491
dtype=torch.float32,
1492
1492
device=device),
1493
1493
args=(0,),)
1494
- return (ErrorInput(si1, error_type=RuntimeError, error_regex=err_msg1),
1495
- ErrorInput(si2, error_type=RuntimeError, error_regex=err_msg2),)
1494
+ return (ErrorInput(si1, error_regex=err_msg1),
1495
+ ErrorInput(si2, error_regex=err_msg2),)
1496
1496
1497
1497
def error_inputs_dsplit(op_info, device, **kwargs):
1498
1498
err_msg1 = ("torch.dsplit requires a tensor with at least 3 dimension, "
@@ -1508,8 +1508,8 @@ def error_inputs_dsplit(op_info, device, **kwargs):
1508
1508
dtype=torch.float32,
1509
1509
device=device),
1510
1510
args=(0,),)
1511
- return (ErrorInput(si1, error_type=RuntimeError, error_regex=err_msg1),
1512
- ErrorInput(si2, error_type=RuntimeError, error_regex=err_msg2),)
1511
+ return (ErrorInput(si1, error_regex=err_msg1),
1512
+ ErrorInput(si2, error_regex=err_msg2),)
1513
1513
1514
1514
def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwargs):
1515
1515
# Each test case consists of the sizes in the chain of multiplications
@@ -3060,12 +3060,12 @@ def error_inputs_gather(op_info, device, **kwargs):
3060
3060
3061
3061
# Index should be smaller than self except on dimesion 1
3062
3062
bad_src = make_tensor((1, 1), device=device, dtype=torch.float32)
3063
- yield ErrorInput(SampleInput(bad_src, args=(1, idx,)), error_type=RuntimeError,
3063
+ yield ErrorInput(SampleInput(bad_src, args=(1, idx,)),
3064
3064
error_regex="Size does not match at dimension 0")
3065
3065
3066
3066
# Index must have long dtype
3067
3067
bad_idx = idx.to(torch.int32)
3068
- yield ErrorInput(SampleInput(src, args=(1, bad_idx)), error_type=RuntimeError,
3068
+ yield ErrorInput(SampleInput(src, args=(1, bad_idx)),
3069
3069
error_regex="Expected dtype int64 for index")
3070
3070
3071
3071
# TODO: FIXME
@@ -3074,28 +3074,28 @@ def error_inputs_gather(op_info, device, **kwargs):
3074
3074
src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
3075
3075
idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
3076
3076
out = torch.empty((2, 2), device=device, dtype=torch.float64)
3077
- yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}), error_type=RuntimeError,
3077
+ yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}),
3078
3078
error_regex="Expected out tensor to have dtype")
3079
3079
3080
3080
# src and index tensors must have the same # of dimensions
3081
3081
# idx too few dimensions
3082
3082
src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
3083
3083
idx = torch.tensor((0, 0), device=device, dtype=torch.long)
3084
- yield ErrorInput(SampleInput(src, args=(1, idx)), error_type=RuntimeError,
3084
+ yield ErrorInput(SampleInput(src, args=(1, idx)),
3085
3085
error_regex="Index tensor must have the same number of dimensions")
3086
3086
3087
3087
# src too few dimensions
3088
3088
src = torch.tensor((1, 2), device=device, dtype=torch.float32)
3089
3089
idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
3090
- yield ErrorInput(SampleInput(src, args=(0, idx)), error_type=RuntimeError,
3090
+ yield ErrorInput(SampleInput(src, args=(0, idx)),
3091
3091
error_regex="Index tensor must have the same number of dimensions")
3092
3092
3093
3093
# index out of bounds
3094
3094
# NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices
3095
3095
if torch.device(device).type == 'cpu':
3096
3096
src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
3097
3097
idx = torch.tensor(((0, 23), (1, 0)), device=device, dtype=torch.long)
3098
- yield ErrorInput(SampleInput(src, args=(1, idx,)), error_type=RuntimeError,
3098
+ yield ErrorInput(SampleInput(src, args=(1, idx,)),
3099
3099
error_regex="index 23 is out of bounds for dimension")
3100
3100
3101
3101
# Error inputs for scatter
@@ -3104,28 +3104,28 @@ def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
3104
3104
src = make_tensor((2, 5), device=device, dtype=torch.float32)
3105
3105
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
3106
3106
dst = torch.zeros((3, 5), device=device, dtype=torch.double)
3107
- yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3107
+ yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
3108
3108
error_regex="Expected self.dtype to be equal to src.dtype")
3109
3109
3110
3110
# Index dtype must be long
3111
3111
src = make_tensor((2, 5), device=device, dtype=torch.float32)
3112
3112
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.int32)
3113
3113
dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
3114
- yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3114
+ yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
3115
3115
error_regex="Expected dtype int64 for index")
3116
3116
3117
3117
# Index and destination must have the same number of dimensions
3118
3118
src = make_tensor((2, 5), device=device, dtype=torch.float32)
3119
3119
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
3120
3120
dst = torch.zeros((3, 5, 3), device=device, dtype=torch.float32)
3121
- yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3121
+ yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
3122
3122
error_regex="Index tensor must have the same number of dimensions as self tensor")
3123
3123
3124
3124
# Index and src must have the same number of dimensions when src is not a scalar
3125
3125
src = make_tensor((2, 5, 2), device=device, dtype=torch.float32)
3126
3126
idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
3127
3127
dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
3128
- yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3128
+ yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
3129
3129
error_regex="Index tensor must have the same number of dimensions as src tensor")
3130
3130
3131
3131
# Index out of bounds
@@ -3134,7 +3134,7 @@ def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
3134
3134
src = make_tensor((2, 5), device=device, dtype=torch.float32)
3135
3135
idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
3136
3136
dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
3137
- yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3137
+ yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
3138
3138
error_regex="index 34 is out of bounds for dimension 0 with size 3")
3139
3139
3140
3140
def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs):
@@ -5532,6 +5532,39 @@ def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs):
5532
5532
return inputs
5533
5533
5534
5534
5535
+ def error_inputs_cov(op_info, device, **kwargs):
5536
+ a = torch.rand(S, device=device)
5537
+ error_inputs = []
5538
+ error_inputs.append(ErrorInput(
5539
+ SampleInput(torch.rand(S, S, S, device=device)),
5540
+ error_regex="expected input to have two or fewer dimensions"))
5541
+ error_inputs.append(ErrorInput(
5542
+ SampleInput(a, kwargs={'fweights': torch.rand(S, S, device=device)}),
5543
+ error_regex="expected fweights to have one or fewer dimensions"))
5544
+ error_inputs.append(ErrorInput(
5545
+ SampleInput(a, kwargs={'aweights': torch.rand(S, S, device=device)}),
5546
+ error_regex="expected aweights to have one or fewer dimensions"))
5547
+ error_inputs.append(ErrorInput(
5548
+ SampleInput(a, kwargs={'fweights': torch.rand(S, device=device)}),
5549
+ error_regex="expected fweights to have integral dtype"))
5550
+ error_inputs.append(ErrorInput(
5551
+ SampleInput(a, kwargs={'aweights': torch.tensor([1, 1], device=device)}),
5552
+ error_regex="expected aweights to have floating point dtype"))
5553
+ error_inputs.append(ErrorInput(
5554
+ SampleInput(a, kwargs={'fweights': torch.tensor([1], device=device)}),
5555
+ error_regex="expected fweights to have the same numel"))
5556
+ error_inputs.append(ErrorInput(
5557
+ SampleInput(a, kwargs={'aweights': torch.rand(1, device=device)}),
5558
+ error_regex="expected aweights to have the same numel"))
5559
+ error_inputs.append(ErrorInput(
5560
+ SampleInput(a, kwargs={'fweights': torch.tensor([-1, -2, -3, -4 , -5], device=device)}),
5561
+ error_regex="fweights cannot be negative"))
5562
+ error_inputs.append(ErrorInput(
5563
+ SampleInput(a, kwargs={'aweights': torch.tensor([-1., -2., -3., -4., -5.], device=device)}),
5564
+ error_regex="aweights cannot be negative"))
5565
+ return error_inputs
5566
+
5567
+
5535
5568
def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs):
5536
5569
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
5537
5570
make_arg = partial(make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad)
@@ -5967,7 +6000,7 @@ def error_inputs_neg(op_info, device, **kwargs):
5967
6000
msg = ("Negation, the `\\-` operator, on a bool tensor is not supported."
5968
6001
" If you are trying to invert a mask, use the `\\~` or"
5969
6002
" `logical_not\\(\\)` operator instead.")
5970
- return (ErrorInput(si, error_type=RuntimeError, error_regex=msg),)
6003
+ return (ErrorInput(si, error_regex=msg),)
5971
6004
5972
6005
def sample_inputs_nextafter(op_info, device, dtype, requires_grad, **kwargs):
5973
6006
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
@@ -7104,7 +7137,7 @@ def error_inputs_where(op_info, device, **kwargs):
7104
7137
si = SampleInput(make_tensor(shape, device=devices[0], dtype=torch.float32),
7105
7138
args=(make_tensor(shape, dtype=torch.bool, device=devices[1]),
7106
7139
make_tensor(shape, device=devices[2], dtype=torch.float32)))
7107
- yield ErrorInput(si, error_type=RuntimeError, error_regex=err_msg)
7140
+ yield ErrorInput(si, error_regex=err_msg)
7108
7141
7109
7142
def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs):
7110
7143
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
@@ -7164,13 +7197,13 @@ def error_inputs_kthvalue(op_info, device, **kwargs):
7164
7197
si = SampleInput(t, args=(5,), kwargs={'out': (t, indices)})
7165
7198
7166
7199
k_out_of_range_err = "selected number k out of range for dimension"
7167
- return (ErrorInput(si, error_type=RuntimeError, error_regex="unsupported operation"),
7200
+ return (ErrorInput(si, error_regex="unsupported operation"),
7168
7201
ErrorInput(SampleInput(torch.randn(2, 2, device=device), args=(3, 0)),
7169
- error_type=RuntimeError, error_regex=k_out_of_range_err),
7202
+ error_regex=k_out_of_range_err),
7170
7203
ErrorInput(SampleInput(torch.randn(2, 2, device=device), args=(3,)),
7171
- error_type=RuntimeError, error_regex=k_out_of_range_err),
7204
+ error_regex=k_out_of_range_err),
7172
7205
ErrorInput(SampleInput(torch.tensor(2, device=device), args=(3,)),
7173
- error_type=RuntimeError, error_regex=k_out_of_range_err),)
7206
+ error_regex=k_out_of_range_err),)
7174
7207
7175
7208
def sample_inputs_dropout(op_info, device, dtype, requires_grad, *,
7176
7209
train=None, valid_input_dim=None, **kwargs):
@@ -9087,6 +9120,7 @@ def ref_pairwise_distance(input1, input2):
9087
9120
backward_dtypesIfCUDA=all_types_and_complex_and(torch.half, *[torch.bfloat16]
9088
9121
if (CUDA11OrLater or TEST_WITH_ROCM) else []),
9089
9122
sample_inputs_func=sample_inputs_cov,
9123
+ error_inputs_func=error_inputs_cov,
9090
9124
supports_out=False,
9091
9125
supports_forward_ad=True,
9092
9126
supports_fwgrad_bwgrad=True,
0 commit comments