-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
pass names from RV dims to change_rv_size #5931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -145,6 +145,7 @@ def convert_observed_data(data): | |
def change_rv_size( | ||
rv: TensorVariable, | ||
new_size: PotentialShapeType, | ||
new_size_dims: Optional[tuple] = (None,), | ||
expand: Optional[bool] = False, | ||
) -> TensorVariable: | ||
"""Change or expand the size of a `RandomVariable`. | ||
|
@@ -155,6 +156,8 @@ def change_rv_size( | |
The old `RandomVariable` output. | ||
new_size | ||
The new size. | ||
new_size_dims | ||
dim names of the new size vector | ||
expand: | ||
Expand the existing size by `new_size`. | ||
|
||
|
@@ -166,6 +169,10 @@ def change_rv_size( | |
elif new_size_ndim == 0: | ||
new_size = (new_size,) | ||
|
||
# wrap None in tuple, if new_size_dims are None | ||
if new_size_dims is None: | ||
new_size_dims = (None,) | ||
|
||
# Extract the RV node that is to be resized, together with its inputs, name and tag | ||
assert rv.owner.op is not None | ||
if isinstance(rv.owner.op, SpecifyShape): | ||
|
@@ -180,9 +187,13 @@ def change_rv_size( | |
size = shape[: len(shape) - rv_node.op.ndim_supp] | ||
new_size = tuple(new_size) + tuple(size) | ||
|
||
# create the name of the RV's resizing tensor | ||
# TODO: add information where the dim is coming from (obseverd, prior, ...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this todo going to be implemented in this PR or is this for later? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, I think it would be helpful. Just wanted to make sure I'm on the right track before I spend more time on it |
||
new_size_name = f"Broadcast to {new_size_dims[0]}_dim" | ||
|
||
# Make sure the new size is a tensor. This dtype-aware conversion helps | ||
# to not unnecessarily pick up a `Cast` in some cases (see #4652). | ||
new_size = at.as_tensor(new_size, ndim=1, dtype="int64") | ||
new_size = at.as_tensor(new_size, ndim=1, dtype="int64", name=new_size_name) | ||
|
||
new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params) | ||
new_rv = new_rv_node.outputs[-1] | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name info should not be handled by this function. I think somewhere in the model we create constants or shared variables for dim sizes and there is where we should set the names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you give me a hint roughly where this should be? Before or after the call to change_rv_size?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be done here where the dims variables are created:
pymc/pymc/model.py
Lines 1146 to 1149 in be048a4
Both constants and shared variables accept names during construction.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94, I looked into it and when I set the names in the suggested lines, I receive only the following output from dprint:
only after passing the name to new_size in aesaraf.py::change_rv_size, I get the change
pymc/pymc/aesaraf.py
Line 185 in be048a4
If this is okay, I'll implement the changes in this PR
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR added a name for the mutable case, but not the constant one: f45ca4a, I think adding it to the constant is all that's needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I will try to give it another go in this week
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94, I'm dancing around this PR for quite a bit now, but I'm not finding the time to do it. In the beginning of next year I would have more time, but if it bothers you to have an open PR for such a long time or someone else takes over, I won't mind if I don't get to finish it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries. The original issue lost some of it's relevance when we switched from defaulting resizable dims to constant dims. It's still nice but not urgent by any means. Feel free to come back to it whenever you find the time.