Skip to content

[ready] Introduce chain_matmul #12380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed

Conversation

vishwakftw
Copy link
Contributor

@vishwakftw vishwakftw commented Oct 5, 2018

  • This was one of the few functions left out from the list of functions in
    NumPy's linalg module
  • multi_mm is particularly useful for DL research, for quick analysis of
    deep linear networks
  • Added tests and doc string

@vishwakftw
Copy link
Contributor Author

One benchmark on the CPU (taken out of an exercise in CLRS):

In [14]: a1 = torch.randn(30, 35)

In [15]: a2 = torch.randn(35, 15)

In [16]: a3 = torch.randn(15, 5)

In [17]: a4 = torch.randn(5, 10)

In [18]: a5 = torch.randn(10, 20)

In [19]: a6 = torch.randn(20, 25)

In [20]: %%timeit
    ...: torch.einsum('pq,qr,rs,st,tu,uv->pv',[a1,a2,a3,a4,a5,a6])
    ...: 
262 µs ± 4.29 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [21]: %%timeit
    ...: torch.multi_mm(a1, a2, a3, a4, a5, a6)
    ...: 
23.6 µs ± 225 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

@zou3519
Copy link
Contributor

zou3519 commented Oct 5, 2018

I'm not sure einsum is the best thing to compare this to. Could you do a direct comparision with m1 @ m2 @ m3 @... @ m6 ?

@vishwakftw
Copy link
Contributor Author

In [13]: a1 = torch.randn(300, 350).double()

In [14]: a2 = torch.randn(350, 150).double()

In [15]: a3 = torch.randn(150, 50).double()

In [16]: a4 = torch.randn(50, 10).double()

In [17]: a5 = torch.randn(10, 200).double()

In [18]: a6 = torch.randn(200, 25).double()

In [19]: %%timeit
    ...: torch.multi_mm(a1, a2, a3, a4, a5, a6)
    ...: 
178 µs ± 6.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [20]: %%timeit
    ...: a1 @ a2 @ a3 @ a4 @ a5 @ a6
    ...: 
767 µs ± 4.25 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

@vishwakftw vishwakftw changed the title [WIP] Introduce multi_mm [ready] Introduce multi_mm Oct 6, 2018
@vishwakftw
Copy link
Contributor Author

@zou3519 This is ready for review.

@apaszke
Copy link
Contributor

apaszke commented Oct 7, 2018

Eh is there any chance to revert the reshuffling? It makes the diff unnecessarily large. Those changes are completely meaningless unless they are enforced by the CI, because I'm sure that the order will be messed up in a week or two.



Args:
matrices (list of Tensors): list of 2-D tensors whose product is to be determined.

This comment was marked as off-topic.

r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
of arithmetic operations (`[CLRS]`_). Note that :math:`N` needs to be greater than or equal to 2; if equal to 2
then a trivial matrix-matrix product is returned.

This comment was marked as off-topic.

@fmassa
Copy link
Member

fmassa commented Oct 7, 2018

I think this is nice! Question: can't those cost function optimizations be used in einsum as well?

@vishwakftw
Copy link
Contributor Author

@fmassa I went through the einsum implementation, and it seems to use a lot of permutations for preprocessing the operands. With the current implementation, I am not sure if these optimizations can be leveraged.

@apaszke
Copy link
Contributor

apaszke commented Oct 8, 2018

BTW is multi_mm really what NumPy calls this? The term I've always heard is "matrix chain multiplication", so mm_chain seems like a better one.

@vishwakftw
Copy link
Contributor Author

NumPy calls it multi_dot.

@apaszke
Copy link
Contributor

apaszke commented Oct 8, 2018

Well since we're not calling it that anyway, why not clean up the multi_ prefix that doesn't fit it all too well? Tbh when I first read the title of this PR I was expecting sth like batched mm, but possibly for matrices of mismatched sizes (e.g. passed in as lists). chain_mm or mm_chain seem nice. Finally, it might be the best to make it chain_matmul instead of limiting it to 2D (although I guess that might complicate the implementation a bit).

