Skip to content

Commit 1e5867f

Browse files
Junpeng LaoColCarroll
Junpeng Lao
authored andcommitted
Improve random for ObservedRV (#3036)
To generate random array for ObservedRV in sample_ppc and sample_prior_predictive a shape kwarg sometimes is needed to make sure the generated array has the same shape as the observed. This PR improved that to make sure the shape is inferred during random number generation. Previously I tried to add the shape attr to ObservedRV at the model creation, but turns out it breaks the sample_ppc for theano.shared observation. I try a different approach here to always reset the distribution shape at generation.
1 parent 47982b4 commit 1e5867f

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

pymc3/distributions/distribution.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,17 @@ def _draw_value(param, point=None, givens=None, size=None):
383383
elif (hasattr(param, 'distribution') and
384384
hasattr(param.distribution, 'random') and
385385
param.distribution.random is not None):
386-
return param.distribution.random(point=point, size=size)
386+
# reset the dist shape for ObservedRV
387+
if hasattr(param, 'observations'):
388+
dist_tmp = param.distribution
389+
try:
390+
distshape = param.observations.shape.eval()
391+
except AttributeError:
392+
distshape = param.observations.shape
393+
dist_tmp.shape = distshape
394+
return dist_tmp.random(point=point, size=size)
395+
else:
396+
return param.distribution.random(point=point, size=size)
387397
else:
388398
if givens:
389399
variables, values = list(zip(*givens))

pymc3/tests/test_sampling.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def test_ignores_observed(self):
353353
assert (prior['mu'] < 90).all()
354354
assert (prior['positive_mu'] > 90).all()
355355
assert (prior['x_obs'] < 90).all()
356+
assert prior['x_obs'].shape == (500, 200)
356357
npt.assert_array_almost_equal(prior['positive_mu'], np.abs(prior['mu']), decimal=4)
357358

358359
def test_respects_shape(self):
@@ -395,9 +396,28 @@ def test_transformed(self):
395396

396397
thetas = pm.Beta('thetas', alpha=phi*kappa, beta=(1.0-phi)*kappa, shape=n)
397398

398-
y = pm.Binomial('y', n=at_bats, p=thetas, shape=n, observed=hits)
399+
y = pm.Binomial('y', n=at_bats, p=thetas, observed=hits)
399400
gen = pm.sample_prior_predictive(draws)
400401

401402
assert gen['phi'].shape == (draws,)
402403
assert gen['y'].shape == (draws, n)
403-
assert 'thetas_logodds__' in gen
404+
assert 'thetas_logodds__' in gen
405+
406+
def test_shared(self):
407+
n1 = 10
408+
obs = shared(np.random.rand(n1) < .5)
409+
draws = 50
410+
411+
with pm.Model() as m:
412+
p = pm.Beta('p', 1., 1.)
413+
y = pm.Bernoulli('y', p, observed=obs)
414+
gen1 = pm.sample_prior_predictive(draws)
415+
416+
assert gen1['y'].shape == (draws, n1)
417+
418+
n2 = 20
419+
obs.set_value(np.random.rand(n2) < .5)
420+
with m:
421+
gen2 = pm.sample_prior_predictive(draws)
422+
423+
assert gen2['y'].shape == (draws, n2)

0 commit comments

Comments
 (0)