-
Notifications
You must be signed in to change notification settings - Fork 129
Add linalg.block_diag
and sparse equivalent
#576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #576 +/- ##
========================================
Coverage 80.92% 80.92%
========================================
Files 162 162
Lines 46524 46644 +120
Branches 11375 11401 +26
========================================
+ Hits 37648 37746 +98
- Misses 6653 6668 +15
- Partials 2223 2230 +7
|
The floatX functions are actually pretty crappy and should be phased out completely. For instance they fail when you pass a python list of TensorVariables. PyTensor has tools (including config flags) to handle type promotions that should be relied upon instead or refactored if they don't satisfy our current needs https://pytensor.readthedocs.io/en/latest/library/config.html#config.cast_policy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just some small tweaks
Also are dispatch for JAX/Numba trivial that we could support already? |
One thing I don't know is whether mixing the sparse and dense in a single Op is the best internal API. Maybe that's fine. I guess it mostly depends on whether rewrites can ignore this info or whether they will have to always reason differently depending on it. However, I think the helper function should follow scipy API, so there would be a |
Co-authored-by: Ricardo Vieira <[email protected]>
Co-authored-by: Ricardo Vieira <[email protected]>
I think numba/jax should be easy, i'll work on it. The other thing would be to make it Blockwise. With respect to the API I totally agree about splitting sparse and dense into |
Can you open an issue in PyMC to deprecate the code there? |
Closely follow scipy function signature for `block_diag`
I'm going to need some handholding with turning it into Blockwise. I think the gfunc_sig should just be |
…allow sparse matrix inputs to `pytensor.sparse.block_diag`
4e04ac9
to
8ac8f50
Compare
f88c27c
to
491111b
Compare
Signatures can be defined dynamically when the Op is initialized, that's how we do it for solve. For this Op you can do something like we do here pytensor/pytensor/tensor/blockwise.py Line 29 in e180927
signature="(m0,n0),(m1,n1),...(mn,nn)->(m,n)" You can add the number of inputs as a parameter of the Op, so it's known at initialization. The inputs don't have to have the same shape right, just ndim? Could be |
Yes you're right about sizes Do I need to re-write the jax/numba overloads to handle batch dims, or is it just unsupported so far? |
Jax Blockwise will work with vmap, only need to dispatch the base case like you did. The point of Blockwise is exactly that. Since the batch dims always work the same way, we only need to bother specifying the core case. Numba doesn't yet have support for Blockwise. We should be able to do something simple with guvectorize although it has some limitations like only working with single outputs |
I guess we can't blockwise Sparse Ops? Blockwise uses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only nitpicks left, address what you want and close what you don't!
pytensor/tensor/slinalg.py
Outdated
Parameters | ||
---------- | ||
A, B, C ... : tensors | ||
Input matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions, and the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not correct. Blockwise accepts different number of batch dims and also broadcasts when they have length 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got errors when I tried tensors with different batch dims, but I didn't try broadcasting to dimensions with size 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you still see errors? Blockwise should introduce expand dims, so the only failure case would be broadcasting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broadcasting works, I added a test for it. It was failing when I tried different batch sizes, which doesn't make sense anyway I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean different batch sizes? Blockwise adds expand dims automatically to align the number of batch dims, so that shouldn't be possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This errors:
# Different batch sizes
A = np.random.normal(size=(batch_size + 3, 2, 2)).astype(config.floatX)
B = np.random.normal(size=(batch_size, 4, 4)).astype(config.floatX)
result = block_diag(A, B).eval()
with:
E ValueError: Incompatible Blockwise batch input shapes [(8, 2, 2), (5, 4, 4)]
But I think it's supposed to. What does it even mean to batch those two together?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's invalid, batch shapes must be broadcastable. 8 and 5 are not broadcastable.
I thought you were saying inputs with different number of dimensions were failing
We don't have a sparse type with |
pymc.math
to pytensorlinalg.block_diag
and sparse equivalent
Remove `Matrix` from `BlockDiagonal` and `SparseBlockDiagonal` `Op` names Correct errors in docstrings Move input validation to a shared class method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty neat! Two not so nitpick comments.
Also nice catch with the missing test for grad
Co-authored-by: Ricardo Vieira <[email protected]>
Description
There are a number of generic math-type operations in
pymc.math
that make more sense inpytensor
. This PR moves one of them --block_diag
, topytensor.tensor.slinalg
. There are others that could likely be moved as well.block_diag
uses a few support functions that also present opportunities to implement some missing numpy functions in pytensor. I re-wrotelargest_common_dtype
to usenp.promote_types
, but there should probably be apt.promote_types
andpt.result_type
. The latter should just replace this helper function entirely.There was also a function
ix
, which is an implementation ofnp.ix_
. We're missing all of these numpy quick constructors:np.r_
,np.c_
,np.ogrid
,np.mgrid
. Plusnp.meshgrid
, which is an actual function people use.floatX
andintX
should probably also be pytensor functions instead of PyMC functions.I tagged this as relevant to #573 because I'm interested in experimenting with linear algebra rewrites out of the
block_diag
Op
.Related Issue
Checklist
Type of change