Skip to content

FlatSwitch Op for logprob derivation of arbitrary censoring #6949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f302dad
create methods for interval extraction and overlap adjustment
shreyas3156 Oct 3, 2023
4e17948
add broadcastability and measurability checks for switch condition
shreyas3156 Oct 3, 2023
197fecb
add checks for any measurable components in the branches
shreyas3156 Oct 3, 2023
53ae39c
get_measurability_source returns the set of all sources
shreyas3156 Oct 12, 2023
ae1808d
Add SymbolicRandomVariable as an ancestor_var candidate
shreyas3156 Oct 12, 2023
15d07f6
allow only base one base_rv in the entire graph
shreyas3156 Oct 12, 2023
f42db78
disallow discrete RVs
shreyas3156 Oct 12, 2023
5641e14
remove check for non-empty base_rv
shreyas3156 Oct 12, 2023
993ca5d
add single broadcastability check
shreyas3156 Oct 12, 2023
0706789
configure the output dtype of FlatSwitch Op
shreyas3156 Oct 12, 2023
4b09202
broadcastability check in every recursion not necessary
shreyas3156 Oct 12, 2023
5c15ed2
temporary print statements
shreyas3156 Oct 12, 2023
a7b19f3
fetch base_var from the switch condition
shreyas3156 Oct 12, 2023
190eb5f
fix issue with adding encoding when one of the branches is switch
shreyas3156 Oct 12, 2023
b4a4b9d
unpack intervals and encodings as the inputs to FlatSwitch Op
shreyas3156 Oct 12, 2023
b7090c1
verify if the base RVs are the same in all branches
shreyas3156 Oct 12, 2023
2bd2ff8
specify output dtype and shape of FlatSwitch Op
shreyas3156 Oct 24, 2023
79bd2d4
remove redundant checks for source of measurability and base_rv
shreyas3156 Oct 24, 2023
d175906
remove meta_info since info is passed as the op inputs
shreyas3156 Oct 24, 2023
2ceec95
Add logprob calculations
shreyas3156 Oct 31, 2023
f89be08
Vectorize logcdf for all the interval bounds
shreyas3156 Nov 14, 2023
d52cd26
Broadcast the intervals
Nov 20, 2023
5771620
Add indices of branches with RVs in Op property
Nov 23, 2023
5440095
Tests for single and double switch
Nov 23, 2023
d2789b0
test for arbitrary censoring with 3 switch branches
Nov 23, 2023
2000795
Refactor logp
ricardoV94 Nov 28, 2023
5e03ddb
Add comment
ricardoV94 Nov 28, 2023
6070bc7
Modify logp calculation to check for rv branch before encodings
shreyas3156 Jul 19, 2023
3d0a309
add rv_idx to props
Dec 21, 2023
318206d
Fix axis in pt.unique to ignore broadcasted encodings
Dec 21, 2023
cf99a3f
Remove pytest parametrize to only compile the logp once
Dec 21, 2023
3955280
Add broadcastability tests
Dec 21, 2023
5022313
Add tests to check measurability source and denying discrete
Dec 22, 2023
7cb7ff6
Handle None returns from flat_switch_helper()
Dec 22, 2023
e7530ef
Modify default logprob when there is no RV in any branch
Jan 5, 2024
252e742
Change the output dtype of FlatSwitch Op
Jan 5, 2024
9d85980
Tests for no measurable variable in any branch
Jan 5, 2024
db20cef
Test fix-up for nested branches
Jan 5, 2024
b5f26a4
Pre-commit fix
Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 306 additions & 5 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,42 @@
from typing import Optional

import numpy as np
import pytensor
import pytensor.tensor as pt

from pytensor.graph.basic import Node
from pytensor.graph import Op
from pytensor.graph.basic import Apply, Node
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven
from pytensor.raise_op import Assert
from pytensor.scalar.basic import (
GE,
GT,
LE,
LT,
Ceil,
Clip,
Floor,
RoundHalfToEven,
Switch,
)
from pytensor.scalar.basic import clip as scalar_clip
from pytensor.tensor import TensorVariable
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.basic import switch as switch
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorConstant

from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob
from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableVariable,
_logcdf,
_logcdf_helper,
_logprob,
_logprob_helper,
)
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import CheckParameterValue
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability


class MeasurableClip(MeasurableElemwise):
Expand Down Expand Up @@ -238,3 +260,282 @@ def round_logprob(op, values, base_rv, **kwargs):
from pymc.math import logdiffexp

return logdiffexp(logcdf_upper, logcdf_lower)


class FlatSwitches(Op):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstrings explaining what this Op does / where it is used for. Most importantly what is the IR representation this Op uses for what kind of original graphs. This can be done either here or in the main rewrite. If on the main rewrite, just mention here to check out the docstring in the rewrite.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be more intuitive?

Suggested change
class FlatSwitches(Op):
class NestedEncodingSwitches(Op):

__props__ = ("out_dtype", "rv_idx")

def __init__(self, *args, out_dtype, rv_idx, **kwargs):
super().__init__(*args, **kwargs)
self.out_dtype = out_dtype
self.rv_idx = rv_idx

def make_node(self, *inputs):
return Apply(
self, list(inputs), [TensorType(dtype=self.out_dtype, shape=inputs[0].type.shape)()]
)

def perform(self, *args, **kwargs):
raise NotImplementedError("This Op should not be evaluated")


MeasurableVariable.register(FlatSwitches)


def get_intervals(binary_node, valued_rvs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstrings explaining what this does, possibly also input/output type hints

"""
Handles both "x > 1" and "1 < x" expressions.
"""

measurable_inputs = [
inp for inp in binary_node.inputs if check_potential_measurability([inp], valued_rvs)
]

if len(measurable_inputs) != 1:
return None

measurable_var = measurable_inputs[0]
measurable_var_idx = binary_node.inputs.index(measurable_var)

const = binary_node.inputs[(measurable_var_idx + 1) % 2]

# whether it is a lower or an upper bound depends on measurable_var_idx and the binary Op.
is_gt_or_ge = isinstance(binary_node.op.scalar_op, (GT, GE))
is_lt_or_le = isinstance(binary_node.op.scalar_op, (LT, LE))

if not is_lt_or_le and not is_gt_or_ge:
# Switch condition was not defined using binary Ops
return None

intervals = [(-np.inf, const), (const, np.inf)]

# interval_true for the interval corresponds to true branch in 'Switch', interval_false corresponds to false branch
if measurable_var_idx == 0:
# e.g. "x < 1" and "x > 1"
interval_true, interval_false = intervals if is_lt_or_le else intervals[::-1]
else:
# e.g. "1 > x" and "1 < x"
interval_true, interval_false = intervals[::-1] if is_gt_or_ge else intervals

return [interval_true, interval_false]


def adjust_intervals(intervals, outer_interval):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstrings and possibly type hints

for i in range(2):
current = intervals[i]
lower = pt.maximum(current[0], outer_interval[0])
upper = pt.minimum(current[1], outer_interval[1])

intervals[i] = (lower, upper)

return intervals


def flat_switch_helper(node, valued_rvs, encoding_list, outer_interval, base_rv):
"""
Carries out the main recursion through the branches to fetch the encodings, their respective
intervals and adjust any overlaps. It also performs several checks on the switch condition and measurable
components.
"""
from pymc.distributions.distribution import SymbolicRandomVariable

switch_cond, *components = node.inputs

# deny broadcasting of the switch condition
if switch_cond.type.broadcastable != node.outputs[0].type.broadcastable:
return None

measurable_var_switch = [
var for var in switch_cond.owner.inputs if check_potential_measurability([var], valued_rvs)
]

if len(measurable_var_switch) != 1:
return None

current_base_var = measurable_var_switch[0]
# deny cases where base_var is some function of 'x', e.g. pt.exp(x), and measurable var in the current switch is x
# also check if all the sources of measurability are the same RV. e.g. x1 and x2
if current_base_var is not base_rv:
return None

measurable_var_idx = []
switch_comp_idx = []

# get the indices for the switch and other measurable components for further recursion
for idx, component in enumerate(components):
if check_potential_measurability([component], valued_rvs):
if isinstance(
component.owner.op, (RandomVariable, SymbolicRandomVariable)
) or not isinstance(component.owner.op.scalar_op, Switch):
measurable_var_idx.append(idx)
else:
switch_comp_idx.append(idx)

# Check if measurability source and the component itself are the same for all measurable components
if any(components[i] is not base_rv for i in measurable_var_idx):
return None

# Get intervals for true and false components from the condition
intervals = get_intervals(switch_cond.owner, valued_rvs)
adjusted_intervals = adjust_intervals(intervals, outer_interval)

# Base condition for recursion - when there is no more switch in either of the components
if not switch_comp_idx:
# Insert the two components and their respective intervals into encoding_list
for i in range(2):
switch_dict = {
"lower": adjusted_intervals[i][0],
"upper": adjusted_intervals[i][1],
"encoding": components[i],
}
encoding_list.append(switch_dict)

return encoding_list

else:
for i in range(2):
if i in switch_comp_idx:
# Recurse through the switch component(es)
encoding_list = flat_switch_helper(
components[i].owner, valued_rvs, encoding_list, adjusted_intervals[i], base_rv
)

else:
switch_dict = {
"lower": adjusted_intervals[i][0],
"upper": adjusted_intervals[i][1],
"encoding": components[i],
}
encoding_list.append(switch_dict)

return encoding_list


@node_rewriter(tracks=[switch])
def find_measurable_flat_switch_encoding(fgraph: FunctionGraph, node: Node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion, because "flat" is the IR output, not what is being found?

Suggested change
def find_measurable_flat_switch_encoding(fgraph: FunctionGraph, node: Node):
def find_nested_switch_encoding(fgraph: FunctionGraph, node: Node):

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

valued_rvs = rv_map_feature.rv_values.keys()
switch_cond = node.inputs[0]

encoding_list = []
initial_interval = (-np.inf, np.inf)

# fetch base_var as the only measurable input to the logical op in switch condition
measurable_switch_inp = [
component
for component in switch_cond.owner.inputs
if check_potential_measurability([component], valued_rvs)
]

if len(measurable_switch_inp) != 1:
return None

base_rv = measurable_switch_inp[0]

# We do not allow discrete RVs yet
if base_rv.dtype.startswith("int"):
return None

# Since we verify the source of measurability to be the same for switch conditions
# and all measurable components, denying broadcastability of the base_var is enough
if base_rv.type.broadcastable != node.outputs[0].type.broadcastable:
return None

encoding_list = flat_switch_helper(node, valued_rvs, encoding_list, initial_interval, base_rv)
if encoding_list is None:
return None

encodings, intervals = [], []
rv_idx = ()

# TODO: Some alternative cleaner way to do this
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is dirty about this approach? If so add comment, otherwise remove TODO?

for idx, item in enumerate(encoding_list):
encoding = item["encoding"]
# indices of intervals having base_rv as their "encoding"
if encoding == base_rv:
rv_idx += (idx,)

encodings.append(encoding)
intervals.extend((item["lower"], item["upper"]))

flat_switch_op = FlatSwitches(out_dtype=node.outputs[0].dtype, rv_idx=rv_idx)

new_outs = flat_switch_op.make_node(base_rv, *intervals, *encodings).default_output()
return [new_outs]


@_logprob.register(FlatSwitches)
def flat_switches_logprob(op, values, base_rv, *inputs, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def flat_switches_logprob(op, values, base_rv, *inputs, **kwargs):
def nested_switch_encoding_logprob(op, values, base_rv, *inputs, **kwargs):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add some comment in the docstrings about the kind of logp graphs we are generating?

from pymc.math import logdiffexp

(value,) = values

encodings_count = len(inputs) // 3
# 'inputs' is of the form (lower1, upper1, lower2, upper2, encoding1, encoding2)
# Possible TODO:
# encodings = op.get_encodings_from_inputs(inputs)
Comment on lines +479 to +480
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either implement or remove TODO since it's not a high priority anyway?

encodings = inputs[2 * encodings_count : 3 * encodings_count]
encodings = pt.broadcast_arrays(*encodings)

encodings = Assert(msg="all encodings should be unique")(
encodings, pt.eq(pt.unique(encodings, axis=0).shape[0], len(encodings))
)

# TODO: We do not support the encoding graphs of discrete RVs yet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove easy to miss TODO. Either mention in docstrings, or add a NotImplementedError that will make sure we don't forget to update the logp function?


interval_bounds = pt.broadcast_arrays(*inputs[0 : 2 * encodings_count])
lower_interval_bounds = interval_bounds[::2]
upper_interval_bounds = interval_bounds[1::2]

lower_interval_bounds = pt.concatenate([i[None] for i in lower_interval_bounds], axis=0)
upper_interval_bounds = pt.concatenate([j[None] for j in upper_interval_bounds], axis=0)

interval_bounds = pt.concatenate(
[lower_interval_bounds[None], upper_interval_bounds[None]], axis=0
)

# define a logcdf map on a scalar, use vectorize to calculate it for 2D intervals
scalar_interval_bound = pt.scalar("scalar_interval_bound", dtype=base_rv.dtype)
logcdf_scalar_interval_bound = _logcdf_helper(base_rv, scalar_interval_bound, **kwargs)
logcdf_interval_bounds = pytensor.graph.replace.vectorize(
logcdf_scalar_interval_bound, replace={scalar_interval_bound: interval_bounds}
)
logcdf_intervals = logdiffexp(
logcdf_interval_bounds[1, ...], logcdf_interval_bounds[0, ...]
) # (encoding, *base_rv.shape)

# default logprob is -inf if there is no RV in branches
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain what's happening in the first if branch as well?

if op.rv_idx:
logprob = _logprob_helper(base_rv, value, **kwargs)

# Add rv branch (and checks whether it is possible)
for i in op.rv_idx:
logprob = pt.where(
pt.and_(value <= upper_interval_bounds[i], value >= lower_interval_bounds[i]),
logprob,
-np.inf,
)
else:
logprob = -np.inf

for i in range(encodings_count):
# if encoding found in interval (Lower, Upper), then Prob = CDF(Upper) - CDF(Lower)
logprob = pt.where(
pt.eq(value, encodings[i]),
logcdf_intervals[i],
logprob,
)

return logprob


measurable_ir_rewrites_db.register(
"find_measurable_flat_switch_encoding",
find_measurable_flat_switch_encoding,
"basic",
"censoring",
)
Loading