@vishwakftw
Copy link
Contributor Author

I am sorry to disappoint with the name.

Regarding the name of the function, I'll name it chain_mm (mm_chain looks like a intrinsic function sans the prefix _).

An extension to a matmul should be feasible, which I'll look at soon.

@vishwakftw vishwakftw changed the title [ready] Introduce multi_mm [ready] Introduce chain_mm Oct 8, 2018
@apaszke
Copy link
Contributor

apaszke commented Oct 8, 2018

Can we just make it chain_matmul, and assert that all elements are matrices? We can relax the constraint in the future.

Also, don't stress out about the name! That's what NumPy calls it, so it was a very reasonable choice too.

@vishwakftw vishwakftw changed the title [ready] Introduce chain_mm [ready] Introduce chain_matmul Oct 8, 2018

Tensor chain_matmul(TensorList matrices) {
AT_CHECK(matrices.size() >= 2, "Expecting at least 2 matrices");
checkAllSameDim(matrices, 2);

This comment was marked as off-topic.

def chain_matmul(*matrices):
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
of arithmetic operations (`[CLRS]`_). Note that since is a function to compute the product, :math:`N`

This comment was marked as off-topic.

r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
of arithmetic operations (`[CLRS]`_). Note that since is a function to compute the product, :math:`N`
needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


.. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
"""
if len(matrices) == 1 and isinstance(matrices[0], (list, tuple)):

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Algorithm looks correct to me. :)

} else {

// Following the algorithm in Chapter 15.2 : Introduction to Algorithms, Cormen et al.
// Minor modifications have be made to accommodate zero-indexing

This comment was marked as off-topic.

}

// Cost matrix
std::vector<std::vector<double>> m(n, std::vector<double>(n, 0));

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -586,5 +586,81 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
return at::sum_out(result, std::get<1>(at::svd(self)), 0, keepdim);
}

Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {

This comment was marked as off-topic.

@ssnl
Copy link
Collaborator

ssnl commented Oct 10, 2018

@vishwakftw Our CircleCI build has a problem that can be fixed by git pull --rebase upstream master. Could you do that and then force push? Thanks!

@vishwakftw
Copy link
Contributor Author

@ssnl is this good to go?

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

// parenthesizing matrices A_{i} to A_{j}. By this definition m[i, i] = 0 for all i
// m[i, j] is filled using the substructure property of the algorithm, meaning:
// m[i, j] = min_{i <= k < j} m[i, k] + m[k, j] + p_{i-1}p_{k}p_{j}
std::vector<std::vector<int64_t>> m(n, std::vector<int64_t>(n, 0));

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@vishwakftw
Copy link
Contributor Author

@pytorchbot retest this please

@ssnl
Copy link
Collaborator

ssnl commented Oct 10, 2018

Sorry, pytorch bot doesn't work with circlr ci unfortunately. Could you rebase and push again? Thanks!

- This was the only function left out from the list of functions in
  NumPy's linalg module
- `multi_mm` is particularly useful for DL research, for quick analysis of
  deep linear networks

To do:
- Add tests
N.B.: I took the opportunity of shuffling some of the functions based on alphabetical order
@zou3519
Copy link
Contributor

zou3519 commented Oct 11, 2018

I think our windows builds are broken so you can ignore those. The asan failure seems to be real though

@vishwakftw
Copy link
Contributor Author

@zou3519 I have fixed it

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SsnL is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@vishwakftw vishwakftw deleted the multi_dot branch October 12, 2018 11:03
zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 12, 2018
Summary:
- This was one of the few functions left out from the list of functions in
  NumPy's `linalg` module
- `multi_mm` is particularly useful for DL research, for quick analysis of
  deep linear networks
- Added tests and doc string
Pull Request resolved: pytorch/pytorch#12380

Differential Revision: D10357136

Pulled By: SsnL

fbshipit-source-id: 52b44fa18d6409bdeb76cbbb164fe4e88224458e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants