Skip to content

Commit 6007e84

Browse files
committed
Add helper to undo ignore_logprob
1 parent 3c01e65 commit 6007e84

File tree

3 files changed

+50
-13
lines changed

3 files changed

+50
-13
lines changed

pymc/distributions/logprob.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
14+
from copy import copy
1615
from typing import Dict, List, Sequence, Union
1716

1817
import numpy as np
@@ -210,3 +209,23 @@ def ignore_logprob(rv: TensorVariable) -> TensorVariable:
210209
return rv
211210
new_node = assign_custom_measurable_outputs(node, type_prefix=prefix)
212211
return new_node.outputs[node.outputs.index(rv)]
212+
213+
214+
def reconsider_logprob(rv: TensorVariable) -> TensorVariable:
215+
"""Return a duplicated variable that is considered when creating logprob graphs
216+
217+
This undoes the effect of `ignore_logprob`.
218+
219+
If a variable was not ignored, it is returned directly.
220+
"""
221+
prefix = "Unmeasurable"
222+
node = rv.owner
223+
op_type = type(node.op)
224+
if not op_type.__name__.startswith(prefix):
225+
return rv
226+
227+
new_node = node.clone()
228+
original_op_type = new_node.op.original_op_type
229+
new_node.op = copy(new_node.op)
230+
new_node.op.__class__ = original_op_type
231+
return new_node.outputs[node.outputs.index(rv)]

pymc/logprob/abstract.py

+1
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def assign_custom_measurable_outputs(
220220

221221
new_op_dict = op_type.__dict__.copy()
222222
new_op_dict["id_obj"] = (new_node.op, measurable_outputs_fn)
223+
new_op_dict.setdefault("original_op_type", op_type)
223224

224225
new_op_type = type(
225226
f"{type_prefix}{op_type.__name__}", (op_type, UnmeasurableVariable), new_op_dict

pymc/tests/distributions/test_logprob.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
ignore_logprob,
4848
logcdf,
4949
logp,
50+
reconsider_logprob,
5051
)
5152
from pymc.logprob.abstract import get_measurable_outputs
5253
from pymc.model import Model, Potential
@@ -315,7 +316,7 @@ def test_unexpected_rvs():
315316
model.logp()
316317

317318

318-
def test_ignore_logprob_basic():
319+
def test_ignore_reconsider_logprob_basic():
319320
x = Normal.dist()
320321
(measurable_x_out,) = get_measurable_outputs(x.owner.op, x.owner)
321322
assert measurable_x_out is x.owner.outputs[1]
@@ -328,18 +329,34 @@ def test_ignore_logprob_basic():
328329
assert get_measurable_outputs(new_x.owner.op, new_x.owner) == []
329330

330331
# Test that it will not clone a variable that is already unmeasurable
331-
new_new_x = ignore_logprob(new_x)
332-
assert new_new_x is new_x
333-
334-
335-
def test_ignore_logprob_model():
336-
# logp that does not depend on input
337-
def logp(value, x):
338-
return value
332+
assert ignore_logprob(new_x) is new_x
333+
334+
orig_x = reconsider_logprob(new_x)
335+
assert orig_x is not new_x
336+
assert isinstance(orig_x.owner.op, Normal)
337+
assert type(orig_x.owner.op).__name__ == "NormalRV"
338+
# Confirm that it has measurable outputs again
339+
assert get_measurable_outputs(orig_x.owner.op, orig_x.owner) == [orig_x.owner.outputs[1]]
340+
341+
# Test that will not clone a variable that is already measurable
342+
assert reconsider_logprob(x) is x
343+
assert reconsider_logprob(orig_x) is orig_x
344+
345+
346+
def test_ignore_reconsider_logprob_model():
347+
def custom_logp(value, x):
348+
# custom_logp is just the logp of x at value
349+
x = reconsider_logprob(x)
350+
return _joint_logp(
351+
[x],
352+
rvs_to_values={x: value},
353+
rvs_to_transforms={},
354+
rvs_to_total_sizes={},
355+
)
339356

340357
with Model() as m:
341358
x = Normal.dist()
342-
y = CustomDist("y", x, logp=logp)
359+
y = CustomDist("y", x, logp=custom_logp)
343360
with pytest.warns(
344361
UserWarning,
345362
match="Found a random variable that was neither among the observations "
@@ -355,7 +372,7 @@ def logp(value, x):
355372
# The above warning should go away with ignore_logprob.
356373
with Model() as m:
357374
x = ignore_logprob(Normal.dist())
358-
y = CustomDist("y", x, logp=logp)
375+
y = CustomDist("y", x, logp=custom_logp)
359376
with warnings.catch_warnings():
360377
warnings.simplefilter("error")
361378
assert _joint_logp(

0 commit comments

Comments
 (0)