-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix broadcasting via observed and dims #6063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix broadcasting via observed and dims #6063
Conversation
18f5c8b
to
2460b1e
Compare
Started a discussion for removing dims on the fly: #6065 |
I seem to have broken some things according to the failing Seems to be used by Minibatch, cc @ferrine |
Codecov Report
@@ Coverage Diff @@
## main #6063 +/- ##
==========================================
+ Coverage 89.52% 89.53% +0.01%
==========================================
Files 72 72
Lines 12944 12929 -15
==========================================
- Hits 11588 11576 -12
+ Misses 1356 1353 -3
|
9e348dc
to
f173cdd
Compare
b907dc1
to
ed9e0ff
Compare
Tests are passing |
👍 I will review this tonight! |
I'll wait for it! |
ed9e0ff
to
b48a049
Compare
This removes visual dependencies between observed data and likelihood, due to flow of shape information
b48a049
to
f3fc406
Compare
y = pm.Normal("y", observed=[0, 0, 0], dims="ddata") | ||
assert pmodel.RV_dims["y"] == ("ddata",) | ||
assert y.eval().shape == (3,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you want to broadcast in this test too?
y = pm.Normal("y", observed=[0, 0, 0], dims="ddata") | |
assert pmodel.RV_dims["y"] == ("ddata",) | |
assert y.eval().shape == (3,) | |
y = pm.Normal("y", observed=[[0, 0, 0]], dims=("dobs", "ddata")) | |
assert pmodel.RV_dims["y"] == ("dobs", "ddata") | |
assert y.eval().shape == (1, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, is this not already covered by test_define_dims_on_the_fly_from_observed
? If it is the same I would like to keep it separate, in case we decide to drop support for dims on the fly.
pymc/distributions/shape_utils.py
Outdated
# Auto-complete the dims tuple to the full length. | ||
# We don't have a way to know the names of implied | ||
# dimensions, so they will be `None`. | ||
dims = (*dims[:-1], *[None] * ndim_implied) | ||
sdims = cast(StrongDims, dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signature violation: The previous dims: WeakDims
parameter was already a sequence, but with dims: Dims
the new signature allows for dims="bla"
.
This is not accounted for in the implementation; the cast(StrongDims, dims)
is not safe yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the cast, now that the signature of shape_from_dims
requires StrongDims
@michaelosthege Do you see a difference between broadcasting by dims and by observed? Because all the same changes would have been needed if the tests had used dims instead of observed and then assumed you could just resize a variable indirectly by resizing it's arguments. |
I see you point, but I think for most people the intuition is that shape, size and dims are an input of the RV, whereas the observed is something that the RV points to. That's probably why I don't like to broadcast automatically based on |
The graphviz is a bit strange because we use 2 nodes (RV and Value) to represent what is done by a single one for unobserved RVs. Anyway, I think:
I agree it's more cumbersome, but it does push towards being more explicit about shapes, which we are already doing anyway when we suggest users use dims. I agree with making it a |
f3fc406
to
095bd87
Compare
095bd87
to
c9660b6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't review super-thoroughly right now, but I trust that @ricardoV94 addressed my concerns.
I'm still not super happy about prioritizing observed.shape
over the implied dimensionality, but we'll see how the users take it.
I know enough about PyMC so my own models don't blow up because of it ¯_(ツ)_/¯
Fixes #5993
This PR also removes the support of ellipsis in dims and shape. This feature was not very advertised so hopefully this won't cause too much pain.
These tests illustrate the bug fix
Checklist
Major / Breaking Changes
pm.Normal("likelihood", mu=mu, sigma=sigma, observed=data, dims="data", shape=mu.shape)
...
) in shape and dimsBugfixes / New features
model_graph
Docs / Maintenance
_make_nice_attr_error
to suggestpm.draw
instead of.eval()
_make_nice_attr_error
toSymbolicDistribution
s