Skip to content

Commit 2296350

Browse files
committed
Remove support for "dims-on-the-fly"
1 parent 91cbebd commit 2296350

File tree

5 files changed

+38
-107
lines changed

5 files changed

+38
-107
lines changed

pymc/distributions/distribution.py

Lines changed: 9 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from aesara import tensor as at
3030
from aesara.compile.builders import OpFromGraph
3131
from aesara.graph import node_rewriter
32-
from aesara.graph.basic import Node, Variable, clone_replace
32+
from aesara.graph.basic import Node, clone_replace
3333
from aesara.graph.rewriting.basic import in2out
3434
from aesara.graph.utils import MetaType
3535
from aesara.tensor.basic import as_tensor_variable
@@ -42,9 +42,6 @@
4242
from pymc.distributions.shape_utils import (
4343
Dims,
4444
Shape,
45-
StrongDims,
46-
StrongShape,
47-
change_dist_size,
4845
convert_dims,
4946
convert_shape,
5047
convert_size,
@@ -154,35 +151,6 @@ def fn(*args, **kwargs):
154151
return fn
155152

156153

157-
def _make_rv_and_resize_shape_from_dims(
158-
*,
159-
cls,
160-
dims: Optional[StrongDims],
161-
model,
162-
observed,
163-
args,
164-
**kwargs,
165-
) -> Tuple[Variable, StrongShape]:
166-
"""Creates the RV, possibly using dims or observed to determine a resize shape (if needed)."""
167-
resize_shape_from_dims = None
168-
size_or_shape = kwargs.get("size") or kwargs.get("shape")
169-
170-
# Preference is given to size or shape. If not specified, we rely on dims and
171-
# finally, observed, to determine the shape of the variable. Because dims can be
172-
# specified on the fly, we need a two-step process where we first create the RV
173-
# without dims information and then resize it.
174-
if not size_or_shape and observed is not None:
175-
kwargs["shape"] = tuple(observed.shape)
176-
177-
# Create the RV without dims information
178-
rv_out = cls.dist(*args, **kwargs)
179-
180-
if not size_or_shape and dims is not None:
181-
resize_shape_from_dims = shape_from_dims(dims, tuple(rv_out.shape), model)
182-
183-
return rv_out, resize_shape_from_dims
184-
185-
186154
class SymbolicRandomVariable(OpFromGraph):
187155
"""Symbolic Random Variable
188156
@@ -311,17 +279,15 @@ def __new__(
311279
if observed is not None:
312280
observed = convert_observed_data(observed)
313281

314-
# Create the RV, without taking `dims` into consideration
315-
rv_out, resize_shape_from_dims = _make_rv_and_resize_shape_from_dims(
316-
cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
317-
)
282+
# Preference is given to size or shape. If not specified, we rely on dims and
283+
# finally, observed, to determine the shape of the variable.
284+
if not ("size" in kwargs or "shape" in kwargs):
285+
if dims is not None:
286+
kwargs["shape"] = shape_from_dims(dims, model)
287+
elif observed is not None:
288+
kwargs["shape"] = tuple(observed.shape)
318289

319-
# Resize variable based on `dims` information
320-
if resize_shape_from_dims:
321-
resize_size_from_dims = find_size(
322-
shape=resize_shape_from_dims, size=None, ndim_supp=rv_out.owner.op.ndim_supp
323-
)
324-
rv_out = change_dist_size(dist=rv_out, new_size=resize_size_from_dims, expand=False)
290+
rv_out = cls.dist(*args, **kwargs)
325291

326292
rv_out = model.register_rv(
327293
rv_out,

pymc/distributions/shape_utils.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -480,17 +480,13 @@ def convert_size(size: Size) -> Optional[StrongSize]:
480480
return size
481481

482482

483-
def shape_from_dims(
484-
dims: StrongDims, shape_implied: Sequence[TensorVariable], model
485-
) -> StrongShape:
483+
def shape_from_dims(dims: StrongDims, model) -> StrongShape:
486484
"""Determines shape from a `dims` tuple.
487485
488486
Parameters
489487
----------
490488
dims : array-like
491489
A vector of dimension names or None.
492-
shape_implied : tensor_like of int
493-
Shape of RV implied from its inputs alone.
494490
model : pm.Model
495491
The current model on stack.
496492
@@ -499,20 +495,15 @@ def shape_from_dims(
499495
dims : tuple of (str or None)
500496
Names or None for all RV dimensions.
501497
"""
502-
ndim_resize = len(dims) - len(shape_implied)
503498

504-
# Dims must be known already or be inferrable from implied dimensions of the RV
505-
unknowndim_resize_dims = set(dims[:ndim_resize]) - set(model.dim_lengths)
506-
if unknowndim_resize_dims:
499+
# Dims must be known already
500+
unknowndim_dims = set(dims) - set(model.dim_lengths)
501+
if unknowndim_dims:
507502
raise KeyError(
508-
f"Dimensions {unknowndim_resize_dims} are unknown to the model and cannot be used to specify a `size`."
503+
f"Dimensions {unknowndim_dims} are unknown to the model and cannot be used to specify a `shape`."
509504
)
510505

511-
# The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
512-
return tuple(
513-
model.dim_lengths[dname] if dname in model.dim_lengths else shape_implied[i]
514-
for i, dname in enumerate(dims)
515-
)
506+
return tuple(model.dim_lengths[dname] for dname in dims)
516507

517508

518509
def find_size(

pymc/tests/distributions/test_shape_utils.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -312,39 +312,16 @@ def test_simultaneous_dims_and_observed(self):
312312
assert pmodel.RV_dims["y"] == ("ddata",)
313313
assert y.eval().shape == (3,)
314314

315-
def test_define_dims_on_the_fly(self):
315+
def test_define_dims_on_the_fly_raises(self):
316+
# Check that trying to use dims that are not pre-specified fails, even if their
317+
# length could be inferred from the shape of the variables
318+
msg = "Dimensions {'patient'} are unknown to the model"
316319
with pm.Model() as pmodel:
317-
agedata = aesara.shared(np.array([10, 20, 30]))
320+
with pytest.raises(KeyError, match=msg):
321+
pm.Normal("x", [0, 1, 2], dims=("patient",))
318322

319-
# Associate the "patient" dim with an implied dimension
320-
age = pm.Normal("age", agedata, dims=("patient",))
321-
assert "patient" in pmodel.dim_lengths
322-
assert pmodel.dim_lengths["patient"].eval() == 3
323-
324-
# Use the dim to replicate a new RV
325-
effect = pm.Normal("effect", 0, dims=("patient",))
326-
assert effect.ndim == 1
327-
assert effect.eval().shape == (3,)
328-
329-
# Now change the length of the implied dimension
330-
agedata.set_value([1, 2, 3, 4])
331-
# The change should propagate all the way through
332-
assert effect.eval().shape == (4,)
333-
334-
def test_define_dims_on_the_fly_from_observed(self):
335-
with pm.Model() as pmodel:
336-
data = aesara.shared(np.zeros((4, 5)))
337-
x = pm.Normal("x", observed=data, dims=("patient", "trials"))
338-
assert pmodel.dim_lengths["patient"].eval() == 4
339-
assert pmodel.dim_lengths["trials"].eval() == 5
340-
341-
# Use dim to create a new RV
342-
x_noisy = pm.Normal("x_noisy", 0, dims=("patient", "trials"))
343-
assert x_noisy.eval().shape == (4, 5)
344-
345-
# Change data patient dims
346-
data.set_value(np.zeros((10, 6)))
347-
assert x_noisy.eval().shape == (10, 6)
323+
with pytest.raises(KeyError, match=msg):
324+
pm.Normal("x", observed=[0, 1, 2], dims=("patient",))
348325

349326
def test_can_resize_data_defined_size(self):
350327
with pm.Model() as pmodel:

pymc/tests/test_data.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import pymc as pm
2929

3030
from pymc.aesaraf import GeneratorOp, floatX
31-
from pymc.exceptions import ShapeError
3231
from pymc.tests.helpers import SeededTest, select_by_precision
3332

3433

@@ -371,20 +370,6 @@ def test_symbolic_coords(self):
371370
assert pmodel.dim_lengths["row"].eval() == 4
372371
assert pmodel.dim_lengths["column"].eval() == 5
373372

374-
def test_no_resize_of_implied_dimensions(self):
375-
with pm.Model() as pmodel:
376-
# Imply a dimension through RV params
377-
pm.Normal("n", mu=[1, 2, 3], dims="city")
378-
# _Use_ the dimension for a data variable
379-
inhabitants = pm.MutableData("inhabitants", [100, 200, 300], dims="city")
380-
381-
# Attempting to re-size the dimension through the data variable would
382-
# cause shape problems in InferenceData conversion, because the RV remains (3,).
383-
with pytest.raises(
384-
ShapeError, match="was initialized from 'n' which is not a shared variable"
385-
):
386-
pmodel.set_data("inhabitants", [1, 2, 3, 4])
387-
388373
def test_implicit_coords_series(self):
389374
pd = pytest.importorskip("pandas")
390375
ser_sales = pd.Series(

pymc/tests/test_model.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def test_shapeerror_from_resize_immutable_dim_from_RV():
742742
Even if the variable being updated is a SharedVariable and has other
743743
dimensions that are mutable.
744744
"""
745-
with pm.Model() as pmodel:
745+
with pm.Model(coords={"fixed": range(3)}) as pmodel:
746746
pm.Normal("a", mu=[1, 2, 3], dims="fixed")
747747
assert isinstance(pmodel.dim_lengths["fixed"], TensorVariable)
748748

@@ -751,7 +751,8 @@ def test_shapeerror_from_resize_immutable_dim_from_RV():
751751
# This is fine because the "fixed" dim is not resized
752752
pmodel.set_data("m", [[1, 2, 3], [3, 4, 5]])
753753

754-
with pytest.raises(ShapeError, match="was initialized from 'a'"):
754+
msg = "The 'm' variable already had 3 coord values defined for its fixed dimension"
755+
with pytest.raises(ValueError, match=msg):
755756
# Can't work because the "fixed" dimension is linked to a
756757
# TensorVariable with constant shape.
757758
# Note that the new data tries to change both dimensions
@@ -826,7 +827,7 @@ def test_set_dim():
826827

827828

828829
def test_set_dim_with_coords():
829-
"""Test the concious re-sizing of dims created through add_coord() with coord value."""
830+
"""Test the conscious re-sizing of dims created through add_coord() with coord value."""
830831
with pm.Model() as pmodel:
831832
pmodel.add_coord("mdim", mutable=True, length=2, values=["A", "B"])
832833
a = pm.Normal("a", dims="mdim")
@@ -904,6 +905,17 @@ def test_set_data_indirect_resize_with_coords():
904905
pmodel.set_data("mdata", [1, 2], coords=dict(mdim=[1, 2, 3]))
905906

906907

908+
def test_set_data_constant_shape_error():
909+
with pm.Model() as pmodel:
910+
x = pm.Normal("x", size=7)
911+
pmodel.add_coord("weekday", length=x.shape[0])
912+
pm.MutableData("y", np.arange(7), dims="weekday")
913+
914+
msg = "because the dimension was initialized from 'x' which is not a shared variable"
915+
with pytest.raises(ShapeError, match=msg):
916+
pmodel.set_data("y", np.arange(10))
917+
918+
907919
def test_model_logpt_deprecation_warning():
908920
with pm.Model() as m:
909921
x = pm.Normal("x", 0, 1, size=2)

0 commit comments

Comments
 (0)