Skip to content

Commit b5a5b56

Browse files
Remove size from Distribution signatures (#5788)
* Remove size from .dist() signature Closes #5754 * Align (Half)Flat signatures with superclass * Don't mention size in the docstring Co-authored-by: Ricardo Vieira <[email protected]> * Revert changes in `rng_fn` signatures Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]>
1 parent d0af6b1 commit b5a5b56

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

pymc/distributions/continuous.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,8 @@ def __new__(cls, *args, **kwargs):
364364
return super().__new__(cls, *args, **kwargs)
365365

366366
@classmethod
367-
def dist(cls, *, size=None, **kwargs):
368-
res = super().dist([], size=size, **kwargs)
367+
def dist(cls, **kwargs):
368+
res = super().dist([], **kwargs)
369369
return res
370370

371371
def moment(rv, size):
@@ -432,8 +432,8 @@ def __new__(cls, *args, **kwargs):
432432
return super().__new__(cls, *args, **kwargs)
433433

434434
@classmethod
435-
def dist(cls, *, size=None, **kwargs):
436-
res = super().dist([], size=size, **kwargs)
435+
def dist(cls, **kwargs):
436+
res = super().dist([], **kwargs)
437437
return res
438438

439439
def moment(rv, size):

pymc/distributions/distribution.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,8 @@ def __new__(
226226
transform : optional
227227
See ``Model.register_rv``.
228228
**kwargs
229-
Keyword arguments that will be forwarded to ``.dist()``.
230-
Most prominently: ``shape`` and ``size``
229+
Keyword arguments that will be forwarded to ``.dist()`` or the Aesara RV Op.
230+
Most prominently: ``shape`` for ``.dist()`` or ``dtype`` for the Op.
231231
232232
Returns
233233
-------
@@ -298,7 +298,6 @@ def dist(
298298
dist_params,
299299
*,
300300
shape: Optional[Shape] = None,
301-
size: Optional[Size] = None,
302301
**kwargs,
303302
) -> RandomVariable:
304303
"""Creates a RandomVariable corresponding to the `cls` distribution.
@@ -312,8 +311,9 @@ def dist(
312311
313312
An Ellipsis (...) may be inserted in the last position to short-hand refer to
314313
all the dimensions that the RV would get if no shape/size/dims were passed at all.
315-
size : int, tuple, Variable, optional
316-
For creating the RV like in Aesara/NumPy.
314+
**kwargs
315+
Keyword arguments that will be forwarded to the Aesara RV Op.
316+
Most prominently: ``size`` or ``dtype``.
317317
318318
Returns
319319
-------
@@ -337,6 +337,7 @@ def dist(
337337

338338
if "dims" in kwargs:
339339
raise NotImplementedError("The use of a `.dist(dims=...)` API is not supported.")
340+
size = kwargs.pop("size", None)
340341
if shape is not None and size is not None:
341342
raise ValueError(
342343
f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!"

0 commit comments

Comments
 (0)