Skip to content

Add linalg.block_diag and sparse equivalent #576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jan 7, 2024

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jan 6, 2024

Description

There are a number of generic math-type operations in pymc.math that make more sense in pytensor. This PR moves one of them -- block_diag, to pytensor.tensor.slinalg. There are others that could likely be moved as well.

block_diag uses a few support functions that also present opportunities to implement some missing numpy functions in pytensor. I re-wrote largest_common_dtype to use np.promote_types, but there should probably be a pt.promote_types and pt.result_type. The latter should just replace this helper function entirely.

There was also a function ix, which is an implementation of np.ix_. We're missing all of these numpy quick constructors: np.r_, np.c_, np.ogrid, np.mgrid. Plus np.meshgrid, which is an actual function people use.

floatX and intX should probably also be pytensor functions instead of PyMC functions.

I tagged this as relevant to #573 because I'm interested in experimenting with linear algebra rewrites out of the block_diag Op.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@codecov-commenter
Copy link

codecov-commenter commented Jan 6, 2024

Codecov Report

Attention: 24 lines in your changes are missing coverage. Please review.

Comparison is base (e180927) 80.92% compared to head (a9893b8) 80.92%.
Report is 4 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff            @@
##             main     #576    +/-   ##
========================================
  Coverage   80.92%   80.92%            
========================================
  Files         162      162            
  Lines       46524    46644   +120     
  Branches    11375    11401    +26     
========================================
+ Hits        37648    37746    +98     
- Misses       6653     6668    +15     
- Partials     2223     2230     +7     
Files Coverage Δ
pytensor/link/jax/dispatch/slinalg.py 93.54% <100.00%> (+1.24%) ⬆️
pytensor/tensor/basic.py 88.32% <100.00%> (-0.15%) ⬇️
pytensor/sparse/basic.py 82.57% <81.81%> (+0.08%) ⬆️
pytensor/tensor/slinalg.py 93.70% <85.00%> (-0.94%) ⬇️
pytensor/link/numba/dispatch/slinalg.py 45.60% <17.64%> (-3.20%) ⬇️

... and 6 files with indirect coverage changes

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 6, 2024

The floatX functions are actually pretty crappy and should be phased out completely. For instance they fail when you pass a python list of TensorVariables.

PyTensor has tools (including config flags) to handle type promotions that should be relied upon instead or refactored if they don't satisfy our current needs

https://pytensor.readthedocs.io/en/latest/library/config.html#config.cast_policy

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just some small tweaks

@ricardoV94
Copy link
Member

Also are dispatch for JAX/Numba trivial that we could support already?

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 6, 2024

One thing I don't know is whether mixing the sparse and dense in a single Op is the best internal API. Maybe that's fine.

I guess it mostly depends on whether rewrites can ignore this info or whether they will have to always reason differently depending on it.

However, I think the helper function should follow scipy API, so there would be a tensor.linalg.block_diag that builds the dense case (or in this case can just use a pre-built Op) and a tensor.sparse.block_diag for the sparse cases.

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jan 6, 2024

I think numba/jax should be easy, i'll work on it. The other thing would be to make it Blockwise.

With respect to the API I totally agree about splitting sparse and dense into pytensor.tensor.slinalg.block_diag and pytensor.sparse.block_diag. I will also refactor the function to match the scipy function, which takes the matrices as args instead of a single list.

@ricardoV94
Copy link
Member

Can you open an issue in PyMC to deprecate the code there?

Closely follow scipy function signature for `block_diag`
@jessegrabowski
Copy link
Member Author

I'm going to need some handholding with turning it into Blockwise. I think the gfunc_sig should just be "(n,m)->(o,p)", but there are an arbitrary number of inputs all of the same size, so do I need to represent that somehow?

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 6, 2024

Signatures can be defined dynamically when the Op is initialized, that's how we do it for solve.

For this Op you can do something like we do here

def safe_signature(

signature="(m0,n0),(m1,n1),...(mn,nn)->(m,n)"

You can add the number of inputs as a parameter of the Op, so it's known at initialization.

The inputs don't have to have the same shape right, just ndim? Could be foo([[0]], [[1,2,3],[4,5,6]])?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jan 6, 2024

Yes you're right about sizes

Do I need to re-write the jax/numba overloads to handle batch dims, or is it just unsupported so far?

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 6, 2024

Jax Blockwise will work with vmap, only need to dispatch the base case like you did.

The point of Blockwise is exactly that. Since the batch dims always work the same way, we only need to bother specifying the core case.

Numba doesn't yet have support for Blockwise. We should be able to do something simple with guvectorize although it has some limitations like only working with single outputs

@jessegrabowski
Copy link
Member Author

I guess we can't blockwise Sparse Ops? Blockwise uses pt.as_tensor, which breaks if the inputs are sparse. Will need to add some logic to handle that case. This would be the first function it would work with, I think. Otherwise, though, I think this is done.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only nitpicks left, address what you want and close what you don't!

Parameters
----------
A, B, C ... : tensors
Input matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions, and the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct. Blockwise accepts different number of batch dims and also broadcasts when they have length 1.

Copy link
Member Author

@jessegrabowski jessegrabowski Jan 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got errors when I tried tensors with different batch dims, but I didn't try broadcasting to dimensions with size 1.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still see errors? Blockwise should introduce expand dims, so the only failure case would be broadcasting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadcasting works, I added a test for it. It was failing when I tried different batch sizes, which doesn't make sense anyway I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean different batch sizes? Blockwise adds expand dims automatically to align the number of batch dims, so that shouldn't be possible?

Copy link
Member Author

@jessegrabowski jessegrabowski Jan 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This errors:

    # Different batch sizes
    A = np.random.normal(size=(batch_size + 3, 2, 2)).astype(config.floatX)
    B = np.random.normal(size=(batch_size, 4, 4)).astype(config.floatX)
    result = block_diag(A, B).eval()

with:

E           ValueError: Incompatible Blockwise batch input shapes [(8, 2, 2), (5, 4, 4)]

But I think it's supposed to. What does it even mean to batch those two together?

Copy link
Member

@ricardoV94 ricardoV94 Jan 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's invalid, batch shapes must be broadcastable. 8 and 5 are not broadcastable.

I thought you were saying inputs with different number of dimensions were failing

@ricardoV94
Copy link
Member

I guess we can't blockwise Sparse Ops? Blockwise uses pt.as_tensor, which breaks if the inputs are sparse. Will need to add some logic to handle that case. This would be the first function it would work with, I think. Otherwise, though, I think this is done.

We don't have a sparse type with ndim != 2 so that's only the tip of the iceberg.

@ricardoV94 ricardoV94 changed the title Migrate generic pytensor functions from pymc.math to pytensor Add linalg.block_diag and sparse equivalent Jan 7, 2024
Remove `Matrix` from `BlockDiagonal` and `SparseBlockDiagonal` `Op` names

Correct errors in docstrings

Move input validation to a shared class method
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty neat! Two not so nitpick comments.

Also nice catch with the missing test for grad

@jessegrabowski jessegrabowski merged commit c4ae6e3 into pymc-devs:main Jan 7, 2024
@jessegrabowski jessegrabowski deleted the block-diag branch January 7, 2024 03:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants