Skip to content

Commit 6585f90

Browse files
committed
pass names from RV dims to change_rv_size
- adds argument new_size_dims to change_rv_size in aesaraf.py - tests if dims is None and wraps in tuple if RV has no dim - creates new_size_name by adding dim name and decorations - passes new_size_name as name argument to at.as_tensor method
1 parent be048a4 commit 6585f90

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

pymc/aesaraf.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def convert_observed_data(data):
145145
def change_rv_size(
146146
rv: TensorVariable,
147147
new_size: PotentialShapeType,
148+
new_size_dims: Optional[tuple] = (None,),
148149
expand: Optional[bool] = False,
149150
) -> TensorVariable:
150151
"""Change or expand the size of a `RandomVariable`.
@@ -155,6 +156,8 @@ def change_rv_size(
155156
The old `RandomVariable` output.
156157
new_size
157158
The new size.
159+
new_size_dims
160+
dim names of the new size vector
158161
expand:
159162
Expand the existing size by `new_size`.
160163
@@ -166,6 +169,10 @@ def change_rv_size(
166169
elif new_size_ndim == 0:
167170
new_size = (new_size,)
168171

172+
# wrap None in tuple, if new_size_dims are None
173+
if new_size_dims is None:
174+
new_size_dims = (None,)
175+
169176
# Extract the RV node that is to be resized, together with its inputs, name and tag
170177
assert rv.owner.op is not None
171178
if isinstance(rv.owner.op, SpecifyShape):
@@ -180,9 +187,13 @@ def change_rv_size(
180187
size = shape[: len(shape) - rv_node.op.ndim_supp]
181188
new_size = tuple(new_size) + tuple(size)
182189

190+
# create the name of the RV's resizing tensor
191+
# TODO: add information where the dim is coming from (obseverd, prior, ...)
192+
new_size_name = f"Broadcast to {new_size_dims[0]}_dim"
193+
183194
# Make sure the new size is a tensor. This dtype-aware conversion helps
184195
# to not unnecessarily pick up a `Cast` in some cases (see #4652).
185-
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")
196+
new_size = at.as_tensor(new_size, ndim=1, dtype="int64", name=new_size_name)
186197

187198
new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
188199
new_rv = new_rv_node.outputs[-1]

pymc/distributions/distribution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,9 @@ def __new__(
266266

267267
if resize_shape:
268268
# A batch size was specified through `dims`, or implied by `observed`.
269-
rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True)
269+
rv_out = change_rv_size(
270+
rv=rv_out, new_size=resize_shape, new_size_dims=dims, expand=True
271+
)
270272

271273
rv_out = model.register_rv(
272274
rv_out,

0 commit comments

Comments
 (0)