-
Notifications
You must be signed in to change notification settings - Fork 136
Small fixups to xtensor type and XRV #1503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looks good to me (though the 4x nested max thing sort of has my head spinning).
There seems to be a bug/oversight with double assignment of broadcastable.
@@ -392,6 +392,13 @@ def make_node(self, rng, size, *dist_params): | |||
out_type = TensorType(dtype=self.dtype, shape=static_shape) | |||
outputs = (rng.type(), out_type()) | |||
|
|||
if self.dtype == "floatX": | |||
# Commit to a specific float type if the Op is still using "floatX" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is dtype = 'floatX' being depreciated? (I'm trying to guess what "still" means here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you create a RandomVariable Op you can specify dtype="floatX"
at the Op level. But when we make an actual node we need to commit to one dtype, since floatX is not a real thing.
If you call __call__
we already commit to a dtype, and this is where users can specify a custom one. But if you call directly make_node
like XRV
does, it doesn't go through this step. It's a quirk of how we are wrapping RV ops in xtensor, but in theory if you have an Op you should always be able to call make_node
and get a valid graph.
@@ -41,7 +42,14 @@ def _as_xrv( | |||
core_out_dims_map = tuple(range(core_op.ndim_supp)) | |||
|
|||
core_dims_needed = max( | |||
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0 | |||
max( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to check my understanding:
This returns how many core dims the "broadcasting" between the inputs and outputs will have? For each input "map", it's returning the largest core dim index, then the largest core dim among all inputs, then the largest between the inputs and the outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quite. the mapping tells if the user passes a list of n core dims (say 2 in the MvNormal, the covariance dims), which of these correspond to each input / output, positionally.
From this it is trivial to infer how many the user has to pass, so we can give an automatic useful message. With zero based index you need to pass a sequence that is as long as the largest index + 1. The problem is there is a difference between 0 and empty in this case, which we weren't handling correctly before.
Also float32 tests are failing but I guess you know that :) |
* Fix core_dims_needed calculation * Handle lazy dtype * Nicer __str__ with use of `name`
…Graph once lowered
The failing tests are due to the new scipy I think |
d84e9c6
to
791e006
Compare
791e006
to
9ecfc10
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1503 +/- ##
==========================================
- Coverage 81.99% 81.98% -0.01%
==========================================
Files 231 231
Lines 52253 52274 +21
Branches 9203 9206 +3
==========================================
+ Hits 42843 42856 +13
- Misses 7099 7106 +7
- Partials 2311 2312 +1
🚀 New features to boost your workflow:
|
Showed up when trying to integrate with PyMC in pymc-devs/pymc#7820
📚 Documentation preview 📚: https://pytensor--1503.org.readthedocs.build/en/1503/