Skip to content

Commit e8f9602

Browse files
committed
Implement logp for add and mul ops involving one unregistered random variable
1 parent 79245ce commit e8f9602

File tree

2 files changed

+269
-2
lines changed

2 files changed

+269
-2
lines changed

pymc3/distributions/logp.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from aesara.graph.fg import FunctionGraph
2626
from aesara.graph.op import Op, compute_test_value
2727
from aesara.graph.type import CType
28+
from aesara.scalar.basic import Add, Mul
29+
from aesara.tensor.elemwise import Elemwise
2830
from aesara.tensor.random.op import RandomVariable
2931
from aesara.tensor.random.opt import local_subtensor_rv_lift
3032
from aesara.tensor.subtensor import (
@@ -37,7 +39,12 @@
3739
)
3840
from aesara.tensor.var import TensorVariable
3941

40-
from pymc3.aesaraf import extract_rv_and_value_vars, floatX, rvs_to_value_vars
42+
from pymc3.aesaraf import (
43+
extract_rv_and_value_vars,
44+
floatX,
45+
rvs_to_value_vars,
46+
walk_model,
47+
)
4148

4249

4350
@singledispatch
@@ -260,6 +267,130 @@ def _logp(
260267
"""
261268
value_var = rvs_to_values.get(var, var)
262269
return at.zeros_like(value_var)
270+
# raise NotImplementedError(f"Logp cannot be computed for op {op}")
271+
272+
273+
@_logp.register(Elemwise)
274+
def logp_elemwise(op, *args, **kwargs):
275+
if hasattr(op, "scalar_op"):
276+
return _logp(op.scalar_op, *args, **kwargs)
277+
raise NotImplementedError
278+
279+
280+
# TODO: Implement DimShuffle logp?
281+
# @_logp.register(DimShuffle)
282+
# def logp_dimshuffle(op, var, *args, **kwargs):
283+
# if var.owner and len(var.owner.inputs) == 1:
284+
# inp = var.owner.inputs[0]
285+
# if inp.owner and hasattr(inp.owner, 'op'):
286+
# return _logp(inp.owner.op, inp, *args, **kwargs)
287+
# raise NotImplementedError
288+
289+
290+
def find_rv_branch(inputs):
291+
"""
292+
Helper function to find which input branch(es) contain unregistered random variables
293+
"""
294+
rv_branch = []
295+
no_rv_branch = []
296+
297+
for inp in inputs:
298+
res_ancestors = list(walk_model((inp,), walk_past_rvs=True))
299+
# unregistered variables do not contain a value_var tag
300+
res_unregistered_ancestors = [
301+
v
302+
for v in res_ancestors
303+
if v.owner
304+
and isinstance(v.owner.op, RandomVariable)
305+
and not getattr(v.tag, "value_var", False)
306+
]
307+
if res_unregistered_ancestors:
308+
rv_branch.append(inp)
309+
else:
310+
no_rv_branch.append(inp)
311+
312+
return rv_branch, no_rv_branch
313+
314+
315+
@_logp.register(Add)
316+
def add_logp(op, var, rvs_to_values, *add_inputs, **kwargs):
317+
318+
if len(add_inputs) != 2:
319+
raise ValueError(f"Expected 2 inputs but got: {len(add_inputs)}")
320+
321+
rv, loc = find_rv_branch(add_inputs)
322+
323+
if len(rv) != 1:
324+
raise NotImplementedError(
325+
f"Logp of addition requires one branch with an unregistered RandomVariable but got {len(rv)}"
326+
)
327+
328+
rv = rv[0]
329+
rv_value = rvs_to_values.get(rv, getattr(rv.tag, "value_var", rv))
330+
loc = loc[0]
331+
loc_value = rvs_to_values.get(loc, getattr(loc.tag, "value_var", loc))
332+
333+
new_rvs_to_values = rvs_to_values.copy()
334+
new_rvs_to_values[rv] = rv_value
335+
336+
logp_rv = logpt(rv, new_rvs_to_values, **kwargs)
337+
fgraph = FunctionGraph(
338+
[i for i in graph_inputs((logp_rv,)) if not isinstance(i, Constant)],
339+
[logp_rv],
340+
clone=False,
341+
)
342+
343+
var_value = rvs_to_values.get(var, var)
344+
345+
fgraph.add_input(loc_value)
346+
fgraph.add_input(var_value)
347+
fgraph.replace(rv_value, var_value - loc_value)
348+
349+
logp_rv.name = f"__logp_{var.name}"
350+
351+
return logp_rv
352+
353+
354+
@_logp.register(Mul)
355+
def mul_logp(op, var, rvs_to_values, *mul_inputs, **kwargs):
356+
357+
if len(mul_inputs) != 2:
358+
raise ValueError(f"Expected 2 inputs but got: {len(mul_inputs)}")
359+
360+
rv, scale = find_rv_branch(mul_inputs)
361+
362+
if len(rv) != 1:
363+
raise NotImplementedError(
364+
f"Logp of product requires one branch with an unregistered RandomVariable but got {len(rv)}"
365+
)
366+
367+
rv = rv[0]
368+
rv_value = rvs_to_values.get(rv, getattr(rv.tag, "value_var", rv))
369+
scale = scale[0]
370+
scale_value = rvs_to_values.get(scale, getattr(scale.tag, "value_var", scale))
371+
372+
new_rvs_to_values = rvs_to_values.copy()
373+
new_rvs_to_values[rv] = rv_value
374+
375+
logp_rv = logpt(rv, new_rvs_to_values, **kwargs)
376+
fgraph = FunctionGraph(
377+
[i for i in graph_inputs((logp_rv,)) if not isinstance(i, Constant)],
378+
[logp_rv],
379+
clone=False,
380+
)
381+
382+
var_value = rvs_to_values.get(var, var)
383+
384+
fgraph.add_input(scale_value)
385+
fgraph.add_input(var_value)
386+
# TODO: This is not correct for discrete variables
387+
# TODO: Undefined behavior for scale = 0
388+
fgraph.replace(rv_value, var_value / scale_value)
389+
390+
logp_rv = fgraph.outputs[0] - at.log(at.abs_(scale_value))
391+
logp_rv.name = f"__logp_{var.name}"
392+
393+
return logp_rv
263394

264395

265396
def convert_indices(indices, entry):

pymc3/tests/test_logp.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from contextlib import ExitStack as does_not_raise
15+
1416
import aesara
1517
import aesara.tensor as at
1618
import numpy as np
@@ -31,7 +33,7 @@
3133
)
3234

3335
from pymc3.aesaraf import floatX, walk_model
34-
from pymc3.distributions.continuous import Normal, Uniform
36+
from pymc3.distributions.continuous import Exponential, Normal, Uniform
3537
from pymc3.distributions.discrete import Bernoulli
3638
from pymc3.distributions.logp import logpt
3739
from pymc3.model import Model
@@ -69,6 +71,140 @@ def test_logpt_basic():
6971
assert a_value_var in res_ancestors
7072

7173

74+
def test_logpt_add():
75+
"""
76+
Mare sure we can compute a log-likelihood for ``loc + Y`` where ``Y`` is an unregistered
77+
random variable and ``loc`` is an tensor variable or a registered random variable
78+
"""
79+
with Model() as m:
80+
loc = Uniform("loc", 0, 1)
81+
x = Normal.dist(0, 1) + loc
82+
m.register_rv(x, "x")
83+
84+
loc_value_var = m.rvs_to_values[loc]
85+
x_value_var = m.rvs_to_values[x]
86+
87+
x_logp = logpt(x, m.rvs_to_values[x])
88+
89+
res_ancestors = list(walk_model((x_logp,), walk_past_rvs=True))
90+
res_rv_ancestors = [
91+
v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable)
92+
]
93+
94+
# There shouldn't be any `RandomVariable`s in the resulting graph
95+
assert len(res_rv_ancestors) == 0
96+
assert loc_value_var in res_ancestors
97+
assert x_value_var in res_ancestors
98+
99+
# Test logp is correct
100+
f_logp = aesara.function([x_value_var, loc_value_var], x_logp)
101+
np.testing.assert_almost_equal(f_logp(50, 50), sp.norm(50, 1).logpdf(50))
102+
np.testing.assert_almost_equal(f_logp(50, 0), sp.norm(0, 1).logpdf(50), decimal=5)
103+
104+
105+
def test_logpt_mul():
106+
"""
107+
Mare sure we can compute a log-likelihood for ``scale * Y`` where ``Y`` is an unregistered
108+
random variable and ``scale`` is an tensor variable or a registered random variable
109+
"""
110+
with Model() as m:
111+
scale = Uniform("scale", 0, 1)
112+
x = Exponential.dist(1) * scale
113+
m.register_rv(x, "x")
114+
115+
scale_value_var = m.rvs_to_values[scale]
116+
x_value_var = m.rvs_to_values[x]
117+
118+
x_logp = logpt(x, m.rvs_to_values[x])
119+
120+
res_ancestors = list(walk_model((x_logp,), walk_past_rvs=True))
121+
res_rv_ancestors = [
122+
v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable)
123+
]
124+
125+
# There shouldn't be any `RandomVariable`s in the resulting graph
126+
assert len(res_rv_ancestors) == 0
127+
assert scale_value_var in res_ancestors
128+
assert x_value_var in res_ancestors
129+
130+
# Test logp is correct
131+
f_logp = aesara.function([x_value_var, scale_value_var], x_logp)
132+
np.testing.assert_almost_equal(f_logp(0, 5), sp.expon(scale=5).logpdf(0))
133+
np.testing.assert_almost_equal(f_logp(-2, -2), sp.expon(scale=2).logpdf(2))
134+
assert f_logp(2, -2) == -np.inf
135+
136+
137+
def test_logpt_mul_add():
138+
"""
139+
Mare sure we can compute a log-likelihood for ``loc + scale * Y`` where ``Y`` is an unregistered
140+
random variable and ``loc`` and ``scale`` are tensor variables or registered random variables
141+
"""
142+
with Model() as m:
143+
loc = Uniform("loc", 0, 1)
144+
scale = Uniform("scale", 0, 1)
145+
x = loc + scale * Normal.dist(0, 1)
146+
m.register_rv(x, "x")
147+
148+
loc_value_var = m.rvs_to_values[loc]
149+
scale_value_var = m.rvs_to_values[scale]
150+
x_value_var = m.rvs_to_values[x]
151+
152+
x_logp = logpt(x, m.rvs_to_values[x])
153+
154+
res_ancestors = list(walk_model((x_logp,), walk_past_rvs=True))
155+
res_rv_ancestors = [
156+
v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable)
157+
]
158+
159+
# There shouldn't be any `RandomVariable`s in the resulting graph
160+
assert len(res_rv_ancestors) == 0
161+
assert loc_value_var in res_ancestors
162+
assert scale_value_var in res_ancestors
163+
assert x_value_var in res_ancestors
164+
165+
# Test logp is correct
166+
f_logp = aesara.function([x_value_var, loc_value_var, scale_value_var], x_logp)
167+
np.testing.assert_almost_equal(f_logp(-1, 0, 2), sp.norm(0, 2).logpdf(-1))
168+
np.testing.assert_almost_equal(f_logp(95, 100, 15), sp.norm(100, 15).logpdf(95), decimal=6)
169+
170+
171+
def test_logpt_not_implemented():
172+
"""Test that logpt for add and mul fail if inputs are 0 or 2 unregistered rvs"""
173+
174+
with Model() as m:
175+
variable1 = at.as_tensor_variable(1, "variable1")
176+
variable2 = at.scalar("variable2")
177+
unregistered1 = Normal.dist(0, 1)
178+
unregistered2 = Normal.dist(0, 1)
179+
registered1 = Normal("registered1", 0, 1)
180+
registered2 = Normal("registered2", 0, 1)
181+
182+
x_fail1 = variable1 + variable2
183+
x_fail2 = unregistered1 + unregistered2
184+
x_fail3 = registered1 + variable1
185+
x_fail4 = registered1 + registered2
186+
187+
x_pass1 = variable1 + unregistered2
188+
x_pass2 = unregistered1 + variable2
189+
x_pass3 = registered1 + unregistered1
190+
191+
m.register_rv(x_fail1, "x_fail1")
192+
m.register_rv(x_fail2, "x_fail2")
193+
m.register_rv(x_fail3, "x_fail3")
194+
m.register_rv(x_fail4, "x_fail4")
195+
m.register_rv(x_pass1, "x_pass1")
196+
m.register_rv(x_pass2, "x_pass2")
197+
m.register_rv(x_pass3, "x_pass3")
198+
199+
for rv, value_var in m.rvs_to_values.items():
200+
if "fail" in rv.name:
201+
with pytest.raises(NotImplementedError):
202+
logpt(rv, value_var)
203+
else:
204+
with does_not_raise():
205+
logpt(rv, value_var)
206+
207+
72208
@pytest.mark.parametrize(
73209
"indices, size",
74210
[

0 commit comments

Comments
 (0)