-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
FlatSwitch Op for logprob derivation of arbitrary censoring #6949
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f302dad
4e17948
197fecb
53ae39c
ae1808d
15d07f6
f42db78
5641e14
993ca5d
0706789
4b09202
5c15ed2
a7b19f3
190eb5f
b4a4b9d
b7090c1
2bd2ff8
79bd2d4
d175906
2ceec95
f89be08
d52cd26
5771620
5440095
d2789b0
2000795
5e03ddb
6070bc7
3d0a309
318206d
cf99a3f
3955280
5022313
7cb7ff6
e7530ef
252e742
9d85980
db20cef
b5f26a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -37,20 +37,42 @@ | |||||
from typing import Optional | ||||||
|
||||||
import numpy as np | ||||||
import pytensor | ||||||
import pytensor.tensor as pt | ||||||
|
||||||
from pytensor.graph.basic import Node | ||||||
from pytensor.graph import Op | ||||||
from pytensor.graph.basic import Apply, Node | ||||||
from pytensor.graph.fg import FunctionGraph | ||||||
from pytensor.graph.rewriting.basic import node_rewriter | ||||||
from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven | ||||||
from pytensor.raise_op import Assert | ||||||
from pytensor.scalar.basic import ( | ||||||
GE, | ||||||
GT, | ||||||
LE, | ||||||
LT, | ||||||
Ceil, | ||||||
Clip, | ||||||
Floor, | ||||||
RoundHalfToEven, | ||||||
Switch, | ||||||
) | ||||||
from pytensor.scalar.basic import clip as scalar_clip | ||||||
from pytensor.tensor import TensorVariable | ||||||
from pytensor.tensor import TensorType, TensorVariable | ||||||
from pytensor.tensor.basic import switch as switch | ||||||
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even | ||||||
from pytensor.tensor.random.op import RandomVariable | ||||||
from pytensor.tensor.variable import TensorConstant | ||||||
|
||||||
from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob | ||||||
from pymc.logprob.abstract import ( | ||||||
MeasurableElemwise, | ||||||
MeasurableVariable, | ||||||
_logcdf, | ||||||
_logcdf_helper, | ||||||
_logprob, | ||||||
_logprob_helper, | ||||||
) | ||||||
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db | ||||||
from pymc.logprob.utils import CheckParameterValue | ||||||
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability | ||||||
|
||||||
|
||||||
class MeasurableClip(MeasurableElemwise): | ||||||
|
@@ -238,3 +260,282 @@ def round_logprob(op, values, base_rv, **kwargs): | |||||
from pymc.math import logdiffexp | ||||||
|
||||||
return logdiffexp(logcdf_upper, logcdf_lower) | ||||||
|
||||||
|
||||||
class FlatSwitches(Op): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be more intuitive?
Suggested change
|
||||||
__props__ = ("out_dtype", "rv_idx") | ||||||
|
||||||
def __init__(self, *args, out_dtype, rv_idx, **kwargs): | ||||||
super().__init__(*args, **kwargs) | ||||||
self.out_dtype = out_dtype | ||||||
self.rv_idx = rv_idx | ||||||
|
||||||
def make_node(self, *inputs): | ||||||
return Apply( | ||||||
self, list(inputs), [TensorType(dtype=self.out_dtype, shape=inputs[0].type.shape)()] | ||||||
) | ||||||
|
||||||
def perform(self, *args, **kwargs): | ||||||
raise NotImplementedError("This Op should not be evaluated") | ||||||
|
||||||
|
||||||
MeasurableVariable.register(FlatSwitches) | ||||||
|
||||||
|
||||||
def get_intervals(binary_node, valued_rvs): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add docstrings explaining what this does, possibly also input/output type hints |
||||||
""" | ||||||
Handles both "x > 1" and "1 < x" expressions. | ||||||
""" | ||||||
|
||||||
measurable_inputs = [ | ||||||
inp for inp in binary_node.inputs if check_potential_measurability([inp], valued_rvs) | ||||||
] | ||||||
|
||||||
if len(measurable_inputs) != 1: | ||||||
return None | ||||||
|
||||||
measurable_var = measurable_inputs[0] | ||||||
measurable_var_idx = binary_node.inputs.index(measurable_var) | ||||||
|
||||||
const = binary_node.inputs[(measurable_var_idx + 1) % 2] | ||||||
|
||||||
# whether it is a lower or an upper bound depends on measurable_var_idx and the binary Op. | ||||||
is_gt_or_ge = isinstance(binary_node.op.scalar_op, (GT, GE)) | ||||||
is_lt_or_le = isinstance(binary_node.op.scalar_op, (LT, LE)) | ||||||
|
||||||
if not is_lt_or_le and not is_gt_or_ge: | ||||||
# Switch condition was not defined using binary Ops | ||||||
return None | ||||||
|
||||||
intervals = [(-np.inf, const), (const, np.inf)] | ||||||
|
||||||
# interval_true for the interval corresponds to true branch in 'Switch', interval_false corresponds to false branch | ||||||
if measurable_var_idx == 0: | ||||||
# e.g. "x < 1" and "x > 1" | ||||||
interval_true, interval_false = intervals if is_lt_or_le else intervals[::-1] | ||||||
else: | ||||||
# e.g. "1 > x" and "1 < x" | ||||||
interval_true, interval_false = intervals[::-1] if is_gt_or_ge else intervals | ||||||
|
||||||
return [interval_true, interval_false] | ||||||
|
||||||
|
||||||
def adjust_intervals(intervals, outer_interval): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add docstrings and possibly type hints |
||||||
for i in range(2): | ||||||
current = intervals[i] | ||||||
lower = pt.maximum(current[0], outer_interval[0]) | ||||||
upper = pt.minimum(current[1], outer_interval[1]) | ||||||
|
||||||
intervals[i] = (lower, upper) | ||||||
|
||||||
return intervals | ||||||
|
||||||
|
||||||
def flat_switch_helper(node, valued_rvs, encoding_list, outer_interval, base_rv): | ||||||
""" | ||||||
Carries out the main recursion through the branches to fetch the encodings, their respective | ||||||
intervals and adjust any overlaps. It also performs several checks on the switch condition and measurable | ||||||
components. | ||||||
""" | ||||||
from pymc.distributions.distribution import SymbolicRandomVariable | ||||||
|
||||||
switch_cond, *components = node.inputs | ||||||
|
||||||
# deny broadcasting of the switch condition | ||||||
if switch_cond.type.broadcastable != node.outputs[0].type.broadcastable: | ||||||
return None | ||||||
|
||||||
measurable_var_switch = [ | ||||||
var for var in switch_cond.owner.inputs if check_potential_measurability([var], valued_rvs) | ||||||
] | ||||||
|
||||||
if len(measurable_var_switch) != 1: | ||||||
return None | ||||||
|
||||||
current_base_var = measurable_var_switch[0] | ||||||
# deny cases where base_var is some function of 'x', e.g. pt.exp(x), and measurable var in the current switch is x | ||||||
# also check if all the sources of measurability are the same RV. e.g. x1 and x2 | ||||||
if current_base_var is not base_rv: | ||||||
return None | ||||||
|
||||||
measurable_var_idx = [] | ||||||
switch_comp_idx = [] | ||||||
|
||||||
# get the indices for the switch and other measurable components for further recursion | ||||||
for idx, component in enumerate(components): | ||||||
if check_potential_measurability([component], valued_rvs): | ||||||
if isinstance( | ||||||
component.owner.op, (RandomVariable, SymbolicRandomVariable) | ||||||
) or not isinstance(component.owner.op.scalar_op, Switch): | ||||||
measurable_var_idx.append(idx) | ||||||
else: | ||||||
switch_comp_idx.append(idx) | ||||||
|
||||||
# Check if measurability source and the component itself are the same for all measurable components | ||||||
if any(components[i] is not base_rv for i in measurable_var_idx): | ||||||
return None | ||||||
|
||||||
# Get intervals for true and false components from the condition | ||||||
intervals = get_intervals(switch_cond.owner, valued_rvs) | ||||||
adjusted_intervals = adjust_intervals(intervals, outer_interval) | ||||||
|
||||||
# Base condition for recursion - when there is no more switch in either of the components | ||||||
if not switch_comp_idx: | ||||||
# Insert the two components and their respective intervals into encoding_list | ||||||
for i in range(2): | ||||||
switch_dict = { | ||||||
"lower": adjusted_intervals[i][0], | ||||||
"upper": adjusted_intervals[i][1], | ||||||
"encoding": components[i], | ||||||
} | ||||||
encoding_list.append(switch_dict) | ||||||
|
||||||
return encoding_list | ||||||
|
||||||
else: | ||||||
for i in range(2): | ||||||
if i in switch_comp_idx: | ||||||
# Recurse through the switch component(es) | ||||||
encoding_list = flat_switch_helper( | ||||||
components[i].owner, valued_rvs, encoding_list, adjusted_intervals[i], base_rv | ||||||
) | ||||||
|
||||||
else: | ||||||
switch_dict = { | ||||||
"lower": adjusted_intervals[i][0], | ||||||
"upper": adjusted_intervals[i][1], | ||||||
"encoding": components[i], | ||||||
} | ||||||
encoding_list.append(switch_dict) | ||||||
|
||||||
return encoding_list | ||||||
|
||||||
|
||||||
@node_rewriter(tracks=[switch]) | ||||||
def find_measurable_flat_switch_encoding(fgraph: FunctionGraph, node: Node): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion, because "flat" is the IR output, not what is being found?
Suggested change
|
||||||
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) | ||||||
|
||||||
if rv_map_feature is None: | ||||||
return None # pragma: no cover | ||||||
|
||||||
valued_rvs = rv_map_feature.rv_values.keys() | ||||||
switch_cond = node.inputs[0] | ||||||
|
||||||
encoding_list = [] | ||||||
initial_interval = (-np.inf, np.inf) | ||||||
|
||||||
# fetch base_var as the only measurable input to the logical op in switch condition | ||||||
measurable_switch_inp = [ | ||||||
component | ||||||
for component in switch_cond.owner.inputs | ||||||
if check_potential_measurability([component], valued_rvs) | ||||||
] | ||||||
|
||||||
if len(measurable_switch_inp) != 1: | ||||||
return None | ||||||
|
||||||
base_rv = measurable_switch_inp[0] | ||||||
|
||||||
# We do not allow discrete RVs yet | ||||||
if base_rv.dtype.startswith("int"): | ||||||
return None | ||||||
|
||||||
# Since we verify the source of measurability to be the same for switch conditions | ||||||
# and all measurable components, denying broadcastability of the base_var is enough | ||||||
if base_rv.type.broadcastable != node.outputs[0].type.broadcastable: | ||||||
return None | ||||||
|
||||||
encoding_list = flat_switch_helper(node, valued_rvs, encoding_list, initial_interval, base_rv) | ||||||
if encoding_list is None: | ||||||
return None | ||||||
|
||||||
encodings, intervals = [], [] | ||||||
rv_idx = () | ||||||
|
||||||
# TODO: Some alternative cleaner way to do this | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is dirty about this approach? If so add comment, otherwise remove TODO? |
||||||
for idx, item in enumerate(encoding_list): | ||||||
encoding = item["encoding"] | ||||||
# indices of intervals having base_rv as their "encoding" | ||||||
if encoding == base_rv: | ||||||
rv_idx += (idx,) | ||||||
|
||||||
encodings.append(encoding) | ||||||
intervals.extend((item["lower"], item["upper"])) | ||||||
|
||||||
flat_switch_op = FlatSwitches(out_dtype=node.outputs[0].dtype, rv_idx=rv_idx) | ||||||
|
||||||
new_outs = flat_switch_op.make_node(base_rv, *intervals, *encodings).default_output() | ||||||
return [new_outs] | ||||||
|
||||||
|
||||||
@_logprob.register(FlatSwitches) | ||||||
def flat_switches_logprob(op, values, base_rv, *inputs, **kwargs): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also add some comment in the docstrings about the kind of logp graphs we are generating? |
||||||
from pymc.math import logdiffexp | ||||||
|
||||||
(value,) = values | ||||||
|
||||||
encodings_count = len(inputs) // 3 | ||||||
# 'inputs' is of the form (lower1, upper1, lower2, upper2, encoding1, encoding2) | ||||||
# Possible TODO: | ||||||
# encodings = op.get_encodings_from_inputs(inputs) | ||||||
Comment on lines
+479
to
+480
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either implement or remove TODO since it's not a high priority anyway? |
||||||
encodings = inputs[2 * encodings_count : 3 * encodings_count] | ||||||
encodings = pt.broadcast_arrays(*encodings) | ||||||
|
||||||
encodings = Assert(msg="all encodings should be unique")( | ||||||
encodings, pt.eq(pt.unique(encodings, axis=0).shape[0], len(encodings)) | ||||||
) | ||||||
|
||||||
# TODO: We do not support the encoding graphs of discrete RVs yet | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove easy to miss TODO. Either mention in docstrings, or add a NotImplementedError that will make sure we don't forget to update the logp function? |
||||||
|
||||||
interval_bounds = pt.broadcast_arrays(*inputs[0 : 2 * encodings_count]) | ||||||
lower_interval_bounds = interval_bounds[::2] | ||||||
upper_interval_bounds = interval_bounds[1::2] | ||||||
|
||||||
lower_interval_bounds = pt.concatenate([i[None] for i in lower_interval_bounds], axis=0) | ||||||
upper_interval_bounds = pt.concatenate([j[None] for j in upper_interval_bounds], axis=0) | ||||||
|
||||||
interval_bounds = pt.concatenate( | ||||||
[lower_interval_bounds[None], upper_interval_bounds[None]], axis=0 | ||||||
) | ||||||
|
||||||
# define a logcdf map on a scalar, use vectorize to calculate it for 2D intervals | ||||||
scalar_interval_bound = pt.scalar("scalar_interval_bound", dtype=base_rv.dtype) | ||||||
logcdf_scalar_interval_bound = _logcdf_helper(base_rv, scalar_interval_bound, **kwargs) | ||||||
logcdf_interval_bounds = pytensor.graph.replace.vectorize( | ||||||
logcdf_scalar_interval_bound, replace={scalar_interval_bound: interval_bounds} | ||||||
) | ||||||
logcdf_intervals = logdiffexp( | ||||||
logcdf_interval_bounds[1, ...], logcdf_interval_bounds[0, ...] | ||||||
) # (encoding, *base_rv.shape) | ||||||
|
||||||
# default logprob is -inf if there is no RV in branches | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explain what's happening in the first if branch as well? |
||||||
if op.rv_idx: | ||||||
logprob = _logprob_helper(base_rv, value, **kwargs) | ||||||
|
||||||
# Add rv branch (and checks whether it is possible) | ||||||
for i in op.rv_idx: | ||||||
logprob = pt.where( | ||||||
pt.and_(value <= upper_interval_bounds[i], value >= lower_interval_bounds[i]), | ||||||
logprob, | ||||||
-np.inf, | ||||||
) | ||||||
else: | ||||||
logprob = -np.inf | ||||||
|
||||||
for i in range(encodings_count): | ||||||
# if encoding found in interval (Lower, Upper), then Prob = CDF(Upper) - CDF(Lower) | ||||||
logprob = pt.where( | ||||||
pt.eq(value, encodings[i]), | ||||||
logcdf_intervals[i], | ||||||
logprob, | ||||||
) | ||||||
|
||||||
return logprob | ||||||
|
||||||
|
||||||
measurable_ir_rewrites_db.register( | ||||||
"find_measurable_flat_switch_encoding", | ||||||
find_measurable_flat_switch_encoding, | ||||||
"basic", | ||||||
"censoring", | ||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docstrings explaining what this Op does / where it is used for. Most importantly what is the IR representation this Op uses for what kind of original graphs. This can be done either here or in the main rewrite. If on the main rewrite, just mention here to check out the docstring in the rewrite.