47
47
ignore_logprob ,
48
48
logcdf ,
49
49
logp ,
50
+ reconsider_logprob ,
50
51
)
51
52
from pymc .logprob .abstract import get_measurable_outputs
52
53
from pymc .model import Model , Potential
@@ -315,7 +316,7 @@ def test_unexpected_rvs():
315
316
model .logp ()
316
317
317
318
318
- def test_ignore_logprob_basic ():
319
+ def test_ignore_reconsider_logprob_basic ():
319
320
x = Normal .dist ()
320
321
(measurable_x_out ,) = get_measurable_outputs (x .owner .op , x .owner )
321
322
assert measurable_x_out is x .owner .outputs [1 ]
@@ -328,18 +329,34 @@ def test_ignore_logprob_basic():
328
329
assert get_measurable_outputs (new_x .owner .op , new_x .owner ) == []
329
330
330
331
# Test that it will not clone a variable that is already unmeasurable
331
- new_new_x = ignore_logprob (new_x )
332
- assert new_new_x is new_x
333
-
334
-
335
- def test_ignore_logprob_model ():
336
- # logp that does not depend on input
337
- def logp (value , x ):
338
- return value
332
+ assert ignore_logprob (new_x ) is new_x
333
+
334
+ orig_x = reconsider_logprob (new_x )
335
+ assert orig_x is not new_x
336
+ assert isinstance (orig_x .owner .op , Normal )
337
+ assert type (orig_x .owner .op ).__name__ == "NormalRV"
338
+ # Confirm that it has measurable outputs again
339
+ assert get_measurable_outputs (orig_x .owner .op , orig_x .owner ) == [orig_x .owner .outputs [1 ]]
340
+
341
+ # Test that will not clone a variable that is already measurable
342
+ assert reconsider_logprob (x ) is x
343
+ assert reconsider_logprob (orig_x ) is orig_x
344
+
345
+
346
+ def test_ignore_reconsider_logprob_model ():
347
+ def custom_logp (value , x ):
348
+ # custom_logp is just the logp of x at value
349
+ x = reconsider_logprob (x )
350
+ return _joint_logp (
351
+ [x ],
352
+ rvs_to_values = {x : value },
353
+ rvs_to_transforms = {},
354
+ rvs_to_total_sizes = {},
355
+ )
339
356
340
357
with Model () as m :
341
358
x = Normal .dist ()
342
- y = CustomDist ("y" , x , logp = logp )
359
+ y = CustomDist ("y" , x , logp = custom_logp )
343
360
with pytest .warns (
344
361
UserWarning ,
345
362
match = "Found a random variable that was neither among the observations "
@@ -355,7 +372,7 @@ def logp(value, x):
355
372
# The above warning should go away with ignore_logprob.
356
373
with Model () as m :
357
374
x = ignore_logprob (Normal .dist ())
358
- y = CustomDist ("y" , x , logp = logp )
375
+ y = CustomDist ("y" , x , logp = custom_logp )
359
376
with warnings .catch_warnings ():
360
377
warnings .simplefilter ("error" )
361
378
assert _joint_logp (
0 commit comments