Skip to content

Commit 632f1a1

Browse files
committed
Return scalar variables instead of 1D
1 parent de70d25 commit 632f1a1

File tree

2 files changed

+42
-28
lines changed

2 files changed

+42
-28
lines changed

pymc/sampling.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,17 +2143,19 @@ def draw(
21432143
assert draws[2].shape == (num_draws, 5)
21442144
"""
21452145

2146-
if not isinstance(vars, (list, tuple)):
2147-
vars = [vars]
2148-
21492146
draw_fn = compile_pymc(inputs=[], outputs=vars, mode=mode, **kwargs)
2150-
drawn_values = zip(*(draw_fn() for _ in range(draws)))
2151-
drawn_values = [np.stack(v) for v in drawn_values]
21522147

2153-
# If only one variable, return the numpy array instead of a list of numpy arrays
21542148
if draws == 1:
2155-
return drawn_values[0]
2156-
return drawn_values
2149+
return draw_fn()
2150+
2151+
# Single variable output
2152+
if not isinstance(vars, (list, tuple)):
2153+
drawn_values = (draw_fn() for _ in range(draws))
2154+
return np.stack(drawn_values)
2155+
2156+
# Multiple variable output
2157+
drawn_values = zip(*(draw_fn() for _ in range(draws)))
2158+
return [np.stack(v) for v in drawn_values]
21572159

21582160

21592161
def _init_jitter(

pymc/tests/test_sampling.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,44 +1215,56 @@ def test_sample_deterministic():
12151215

12161216

12171217
class TestDraw(SeededTest):
1218-
def test_draw_one_variable(self):
1218+
def test_univariate(self):
12191219
with pm.Model():
12201220
x = pm.Normal("x")
12211221

12221222
x_draws = pm.draw(x)
1223-
assert x_draws.shape == (1,)
1223+
assert x_draws.shape == ()
12241224

1225-
def test_draw_several_variables(self):
1225+
(x_draws,) = pm.draw([x])
1226+
assert x_draws.shape == ()
1227+
1228+
x_draws = pm.draw(x, draws=10)
1229+
assert x_draws.shape == (10,)
1230+
1231+
(x_draws,) = pm.draw([x], draws=10)
1232+
assert x_draws.shape == (10,)
1233+
1234+
def test_multivariate(self):
1235+
with pm.Model():
1236+
mln = pm.Multinomial("mln", n=5, p=np.array([0.25, 0.25, 0.25, 0.25]))
1237+
1238+
mln_draws = pm.draw(mln, draws=1)
1239+
assert mln_draws.shape == (4,)
1240+
1241+
(mln_draws,) = pm.draw([mln], draws=1)
1242+
assert mln_draws.shape == (4,)
1243+
1244+
mln_draws = pm.draw(mln, draws=10)
1245+
assert mln_draws.shape == (10, 4)
1246+
1247+
(mln_draws,) = pm.draw([mln], draws=10)
1248+
assert mln_draws.shape == (10, 4)
1249+
1250+
def test_multiple_variables(self):
12261251
with pm.Model():
12271252
x = pm.Normal("x")
12281253
y = pm.Normal("y", shape=10)
12291254
z = pm.Uniform("z", shape=5)
1255+
w = pm.Dirichlet("w", a=[1, 1, 1])
12301256

1231-
num_draws = 1000
1232-
# Draw samples of a list variables
1233-
draws = pm.draw([x, y, z], draws=num_draws)
1234-
assert draws[0].shape == (num_draws,)
1235-
assert draws[1].shape == (num_draws, 10)
1236-
assert draws[2].shape == (num_draws, 5)
1237-
1238-
# Draw samples of a tuple variables
1239-
draws = pm.draw((x, y, z), draws=num_draws)
1257+
num_draws = 100
1258+
draws = pm.draw((x, y, z, w), draws=num_draws)
12401259
assert draws[0].shape == (num_draws,)
12411260
assert draws[1].shape == (num_draws, 10)
12421261
assert draws[2].shape == (num_draws, 5)
1243-
1244-
def test_multivariate(self):
1245-
with pm.Model():
1246-
mln = pm.Multinomial("mln", n=5, p=np.array([0.25, 0.25, 0.25, 0.25]))
1247-
1248-
mln_draws = pm.draw(mln, draws=100)
1249-
assert mln_draws.shape == (100, 4)
1262+
assert draws[3].shape == (num_draws, 3)
12501263

12511264
def test_draw_different_samples(self):
12521265
with pm.Model():
12531266
x = pm.Normal("x")
12541267

12551268
x_draws_1 = pm.draw(x, 100)
12561269
x_draws_2 = pm.draw(x, 100)
1257-
# Check if the draw function will draw different samples each time
12581270
assert not np.all(np.isclose(x_draws_1, x_draws_2))

0 commit comments

Comments
 (0)