Skip to content

Commit 5cbd57d

Browse files
committed
[Array API] Add linalg.vecdot
This PR adds the function `linalg.vecdot` specified by the [Array API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot) For the complex case, it chooses to implement \sum x_i y_i. See the discussion in data-apis/array-api#356 [ghstack-poisoned]
1 parent ce86881 commit 5cbd57d

File tree

9 files changed

+82
-1
lines changed

9 files changed

+82
-1
lines changed

aten/src/ATen/core/interned_strings.h

+1
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ namespace c10 {
218218
_(aten, mH) \
219219
_(aten, linalg_matrix_power) \
220220
_(aten, chain_matmul) \
221+
_(aten, linalg_vecdot) \
221222
_(aten, linalg_multi_dot) \
222223
_(aten, linalg_norm) \
223224
_(aten, linalg_vector_norm) \

aten/src/ATen/native/BatchLinearAlgebra.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -3836,6 +3836,14 @@ TransposeType to_transpose_type(const bool contig, const bool conj) {
38363836
}
38373837
} // end of anonymous namespace
38383838

3839+
Tensor& linalg_vecdot_out(const Tensor& x, const Tensor& y, int64_t dim, Tensor& out) {
3840+
return at::sum_out(out, x * y, /*dim=*/dim);
3841+
}
3842+
3843+
Tensor linalg_vecdot(const Tensor& x, const Tensor& y, int64_t dim) {
3844+
return (x * y).sum(/*dim=*/dim);
3845+
}
3846+
38393847
/*
38403848
Solves the matrix equation AX = B for A triangular.
38413849
'left' If true solves AX = B, if false solves XA = B

aten/src/ATen/native/native_functions.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -10844,6 +10844,13 @@
1084410844
- func: linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
1084510845
python_module: linalg
1084610846

10847+
- func: linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor
10848+
python_module: linalg
10849+
variants: function
10850+
10851+
- func: linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
10852+
python_module: linalg
10853+
1084710854
- func: linalg_matrix_exp(Tensor self) -> Tensor
1084810855
python_module: linalg
1084910856
variants: function

docs/source/linalg.rst

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Matrix Products
8080

8181
cross
8282
matmul
83+
vecdot
8384
multi_dot
8485
householder_product
8586

test/test_linalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torch.testing import make_tensor
2727
from torch.testing._internal.common_dtype import (
2828
all_types, floating_types, floating_and_complex_types, get_all_dtypes, get_all_int_dtypes, get_all_complex_dtypes,
29-
get_all_fp_dtypes,
29+
get_all_fp_dtypes, all_types_and_complex_and
3030
)
3131
from torch.testing._internal.common_cuda import SM53OrLater, tf32_on_and_off, CUDA11OrLater, CUDA9
3232
from torch.distributions.binomial import Binomial

torch/_torch_docs.py

+4
Original file line numberDiff line numberDiff line change
@@ -3412,6 +3412,10 @@ def merge_dicts(*dicts):
34123412
Unlike NumPy's dot, torch.dot intentionally only supports computing the dot product
34133413
of two 1D tensors with the same number of elements.
34143414
3415+
.. seealso::
3416+
3417+
:func:`torch.linalg.vecdot` the vector product of two batches of vectors along a dimension.
3418+
34153419
Args:
34163420
input (Tensor): first tensor in the dot product, must be 1D.
34173421
other (Tensor): second tensor in the dot product, must be 1D.

torch/linalg/__init__.py

+39
Original file line numberDiff line numberDiff line change
@@ -2251,3 +2251,42 @@
22512251
>>> torch.dist(Q.mT @ Q, torch.eye(4))
22522252
tensor(6.2158e-07)
22532253
""")
2254+
2255+
vecdot = _add_docstr(_linalg.linalg_vecdot, r"""
2256+
linalg.vecdot(x, y, *, dim=-1, out=None) -> Tensor
2257+
2258+
Computes the vector product of two batches of vectors along a dimension.
2259+
2260+
In symbols, this function computes
2261+
2262+
.. math::
2263+
2264+
\sum_{i=1}^n x_iy_i.
2265+
2266+
over the dimension :attr:`dim`.
2267+
2268+
Supports input of half, bfloat16, float, double, cfloat, cdouble and integral dtypes.
2269+
It also supports broadcasting.
2270+
2271+
.. seealso::
2272+
2273+
:func:`torch.matmul` computes a general matrix-matrix multiplication for batches
2274+
of matrices.
2275+
2276+
Args:
2277+
x (Tensor): first batch of vectors.
2278+
y (Tensor): second batch of vectors.
2279+
2280+
Keyword args:
2281+
dim (int): Dimension along which to compute the vector product.
2282+
out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
2283+
2284+
Examples::
2285+
2286+
>>> v1 = torch.randn(3, 2)
2287+
>>> v2 = torch.randn(3, 2)
2288+
>>> linalg.vecdot(v1, v2)
2289+
tensor([ 0.3223, 0.2815, -0.1944])
2290+
>>> torch.dot(v1[0], v2[0])
2291+
tensor(0.3223)
2292+
""")

torch/overrides.py

+1
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
869869
torch.ravel: lambda input: -1,
870870
torch.real: lambda input, out=None: -1,
871871
torch.vdot: lambda input, other, out=None: -1,
872+
torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1,
872873
torch.view_as_real: lambda input: -1,
873874
torch.view_as_complex: lambda input: -1,
874875
torch.reciprocal: lambda input, out=None: -1,

torch/testing/_internal/common_methods_invocations.py

+20
Original file line numberDiff line numberDiff line change
@@ -2012,6 +2012,15 @@ def sample_inputs_isclose(
20122012
yield SampleInput(lhs, args=(rhs,),
20132013
kwargs=dict(op_kwargs, rtol=rtol, atol=atol, equal_nan=equal_nan))
20142014

2015+
def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
2016+
yield from sample_inputs_binary_pwise(op_info, device, dtype, requires_grad)
2017+
2018+
# Add also samples with dim != -1
2019+
for s in sample_inputs_binary_pwise(op_info, device, dtype, requires_grad):
2020+
if s.input.ndim > 1:
2021+
s.kwargs["dim"] = 0
2022+
yield s
2023+
20152024
def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
20162025
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
20172026
return (SampleInput(make_arg((1, 2))),
@@ -9778,6 +9787,17 @@ def ref_pairwise_distance(input1, input2):
97789787
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
97799788
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
97809789
),
9790+
BinaryUfuncInfo('linalg.vecdot',
9791+
aten_name='linalg_vecdot',
9792+
ref=lambda x, y, *, dim=-1: (x * y).sum(dim),
9793+
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
9794+
sample_inputs_func=sample_inputs_linalg_vecdot,
9795+
supports_forward_ad=True,
9796+
supports_fwgrad_bwgrad=True,
9797+
skips=(
9798+
# torch.sum(out=) has an incorrect behaviour
9799+
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
9800+
),),
97819801
OpInfo('linalg.cond',
97829802
aten_name='linalg_cond',
97839803
dtypes=floating_and_complex_types(),

0 commit comments

Comments
 (0)