Skip to content

Fix broadcasting via observed and dims #6063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 29, 2022

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Aug 25, 2022

Fixes #5993

This PR also removes the support of ellipsis in dims and shape. This feature was not very advertised so hopefully this won't cause too much pain.

These tests illustrate the bug fix

    def test_broadcast_by_dims(self):
        with pm.Model(coords={"broadcast_dim": range(3)}) as m:
            x = pm.Normal("x", mu=np.zeros((1,)), dims=("broadcast_dim",))
            assert x.eval().shape == (3,)

    def test_broadcast_by_observed(self):
        with pm.Model() as m:
            x = pm.Normal("x", mu=np.zeros((1,)), observed=np.zeros((3,)))
            assert x.eval().shape == (3,)

Checklist

Major / Breaking Changes

  • When shape is not provided, dims or observed are used to define the shape of a variable. If you want a variable that has dims or is observed to automatically resize when its inputs change, you must specify that explicitly via the shape argument. For example, pm.Normal("likelihood", mu=mu, sigma=sigma, observed=data, dims="data", shape=mu.shape)
  • Remove support of Ellipsis (...) in shape and dims

Bugfixes / New features

  • Fix bug where distribution shape would not be broadcasted by dims or observed
  • Allow specifying dims on the fly from observed
  • Do not show shape-related dependencies of RandomVariables in model_graph

Docs / Maintenance

  • Update _make_nice_attr_error to suggest pm.draw instead of .eval()
  • Add _make_nice_attr_error to SymbolicDistributions

@ricardoV94 ricardoV94 changed the title Fix observed dims resize Fix broadcasting via observed and dims Aug 25, 2022
@ricardoV94 ricardoV94 force-pushed the fix_observed_dims_resize branch from 18f5c8b to 2460b1e Compare August 25, 2022 14:47
@ricardoV94
Copy link
Member Author

ricardoV94 commented Aug 25, 2022

Started a discussion for removing dims on the fly: #6065

@ricardoV94
Copy link
Member Author

ricardoV94 commented Aug 25, 2022

I seem to have broken some things according to the failing test_density_scaling_with_generator. It seems that we support generators as observed data... which means we can only inquiry about data once (and then we must pass it around). Is this something we really need?

Seems to be used by Minibatch, cc @ferrine

@codecov
Copy link

codecov bot commented Aug 25, 2022

Codecov Report

Merging #6063 (c9660b6) into main (fa5e441) will increase coverage by 0.01%.
The diff coverage is 96.66%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6063      +/-   ##
==========================================
+ Coverage   89.52%   89.53%   +0.01%     
==========================================
  Files          72       72              
  Lines       12944    12929      -15     
==========================================
- Hits        11588    11576      -12     
+ Misses       1356     1353       -3     
Impacted Files Coverage Δ
pymc/distributions/distribution.py 91.02% <94.73%> (+0.18%) ⬆️
pymc/distributions/shape_utils.py 99.27% <100.00%> (+2.19%) ⬆️
pymc/model.py 88.23% <100.00%> (+0.03%) ⬆️
pymc/model_graph.py 78.94% <100.00%> (+0.63%) ⬆️
pymc/gp/gp.py 92.73% <0.00%> (-0.45%) ⬇️
pymc/distributions/multivariate.py 92.03% <0.00%> (+0.02%) ⬆️
pymc/step_methods/hmc/base_hmc.py 90.55% <0.00%> (+0.78%) ⬆️

@ricardoV94 ricardoV94 force-pushed the fix_observed_dims_resize branch 6 times, most recently from 9e348dc to f173cdd Compare August 25, 2022 17:52
@ricardoV94 ricardoV94 force-pushed the fix_observed_dims_resize branch 2 times, most recently from b907dc1 to ed9e0ff Compare August 26, 2022 06:28
@ricardoV94
Copy link
Member Author

Tests are passing

@michaelosthege
Copy link
Member

Tests are passing

👍 I will review this tonight!

@ricardoV94
Copy link
Member Author

Tests are passing

+1 I will review this tonight!

I'll wait for it!

