Skip to content

pass names from RV dims to change_rv_size #5931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def convert_observed_data(data):
def change_rv_size(
rv: TensorVariable,
new_size: PotentialShapeType,
new_size_dims: Optional[tuple] = (None,),
Copy link
Member

@ricardoV94 ricardoV94 Jun 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name info should not be handled by this function. I think somewhere in the model we create constants or shared variables for dim sizes and there is where we should set the names.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give me a hint roughly where this should be? Before or after the call to change_rv_size?

Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be done here where the dims variables are created:

pymc/pymc/model.py

Lines 1146 to 1149 in be048a4

if mutable:
length = aesara.shared(length)
else:
length = aesara.tensor.constant(length)

Both constants and shared variables accept names during construction.

Copy link
Author

@flo-schu flo-schu Jul 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94, I looked into it and when I set the names in the suggested lines, I receive only the following output from dprint:

normal_rv{0, (0, 0), floatX, False}.1 [id A] 'x'
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F99BD4B6960>) [id B]
 |TensorConstant{(1,) of 5} [id C]
 |TensorConstant{11} [id D]
 |TensorConstant{0} [id E]
 |TensorConstant{1.0} [id F]

only after passing the name to new_size in aesaraf.py::change_rv_size, I get the change

new_size = at.as_tensor(new_size, ndim=1, dtype="int64")

new_size = at.as_tensor(new_size, ndim=1, dtype="int64", name=new_size[0].name)
normal_rv{0, (0, 0), floatX, False}.1 [id A] 'x'
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FAD2928A960>) [id B]
 |cities_dim{(1,) of 5} [id C]
 |TensorConstant{11} [id D]
 |TensorConstant{0} [id E]
 |TensorConstant{1.0} [id F]

If this is okay, I'll implement the changes in this PR

Copy link
Member

@ricardoV94 ricardoV94 Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR added a name for the mutable case, but not the constant one: f45ca4a, I think adding it to the constant is all that's needed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will try to give it another go in this week

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94, I'm dancing around this PR for quite a bit now, but I'm not finding the time to do it. In the beginning of next year I would have more time, but if it bothers you to have an open PR for such a long time or someone else takes over, I won't mind if I don't get to finish it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. The original issue lost some of it's relevance when we switched from defaulting resizable dims to constant dims. It's still nice but not urgent by any means. Feel free to come back to it whenever you find the time.

expand: Optional[bool] = False,
) -> TensorVariable:
"""Change or expand the size of a `RandomVariable`.
Expand All @@ -155,6 +156,8 @@ def change_rv_size(
The old `RandomVariable` output.
new_size
The new size.
new_size_dims
dim names of the new size vector
expand:
Expand the existing size by `new_size`.

Expand All @@ -166,6 +169,10 @@ def change_rv_size(
elif new_size_ndim == 0:
new_size = (new_size,)

# wrap None in tuple, if new_size_dims are None
if new_size_dims is None:
new_size_dims = (None,)

# Extract the RV node that is to be resized, together with its inputs, name and tag
assert rv.owner.op is not None
if isinstance(rv.owner.op, SpecifyShape):
Expand All @@ -180,9 +187,13 @@ def change_rv_size(
size = shape[: len(shape) - rv_node.op.ndim_supp]
new_size = tuple(new_size) + tuple(size)

# create the name of the RV's resizing tensor
# TODO: add information where the dim is coming from (obseverd, prior, ...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this todo going to be implemented in this PR or is this for later?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think it would be helpful. Just wanted to make sure I'm on the right track before I spend more time on it

new_size_name = f"Broadcast to {new_size_dims[0]}_dim"

# Make sure the new size is a tensor. This dtype-aware conversion helps
# to not unnecessarily pick up a `Cast` in some cases (see #4652).
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")
new_size = at.as_tensor(new_size, ndim=1, dtype="int64", name=new_size_name)

new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
new_rv = new_rv_node.outputs[-1]
Expand Down
4 changes: 3 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ def __new__(

if resize_shape:
# A batch size was specified through `dims`, or implied by `observed`.
rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True)
rv_out = change_rv_size(
rv=rv_out, new_size=resize_shape, new_size_dims=dims, expand=True
)

rv_out = model.register_rv(
rv_out,
Expand Down