@@ -2012,6 +2012,15 @@ def sample_inputs_isclose(
2012
2012
yield SampleInput(lhs, args=(rhs,),
2013
2013
kwargs=dict(op_kwargs, rtol=rtol, atol=atol, equal_nan=equal_nan))
2014
2014
2015
+ def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
2016
+ yield from sample_inputs_binary_pwise(op_info, device, dtype, requires_grad)
2017
+
2018
+ # Add also samples with dim != -1
2019
+ for s in sample_inputs_binary_pwise(op_info, device, dtype, requires_grad):
2020
+ if s.input.ndim > 1:
2021
+ s.kwargs["dim"] = 0
2022
+ yield s
2023
+
2015
2024
def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
2016
2025
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2017
2026
return (SampleInput(make_arg((1, 2))),
@@ -9778,6 +9787,17 @@ def ref_pairwise_distance(input1, input2):
9778
9787
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
9779
9788
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
9780
9789
),
9790
+ BinaryUfuncInfo('linalg.vecdot',
9791
+ aten_name='linalg_vecdot',
9792
+ ref=lambda x, y, *, dim=-1: (x * y).sum(dim),
9793
+ dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
9794
+ sample_inputs_func=sample_inputs_linalg_vecdot,
9795
+ supports_forward_ad=True,
9796
+ supports_fwgrad_bwgrad=True,
9797
+ skips=(
9798
+ # torch.sum(out=) has an incorrect behaviour
9799
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
9800
+ ),),
9781
9801
OpInfo('linalg.cond',
9782
9802
aten_name='linalg_cond',
9783
9803
dtypes=floating_and_complex_types(),
0 commit comments