Skip to content

FlatSwitch Op for logprob derivation of arbitrary censoring #6949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: main
Choose a base branch
from

Conversation

shreyas3156
Copy link
Contributor

@shreyas3156 shreyas3156 commented Oct 12, 2023

What is this PR about?
This PR defines a FlatSwitch Op that aims to extract the intervals and their respective encodings required to infer the logprob of arbitrary censored distributions. It achieves this in the following steps:

  1. Extract the intervals defined by the condition in pt.switch()recursively.
  2. Adjust/limit these intervals to eliminate the overlap from the outer switch.
  3. Identify the intervals that the true and false branches correspond to since each condition splits the space into two parts.

It then checks that we don't allow the broadcastability of a switch condition or any measurable branches, and if all the measurable components have the same source of measurability. The logic for these checks is based on #6834.

Once the intervals and their respective encodings are known, they can be used to calculate the log-probability. So, on running something like

rv2 = pt.switch(
    base_rv < -1,
    -1,
    pt.switch(
        base_rv < 1,  # -inf to 2, 2 to inf
        1,
        base_rv
    ),
)

we get something like:

lower: -1.0
upper: 1.0
encoding: 1 

lower: 1.0
upper: inf
encoding: normal_rv{0, (0, 0), floatX, False}.out 

TO-DO:

  • The checks should work not only on pt.switch(x>0, x, a) but also on pt.switch(pt.exp(x)>0, pt.exp(x), b), where a and b are some encodings.
  • In the FlatSwitch Op, add base_rv, intervals list and the corresponding encodings as inputs to the node so that they can be unpacked in the logprob calculation.

Checklist

@ricardoV94 @larryshamalama

@shreyas3156 shreyas3156 marked this pull request as draft October 12, 2023 06:48
@codecov
Copy link

codecov bot commented Oct 12, 2023

Codecov Report

Attention: Patch coverage is 17.91045% with 110 lines in your changes are missing coverage. Please review.

Project coverage is 87.21%. Comparing base (244fb97) to head (b5f26a4).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6949      +/-   ##
==========================================
- Coverage   92.26%   87.21%   -5.06%     
==========================================
  Files         100      100              
  Lines       16880    17009     +129     
==========================================
- Hits        15574    14834     -740     
- Misses       1306     2175     +869     
Files Coverage Δ
pymc/logprob/censoring.py 34.23% <17.91%> (-61.47%) ⬇️

... and 21 files with indirect coverage changes

@shreyas3156 shreyas3156 force-pushed the logprob-flatswitch-arbitrary-censoring branch from 86ae712 to 89d4635 Compare November 23, 2023 15:39
@shreyas3156 shreyas3156 force-pushed the logprob-flatswitch-arbitrary-censoring branch from 44e5423 to 96e0bb0 Compare December 12, 2023 16:11
@shreyas3156 shreyas3156 marked this pull request as ready for review January 5, 2024 05:39
@ricardoV94 ricardoV94 changed the title FlatSwitch Op for logprob derivation of arbitrary censoring [WIP] FlatSwitch Op for logprob derivation of arbitrary censoring Mar 4, 2024
@ricardoV94
Copy link
Member

ricardoV94 commented Mar 4, 2024

@shreyas3156 could you solve the conflicts issue? I'll finally review this one :)

@shreyas3156 shreyas3156 force-pushed the logprob-flatswitch-arbitrary-censoring branch from ebcee22 to db20cef Compare March 6, 2024 08:24
@ricardoV94
Copy link
Member

One of the pre-existing tests is failing, not sure if due to the changes but would guess so?

https://github.com/pymc-devs/pymc/actions/runs/8169791795/job/22334588152?pr=6949#step:7:478

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I left comments about need for docstrings and rename some stuff / remove TODO comments

@@ -238,3 +260,282 @@ def round_logprob(op, values, base_rv, **kwargs):
from pymc.math import logdiffexp

return logdiffexp(logcdf_upper, logcdf_lower)


class FlatSwitches(Op):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstrings explaining what this Op does / where it is used for. Most importantly what is the IR representation this Op uses for what kind of original graphs. This can be done either here or in the main rewrite. If on the main rewrite, just mention here to check out the docstring in the rewrite.

MeasurableVariable.register(FlatSwitches)


def get_intervals(binary_node, valued_rvs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstrings explaining what this does, possibly also input/output type hints

return [interval_true, interval_false]


def adjust_intervals(intervals, outer_interval):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstrings and possibly type hints



@node_rewriter(tracks=[switch])
def find_measurable_flat_switch_encoding(fgraph: FunctionGraph, node: Node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion, because "flat" is the IR output, not what is being found?

Suggested change
def find_measurable_flat_switch_encoding(fgraph: FunctionGraph, node: Node):
def find_nested_switch_encoding(fgraph: FunctionGraph, node: Node):

encodings, intervals = [], []
rv_idx = ()

# TODO: Some alternative cleaner way to do this
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is dirty about this approach? If so add comment, otherwise remove TODO?



@_logprob.register(FlatSwitches)
def flat_switches_logprob(op, values, base_rv, *inputs, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def flat_switches_logprob(op, values, base_rv, *inputs, **kwargs):
def nested_switch_encoding_logprob(op, values, base_rv, *inputs, **kwargs):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add some comment in the docstrings about the kind of logp graphs we are generating?

encodings, pt.eq(pt.unique(encodings, axis=0).shape[0], len(encodings))
)

# TODO: We do not support the encoding graphs of discrete RVs yet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove easy to miss TODO. Either mention in docstrings, or add a NotImplementedError that will make sure we don't forget to update the logp function?

logcdf_interval_bounds[1, ...], logcdf_interval_bounds[0, ...]
) # (encoding, *base_rv.shape)

# default logprob is -inf if there is no RV in branches
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain what's happening in the first if branch as well?

Comment on lines +479 to +480
# Possible TODO:
# encodings = op.get_encodings_from_inputs(inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either implement or remove TODO since it's not a high priority anyway?

@@ -238,3 +260,282 @@ def round_logprob(op, values, base_rv, **kwargs):
from pymc.math import logdiffexp

return logdiffexp(logcdf_upper, logcdf_lower)


class FlatSwitches(Op):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be more intuitive?

Suggested change
class FlatSwitches(Op):
class NestedEncodingSwitches(Op):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants