Skip to content

Commit c7279b5

Browse files
authored
Remove shape functions (#6556)
1 parent 0334994 commit c7279b5

File tree

4 files changed

+12
-364
lines changed

4 files changed

+12
-364
lines changed

docs/source/api/shape_utils.rst

-4
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ This module introduces functions that are made aware of the requested `size_tupl
1414
:toctree: generated/
1515

1616
to_tuple
17-
shapes_broadcasting
1817
broadcast_dist_samples_shape
19-
get_broadcastable_dist_samples
20-
broadcast_distribution_samples
21-
broadcast_dist_samples_to
2218
rv_size_is_none
2319
change_dist_size

pymc/distributions/multivariate.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
)
6060
from pymc.distributions.shape_utils import (
6161
_change_dist_size,
62-
broadcast_dist_samples_to,
62+
broadcast_dist_samples_shape,
6363
change_dist_size,
6464
get_support_shape,
6565
rv_size_is_none,
@@ -1651,7 +1651,9 @@ def rng_fn(cls, rng, mu, rowchol, colchol, size=None):
16511651
output_shape = size + dist_shape
16521652

16531653
# Broadcasting all parameters
1654-
(mu,) = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)
1654+
shapes = [mu.shape, output_shape]
1655+
broadcastable_shape = broadcast_dist_samples_shape(shapes, size=size)
1656+
mu = np.broadcast_to(mu, shape=broadcastable_shape)
16551657
rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:])
16561658

16571659
colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:])

pymc/distributions/shape_utils.py

+4-259
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,8 @@
3838
from pymc.pytensorf import convert_observed_data
3939

4040
__all__ = [
41-
"to_tuple",
42-
"shapes_broadcasting",
4341
"broadcast_dist_samples_shape",
44-
"get_broadcastable_dist_samples",
45-
"broadcast_distribution_samples",
46-
"broadcast_dist_samples_to",
42+
"to_tuple",
4743
"rv_size_is_none",
4844
"change_dist_size",
4945
]
@@ -91,47 +87,6 @@ def _check_shape_type(shape):
9187
return tuple(out)
9288

9389

94-
def shapes_broadcasting(*args, raise_exception=False):
95-
"""Return the shape resulting from broadcasting multiple shapes.
96-
Represents numpy's broadcasting rules.
97-
98-
Parameters
99-
----------
100-
*args: array-like of int
101-
Tuples or arrays or lists representing the shapes of arrays to be
102-
broadcast.
103-
raise_exception: bool (optional)
104-
Controls whether to raise an exception or simply return `None` if
105-
the broadcasting fails.
106-
107-
Returns
108-
-------
109-
Resulting shape. If broadcasting is not possible and `raise_exception` is
110-
False, then `None` is returned. If `raise_exception` is `True`, a
111-
`ValueError` is raised.
112-
"""
113-
x = list(_check_shape_type(args[0])) if args else ()
114-
for arg in args[1:]:
115-
y = list(_check_shape_type(arg))
116-
if len(x) < len(y):
117-
x, y = y, x
118-
if len(y) > 0:
119-
x[-len(y) :] = [
120-
j if i == 1 else i if j == 1 else i if i == j else 0
121-
for i, j in zip(x[-len(y) :], y)
122-
]
123-
if not all(x):
124-
if raise_exception:
125-
raise ValueError(
126-
"Supplied shapes {} do not broadcast together".format(
127-
", ".join([f"{a}" for a in args])
128-
)
129-
)
130-
else:
131-
return None
132-
return tuple(x)
133-
134-
13590
def broadcast_dist_samples_shape(shapes, size=None):
13691
"""Apply shape broadcasting to shape tuples but assuming that the shapes
13792
correspond to draws from random variables, with the `size` tuple possibly
@@ -152,27 +107,22 @@ def broadcast_dist_samples_shape(shapes, size=None):
152107
Examples
153108
--------
154109
.. code-block:: python
155-
156110
size = 100
157111
shape0 = (size,)
158112
shape1 = (size, 5)
159113
shape2 = (size, 4, 5)
160114
out = broadcast_dist_samples_shape([shape0, shape1, shape2],
161115
size=size)
162116
assert out == (size, 4, 5)
163-
164117
.. code-block:: python
165-
166118
size = 100
167119
shape0 = (size,)
168120
shape1 = (5,)
169121
shape2 = (4, 5)
170122
out = broadcast_dist_samples_shape([shape0, shape1, shape2],
171123
size=size)
172124
assert out == (size, 4, 5)
173-
174125
.. code-block:: python
175-
176126
size = 100
177127
shape0 = (1,)
178128
shape1 = (5,)
@@ -182,7 +132,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
182132
assert out == (4, 5)
183133
"""
184134
if size is None:
185-
broadcasted_shape = shapes_broadcasting(*shapes)
135+
broadcasted_shape = np.broadcast_shapes(*shapes)
186136
if broadcasted_shape is None:
187137
raise ValueError(
188138
"Cannot broadcast provided shapes {} given size: {}".format(
@@ -195,7 +145,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
195145
# samples shapes without the size prepend
196146
sp_shapes = [s[len(_size) :] if _size == s[: min([len(_size), len(s)])] else s for s in shapes]
197147
try:
198-
broadcast_shape = shapes_broadcasting(*sp_shapes, raise_exception=True)
148+
broadcast_shape = np.broadcast_shapes(*sp_shapes)
199149
except ValueError:
200150
raise ValueError(
201151
"Cannot broadcast provided shapes {} given size: {}".format(
@@ -215,212 +165,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
215165
else:
216166
p_shape = shape
217167
broadcastable_shapes.append(p_shape)
218-
return shapes_broadcasting(*broadcastable_shapes, raise_exception=True)
219-
220-
221-
def get_broadcastable_dist_samples(
222-
samples, size=None, must_bcast_with=None, return_out_shape=False
223-
):
224-
"""Get a view of the samples drawn from distributions which adds new axes
225-
in between the `size` prepend and the distribution's `shape`. These views
226-
should be able to broadcast the samples from the distrubtions taking into
227-
account the `size` (i.e. the number of samples) of the draw, which is
228-
prepended to the sample's `shape`. Optionally, one can supply an extra
229-
`must_bcast_with` to try to force samples to be able to broadcast with a
230-
given shape. A `ValueError` is raised if it is not possible to broadcast
231-
the provided samples.
232-
233-
Parameters
234-
----------
235-
samples: Iterable of ndarrays holding the sampled values
236-
size: None, int or tuple (optional)
237-
size of the sample set requested.
238-
must_bcast_with: None, int or tuple (optional)
239-
Tuple shape to which the samples must be able to broadcast
240-
return_out_shape: bool (optional)
241-
If `True`, this function also returns the output's shape and not only
242-
samples views.
243-
244-
Returns
245-
-------
246-
broadcastable_samples: List of the broadcasted sample arrays
247-
broadcast_shape: If `return_out_shape` is `True`, the resulting broadcast
248-
shape is returned.
249-
250-
Examples
251-
--------
252-
.. code-block:: python
253-
254-
must_bcast_with = (3, 1, 5)
255-
size = 100
256-
sample0 = np.random.randn(size)
257-
sample1 = np.random.randn(size, 5)
258-
sample2 = np.random.randn(size, 4, 5)
259-
out = broadcast_dist_samples_to(
260-
[sample0, sample1, sample2],
261-
size=size,
262-
must_bcast_with=must_bcast_with,
263-
)
264-
assert out[0].shape == (size, 1, 1, 1)
265-
assert out[1].shape == (size, 1, 1, 5)
266-
assert out[2].shape == (size, 1, 4, 5)
267-
assert np.all(sample0[:, None, None, None] == out[0])
268-
assert np.all(sample1[:, None, None] == out[1])
269-
assert np.all(sample2[:, None] == out[2])
270-
271-
.. code-block:: python
272-
273-
size = 100
274-
must_bcast_with = (3, 1, 5)
275-
sample0 = np.random.randn(size)
276-
sample1 = np.random.randn(5)
277-
sample2 = np.random.randn(4, 5)
278-
out = broadcast_dist_samples_to(
279-
[sample0, sample1, sample2],
280-
size=size,
281-
must_bcast_with=must_bcast_with,
282-
)
283-
assert out[0].shape == (size, 1, 1, 1)
284-
assert out[1].shape == (5,)
285-
assert out[2].shape == (4, 5)
286-
assert np.all(sample0[:, None, None, None] == out[0])
287-
assert np.all(sample1 == out[1])
288-
assert np.all(sample2 == out[2])
289-
"""
290-
samples = [np.asarray(p) for p in samples]
291-
_size = to_tuple(size)
292-
must_bcast_with = to_tuple(must_bcast_with)
293-
# Raw samples shapes
294-
p_shapes = [p.shape for p in samples] + [_check_shape_type(must_bcast_with)]
295-
out_shape = broadcast_dist_samples_shape(p_shapes, size=size)
296-
# samples shapes without the size prepend
297-
sp_shapes = [
298-
s[len(_size) :] if _size == s[: min([len(_size), len(s)])] else s for s in p_shapes
299-
]
300-
broadcast_shape = shapes_broadcasting(*sp_shapes, raise_exception=True)
301-
broadcastable_samples = []
302-
for param, p_shape, sp_shape in zip(samples, p_shapes, sp_shapes):
303-
if _size == p_shape[: min([len(_size), len(p_shape)])]:
304-
# If size prepends the shape, then we have to add broadcasting axis
305-
# in the middle
306-
slicer_head = [slice(None)] * len(_size)
307-
slicer_tail = [np.newaxis] * (len(broadcast_shape) - len(sp_shape)) + [
308-
slice(None)
309-
] * len(sp_shape)
310-
else:
311-
# If size does not prepend the shape, then we have leave the
312-
# parameter as is
313-
slicer_head = []
314-
slicer_tail = [slice(None)] * len(sp_shape)
315-
broadcastable_samples.append(param[tuple(slicer_head + slicer_tail)])
316-
if return_out_shape:
317-
return broadcastable_samples, out_shape
318-
else:
319-
return broadcastable_samples
320-
321-
322-
def broadcast_distribution_samples(samples, size=None):
323-
"""Broadcast samples drawn from distributions taking into account the
324-
size (i.e. the number of samples) of the draw, which is prepended to
325-
the sample's shape.
326-
327-
Parameters
328-
----------
329-
samples: Iterable of ndarrays holding the sampled values
330-
size: None, int or tuple (optional)
331-
size of the sample set requested.
332-
333-
Returns
334-
-------
335-
List of broadcasted sample arrays
336-
337-
Examples
338-
--------
339-
.. code-block:: python
340-
341-
size = 100
342-
sample0 = np.random.randn(size)
343-
sample1 = np.random.randn(size, 5)
344-
sample2 = np.random.randn(size, 4, 5)
345-
out = broadcast_distribution_samples([sample0, sample1, sample2],
346-
size=size)
347-
assert all((o.shape == (size, 4, 5) for o in out))
348-
assert np.all(sample0[:, None, None] == out[0])
349-
assert np.all(sample1[:, None, :] == out[1])
350-
assert np.all(sample2 == out[2])
351-
352-
.. code-block:: python
353-
354-
size = 100
355-
sample0 = np.random.randn(size)
356-
sample1 = np.random.randn(5)
357-
sample2 = np.random.randn(4, 5)
358-
out = broadcast_distribution_samples([sample0, sample1, sample2],
359-
size=size)
360-
assert all((o.shape == (size, 4, 5) for o in out))
361-
assert np.all(sample0[:, None, None] == out[0])
362-
assert np.all(sample1 == out[1])
363-
assert np.all(sample2 == out[2])
364-
"""
365-
return np.broadcast_arrays(*get_broadcastable_dist_samples(samples, size=size))
366-
367-
368-
def broadcast_dist_samples_to(to_shape, samples, size=None):
369-
"""Broadcast samples drawn from distributions to a given shape, taking into
370-
account the size (i.e. the number of samples) of the draw, which is
371-
prepended to the sample's shape.
372-
373-
Parameters
374-
----------
375-
to_shape: Tuple shape onto which the samples must be able to broadcast
376-
samples: Iterable of ndarrays holding the sampled values
377-
size: None, int or tuple (optional)
378-
size of the sample set requested.
379-
380-
Returns
381-
-------
382-
List of the broadcasted sample arrays
383-
384-
Examples
385-
--------
386-
.. code-block:: python
387-
388-
to_shape = (3, 1, 5)
389-
size = 100
390-
sample0 = np.random.randn(size)
391-
sample1 = np.random.randn(size, 5)
392-
sample2 = np.random.randn(size, 4, 5)
393-
out = broadcast_dist_samples_to(
394-
to_shape,
395-
[sample0, sample1, sample2],
396-
size=size
397-
)
398-
assert np.all((o.shape == (size, 3, 4, 5) for o in out))
399-
assert np.all(sample0[:, None, None, None] == out[0])
400-
assert np.all(sample1[:, None, None] == out[1])
401-
assert np.all(sample2[:, None] == out[2])
402-
403-
.. code-block:: python
404-
405-
size = 100
406-
to_shape = (3, 1, 5)
407-
sample0 = np.random.randn(size)
408-
sample1 = np.random.randn(5)
409-
sample2 = np.random.randn(4, 5)
410-
out = broadcast_dist_samples_to(
411-
to_shape,
412-
[sample0, sample1, sample2],
413-
size=size
414-
)
415-
assert np.all((o.shape == (size, 3, 4, 5) for o in out))
416-
assert np.all(sample0[:, None, None, None] == out[0])
417-
assert np.all(sample1 == out[1])
418-
assert np.all(sample2 == out[2])
419-
"""
420-
samples, to_shape = get_broadcastable_dist_samples(
421-
samples, size=size, must_bcast_with=to_shape, return_out_shape=True
422-
)
423-
return [np.broadcast_to(o, to_shape) for o in samples]
168+
return np.broadcast_shapes(*broadcastable_shapes)
424169

425170

426171
# User-provided can be lazily specified as scalars

0 commit comments

Comments
 (0)