-
Notifications
You must be signed in to change notification settings - Fork 24.4k
[ready] Introduce chain_matmul #12380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
One benchmark on the CPU (taken out of an exercise in CLRS):
|
I'm not sure einsum is the best thing to compare this to. Could you do a direct comparision with m1 @ m2 @ m3 @... @ m6 ? |
|
@zou3519 This is ready for review. |
Eh is there any chance to revert the reshuffling? It makes the diff unnecessarily large. Those changes are completely meaningless unless they are enforced by the CI, because I'm sure that the order will be messed up in a week or two. |
torch/functional.py
Outdated
|
||
|
||
Args: | ||
matrices (list of Tensors): list of 2-D tensors whose product is to be determined. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/functional.py
Outdated
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed | ||
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms | ||
of arithmetic operations (`[CLRS]`_). Note that :math:`N` needs to be greater than or equal to 2; if equal to 2 | ||
then a trivial matrix-matrix product is returned. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
I think this is nice! Question: can't those cost function optimizations be used in einsum as well? |
@fmassa I went through the einsum implementation, and it seems to use a lot of permutations for preprocessing the operands. With the current implementation, I am not sure if these optimizations can be leveraged. |
BTW is |
NumPy calls it |
Well since we're not calling it that anyway, why not clean up the |
I am sorry to disappoint with the name. Regarding the name of the function, I'll name it An extension to a |
Can we just make it Also, don't stress out about the name! That's what NumPy calls it, so it was a very reasonable choice too. |
|
||
Tensor chain_matmul(TensorList matrices) { | ||
AT_CHECK(matrices.size() >= 2, "Expecting at least 2 matrices"); | ||
checkAllSameDim(matrices, 2); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/functional.py
Outdated
def chain_matmul(*matrices): | ||
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed | ||
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms | ||
of arithmetic operations (`[CLRS]`_). Note that since is a function to compute the product, :math:`N` |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed | ||
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms | ||
of arithmetic operations (`[CLRS]`_). Note that since is a function to compute the product, :math:`N` | ||
needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/functional.py
Outdated
|
||
.. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition | ||
""" | ||
if len(matrices) == 1 and isinstance(matrices[0], (list, tuple)): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Algorithm looks correct to me. :)
} else { | ||
|
||
// Following the algorithm in Chapter 15.2 : Introduction to Algorithms, Cormen et al. | ||
// Minor modifications have be made to accommodate zero-indexing |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
} | ||
|
||
// Cost matrix | ||
std::vector<std::vector<double>> m(n, std::vector<double>(n, 0)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -586,5 +586,81 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) { | |||
return at::sum_out(result, std::get<1>(at::svd(self)), 0, keepdim); | |||
} | |||
|
|||
Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@vishwakftw Our CircleCI build has a problem that can be fixed by |
@ssnl is this good to go? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
// parenthesizing matrices A_{i} to A_{j}. By this definition m[i, i] = 0 for all i | ||
// m[i, j] is filled using the substructure property of the algorithm, meaning: | ||
// m[i, j] = min_{i <= k < j} m[i, k] + m[k, j] + p_{i-1}p_{k}p_{j} | ||
std::vector<std::vector<int64_t>> m(n, std::vector<int64_t>(n, 0)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@pytorchbot retest this please |
Sorry, pytorch bot doesn't work with circlr ci unfortunately. Could you rebase and push again? Thanks! |
- This was the only function left out from the list of functions in NumPy's linalg module - `multi_mm` is particularly useful for DL research, for quick analysis of deep linear networks To do: - Add tests
N.B.: I took the opportunity of shuffling some of the functions based on alphabetical order
- This seems to have gotten missed out from the rebase
I think our windows builds are broken so you can ignore those. The asan failure seems to be real though |
@zou3519 I have fixed it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SsnL is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: - This was one of the few functions left out from the list of functions in NumPy's `linalg` module - `multi_mm` is particularly useful for DL research, for quick analysis of deep linear networks - Added tests and doc string Pull Request resolved: pytorch/pytorch#12380 Differential Revision: D10357136 Pulled By: SsnL fbshipit-source-id: 52b44fa18d6409bdeb76cbbb164fe4e88224458e
NumPy's
linalg
modulemulti_mm
is particularly useful for DL research, for quick analysis ofdeep linear networks