@ricardoV94 ricardoV94 mentioned this pull request Aug 26, 2022
5 tasks
@ricardoV94 ricardoV94 force-pushed the fix_observed_dims_resize branch from ed9e0ff to b48a049 Compare August 26, 2022 18:27
This removes visual dependencies between observed data and likelihood, due to flow of shape information
@ricardoV94 ricardoV94 force-pushed the fix_observed_dims_resize branch from b48a049 to f3fc406 Compare August 26, 2022 18:36
Comment on lines +300 to +302
y = pm.Normal("y", observed=[0, 0, 0], dims="ddata")
assert pmodel.RV_dims["y"] == ("ddata",)
assert y.eval().shape == (3,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you want to broadcast in this test too?

Suggested change
y = pm.Normal("y", observed=[0, 0, 0], dims="ddata")
assert pmodel.RV_dims["y"] == ("ddata",)
assert y.eval().shape == (3,)
y = pm.Normal("y", observed=[[0, 0, 0]], dims=("dobs", "ddata"))
assert pmodel.RV_dims["y"] == ("dobs", "ddata")
assert y.eval().shape == (1, 3)

Copy link
Member Author

@ricardoV94 ricardoV94 Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, is this not already covered by test_define_dims_on_the_fly_from_observed? If it is the same I would like to keep it separate, in case we decide to drop support for dims on the fly.

# Auto-complete the dims tuple to the full length.
# We don't have a way to know the names of implied
# dimensions, so they will be `None`.
dims = (*dims[:-1], *[None] * ndim_implied)
sdims = cast(StrongDims, dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signature violation: The previous dims: WeakDims parameter was already a sequence, but with dims: Dims the new signature allows for dims="bla".
This is not accounted for in the implementation; the cast(StrongDims, dims) is not safe yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the cast, now that the signature of shape_from_dims requires StrongDims

@ricardoV94
Copy link
Member Author

@michaelosthege Do you see a difference between broadcasting by dims and by observed? Because all the same changes would have been needed if the tests had used dims instead of observed and then assumed you could just resize a variable indirectly by resizing it's arguments.

@michaelosthege
Copy link
Member

@michaelosthege Do you see a difference between broadcasting by dims and by observed? Because all the same changes would have been needed if the tests had used dims instead of observed and then assumed you could just resize a variable indirectly by resizing it's arguments.

I see you point, but I think for most people the intuition is that shape, size and dims are an input of the RV, whereas the observed is something that the RV points to.
That's also how we draw it in pm.model_to_graphviz and it resembles the direction of the generative process.

That's probably why I don't like to broadcast automatically based on observed. (I found that strange and complicated when I worked on these implementations in the past..)

@ricardoV94
Copy link
Member Author

ricardoV94 commented Aug 27, 2022

The graphviz is a bit strange because we use 2 nodes (RV and Value) to represent what is done by a single one for unobserved RVs.

Anyway, I think:

  1. The output of test_broadcast_by_observed would be very surprising before this PR. Also would be very surprising if we ignored extra dimensions implied by observed (which we don't)
  2. We should make it clear (in the dimensionality notebook?) that in the generative graph, observed is used as a shortcut for shape=observed.shape unless dims or shape is provided. I hope this provides a good mental picture so users don't expect indirect resizing by default.
  3. We should make it clear that the shape argument is more powerful than before. It can take arbitrary symbolic expressions. Hopefully it will become intuitive to specify shape=mu.shape or whatever as we start adding it to official examples.

I agree it's more cumbersome, but it does push towards being more explicit about shapes, which we are already doing anyway when we suggest users use dims.

I agree with making it a minor bigger (what's the name for intermediate?) release.

@ricardoV94 ricardoV94 marked this pull request as draft August 28, 2022 11:34
@michaelosthege michaelosthege added bug shape problem major Include in major changes release notes section labels Aug 28, 2022
@ricardoV94 ricardoV94 force-pushed the fix_observed_dims_resize branch from f3fc406 to 095bd87 Compare August 29, 2022 08:36
@ricardoV94 ricardoV94 marked this pull request as ready for review August 29, 2022 08:37
@ricardoV94 ricardoV94 force-pushed the fix_observed_dims_resize branch from 095bd87 to c9660b6 Compare August 29, 2022 08:54
Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't review super-thoroughly right now, but I trust that @ricardoV94 addressed my concerns.

I'm still not super happy about prioritizing observed.shape over the implied dimensionality, but we'll see how the users take it.
I know enough about PyMC so my own models don't blow up because of it ¯_(ツ)_/¯

@ricardoV94 ricardoV94 merged commit 92ce135 into pymc-devs:main Aug 29, 2022
@ricardoV94 ricardoV94 deleted the fix_observed_dims_resize branch June 6, 2023 03:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug major Include in major changes release notes section shape problem
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implied dimensions of RVs are not broadcasted by dims or observed
4 participants