Skip to content

Should tensordot broadcast the contracted dimensions? #294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
asmeurer opened this issue Oct 29, 2021 · 3 comments · Fixed by #324
Closed

Should tensordot broadcast the contracted dimensions? #294

asmeurer opened this issue Oct 29, 2021 · 3 comments · Fixed by #324
Labels
topic: Linear Algebra Linear algebra.
Milestone

Comments

@asmeurer
Copy link
Member

Should tensordot broadcast the contracted dimensions. For example, say we contract the first dimensions here

tensordot(ones((3, 3)), ones((1, 3)), axes=((0,), (0,)))

The dimension 3 and 1 do not match, but if we broadcast the arrays together first they both become shape (3, 3), after which they do match.

The spec is a little unclear about this https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#tensordot-x1-x2-axes-2. It says x2 must be compatible with x1 by broadcasting, which seems to imply unconditional broadcasting. But it also says "Each axis (dimension) x1_axes[i] for x1 must have the same size as the respective axis (dimension) x2_axes[i] for x2."

NumPy disallows broadcasting in contracted dimensions (it does broadcast non-contracted dimensions):

>>> np.tensordot(np.ones((3, 3)), np.ones((1, 3)), axes=((0,), (0,)))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<__array_function__ internals>", line 181, in tensordot
  File "./numpy/core/numeric.py", line 1110, in tensordot
    raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum
>>> np.tensordot(np.ones((3, 3)), np.ones((1, 3)), axes=((1,), (1,)))
array([[3.],
       [3.],
       [3.]])

Pytorch broadcasts all dimensions, including contracted ones (note that pytorch still calls its axes argument dims)

>>> torch.tensordot(torch.ones((3, 3)), torch.ones((1, 3)), dims=((0,), (0,)))
tensor([[3., 3., 3.],
        [3., 3., 3.],
        [3., 3., 3.]])
>>> torch.tensordot(torch.ones((3, 3)), torch.ones((1, 3)), dims=((1,), (1,)))
tensor([[3.],
        [3.],
        [3.]])

Note that in either case, the resulting array shape is based on the non-broadcasted input shapes, so it's not as simple as wrapping the call with broadcast_arrays.

>>> np.tensordot(np.ones((3, 3)), np.ones((2, 3, 3)), axes=((-1,), (2,))).shape
(3, 2, 3)
>>> np.tensordot(np.ones((2, 3, 3)), np.ones((2, 3, 3)), axes=((-1,), (2,))).shape
(2, 3, 2, 3)
@kgryte kgryte added the topic: Linear Algebra Linear algebra. label Oct 30, 2021
@kgryte kgryte added this to the v2021 milestone Oct 30, 2021
@asmeurer
Copy link
Member Author

asmeurer commented Nov 1, 2021

CC @lezcano do you have any thoughts on this?

@kgryte
Copy link
Contributor

kgryte commented Nov 4, 2021

In today's call, we decided to align with NumPy's behavior, as advocated for by @leofang, @oleksandr-pavlyk, and @rgommers. Given PyTorch's relatively recent addition of tensordot and NumPy's alignment with matmul (which also does not broadcast the innermost two dimensions), seems reasonable to adopt NumPy's behavior in this instance.

@IvanYashchuk did mention elsewhere that PyTorch's behavior can match einsum behavior as follows. If we express the tensordot operation with dims=((1,), (1,)) using einsum ("ab,dc->ad" ), then NumPy would compute the same as torch.tensordot.

import torch
import numpy as np

 a1 = np.random.normal(size=(3, 3))
 a2 = np.random.normal(size=(3, 1))

np.einsum("ab,dc->ad", a1, a2)
# array([[ 1.44946877, -1.0152814 , -1.39638556],
#        [ 1.48299164, -1.03876252, -1.42868074],
#        [-0.37928214,  0.26566844,  0.36539187]])

torch.tensordot(*map(torch.from_numpy, (a1, a2)), dims=((1,), (1,)))
# tensor([[ 1.4495, -1.0153, -1.3964],
#         [ 1.4830, -1.0388, -1.4287],
#         [-0.3793,  0.2657,  0.3654]], dtype=torch.float64)

# A slower but equivalent way of computing the same as with np.einsum "ik,jn->ij"
# or torch.tensordot with dims=((1,), (1,))
result = np.zeros((3, 3))
for i in range(result.shape[0]):
  for j in range(result.shape[1]):
    for k in range(a1.shape[1]): # reducing dim/axis #1
      for n in range(a2.shape[1]): # reducing dim/axis #1
        result[i, j] += a1[i, k] * a2[j, n]
# array([[ 1.44946877, -1.0152814 , -1.39638556],
#        [ 1.48299164, -1.03876252, -1.42868074],
#        [-0.37928214,  0.26566844,  0.36539187]])

@rgommers
Copy link
Member

rgommers commented Nov 5, 2021

Given PyTorch's relatively recent addition of tensordot and NumPy's alignment with matmul (which also does not broadcast the innermost two dimensions), seems reasonable to adopt NumPy's behavior in this instance.

A summary of additional considerations for why the NumPy behavior is preferred:

  • In the case of ambiguity, it's better to raise an exception than to pick a behavior
  • In the future it's possible to add more behavior if desired; taking it away is not

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: Linear Algebra Linear algebra.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants