Skip to content

Commit 69f5caa

Browse files
committed
Combine Add and Mul logps
1 parent 4cb0d1b commit 69f5caa

File tree

1 file changed

+33
-79
lines changed

1 file changed

+33
-79
lines changed

pymc3/distributions/logp.py

+33-79
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,8 @@ def _logp(
271271

272272

273273
@_logp.register(Elemwise)
274-
def logp_elemwise(op, *args, **kwargs):
275-
if hasattr(op, "scalar_op"):
276-
return _logp(op.scalar_op, *args, **kwargs)
277-
raise NotImplementedError
274+
def elemwise_logp(op, *args, **kwargs):
275+
return _logp(op.scalar_op, *args, **kwargs)
278276

279277

280278
# TODO: Implement DimShuffle logp?
@@ -287,14 +285,17 @@ def logp_elemwise(op, *args, **kwargs):
287285
# raise NotImplementedError
288286

289287

290-
def find_rv_branch(inputs):
291-
"""
292-
Helper function to find which input branch(es) contain unregistered random variables
293-
"""
294-
rv_branch = []
295-
no_rv_branch = []
288+
@_logp.register(Add)
289+
@_logp.register(Mul)
290+
def linear_logp(op, var, rvs_to_values, *linear_inputs, **kwargs):
291+
292+
if len(linear_inputs) != 2:
293+
raise ValueError(f"Expected 2 inputs but got: {len(linear_inputs)}")
296294

297-
for inp in inputs:
295+
# Find base_rv and constant inputs
296+
base_rv = []
297+
constant = []
298+
for inp in linear_inputs:
298299
res_ancestors = list(walk_model((inp,), walk_past_rvs=True))
299300
# unregistered variables do not contain a value_var tag
300301
res_unregistered_ancestors = [
@@ -305,94 +306,47 @@ def find_rv_branch(inputs):
305306
and not getattr(v.tag, "value_var", False)
306307
]
307308
if res_unregistered_ancestors:
308-
rv_branch.append(inp)
309+
base_rv.append(inp)
309310
else:
310-
no_rv_branch.append(inp)
311-
312-
return rv_branch, no_rv_branch
313-
314-
315-
@_logp.register(Add)
316-
def add_logp(op, var, rvs_to_values, *add_inputs, **kwargs):
317-
318-
if len(add_inputs) != 2:
319-
raise ValueError(f"Expected 2 inputs but got: {len(add_inputs)}")
320-
321-
base_rv, loc = find_rv_branch(add_inputs)
311+
constant.append(inp)
322312

323313
if len(base_rv) != 1:
324314
raise NotImplementedError(
325-
f"Logp of addition requires one branch with an unregistered RandomVariable but got {len(base_rv)}"
315+
f"Logp of linear transform requires one branch with an unregistered RandomVariable but got {len(base_rv)}"
326316
)
327317

328-
var_value = rvs_to_values.get(var, var)
329-
loc = loc[0]
330318
base_rv = base_rv[0]
331-
base_value = base_rv.type()
332-
333-
logp_base_rv = logpt(base_rv, {base_rv: base_value}, **kwargs)
334-
fgraph = FunctionGraph(
335-
[i for i in graph_inputs((logp_base_rv,)) if not isinstance(i, Constant)],
336-
[logp_base_rv],
337-
clone=False,
338-
)
339-
fgraph.replace(base_value, var_value - loc, import_missing=True)
340-
logp_add_rv = fgraph.outputs[0]
341-
342-
# Replace rvs in graph
343-
# TODO: This shouldn't be here
344-
(logp_add_rv,), _ = rvs_to_value_vars(
345-
(logp_add_rv,),
346-
apply_transforms=True, # Change this
347-
initial_replacements=None,
348-
)
349-
350-
logp_add_rv.name = f"__logp_{var.name}"
351-
352-
return logp_add_rv
353-
354-
355-
@_logp.register(Mul)
356-
def mul_logp(op, var, rvs_to_values, *mul_inputs, **kwargs):
357-
358-
if len(mul_inputs) != 2:
359-
raise ValueError(f"Expected 2 inputs but got: {len(mul_inputs)}")
360-
361-
base_rv, scale = find_rv_branch(mul_inputs)
362-
363-
if len(base_rv) != 1:
364-
raise NotImplementedError(
365-
f"Logp of product requires one branch with an unregistered RandomVariable but got {len(base_rv)}"
366-
)
367-
319+
constant = constant[0]
368320
var_value = rvs_to_values.get(var, var)
369-
scale = scale[0]
370-
base_rv = base_rv[0]
371-
base_value = base_rv.type()
372321

322+
# Get logp of base_rv
323+
base_value = base_rv.type()
373324
logp_base_rv = logpt(base_rv, {base_rv: base_value}, **kwargs)
374325
fgraph = FunctionGraph(
375326
[i for i in graph_inputs((logp_base_rv,)) if not isinstance(i, Constant)],
376-
[logp_base_rv],
327+
outputs=[logp_base_rv],
377328
clone=False,
378329
)
379330

380-
# TODO: This is not correct for discrete variables
381-
# TODO: Undefined behavior for scale = 0
382-
fgraph.replace(base_value, var_value / scale, import_missing=True)
383-
logp_mul_rv = fgraph.outputs[0] - at.log(at.abs_(scale))
331+
# Transform base_rv and apply jacobian correction (for continuous rvs)
332+
if isinstance(op, Add):
333+
fgraph.replace(base_value, var_value - constant, import_missing=True)
334+
logp_linear_rv = fgraph.outputs[0]
335+
elif isinstance(op, Mul):
336+
fgraph.replace(base_value, var_value / constant, import_missing=True)
337+
logp_linear_rv = fgraph.outputs[0]
338+
if "float" in base_rv.dtype:
339+
logp_linear_rv -= at.log(at.abs_(constant))
384340

385341
# Replace rvs in graph
386-
# TODO: This shouldn't be here
387-
(logp_mul_rv,), _ = rvs_to_value_vars(
388-
(logp_mul_rv,),
389-
apply_transforms=True, # Change this
342+
(logp_linear_rv,), _ = rvs_to_value_vars(
343+
(logp_linear_rv,),
344+
apply_transforms=kwargs.get("transformed", True),
390345
initial_replacements=None,
391346
)
392347

393-
logp_mul_rv.name = f"__logp_{var.name}"
394-
395-
return logp_mul_rv
348+
logp_linear_rv.name = f"__logp_{var.name}"
349+
return logp_linear_rv
396350

397351

398352
def convert_indices(indices, entry):

0 commit comments

Comments
 (0)