-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
FlatSwitch Op for logprob derivation of arbitrary censoring #6949
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FlatSwitch Op for logprob derivation of arbitrary censoring #6949
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #6949 +/- ##
==========================================
- Coverage 92.26% 87.21% -5.06%
==========================================
Files 100 100
Lines 16880 17009 +129
==========================================
- Hits 15574 14834 -740
- Misses 1306 2175 +869
|
86ae712
to
89d4635
Compare
44e5423
to
96e0bb0
Compare
@shreyas3156 could you solve the conflicts issue? I'll finally review this one :) |
ebcee22
to
db20cef
Compare
One of the pre-existing tests is failing, not sure if due to the changes but would guess so? https://github.com/pymc-devs/pymc/actions/runs/8169791795/job/22334588152?pr=6949#step:7:478 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I left comments about need for docstrings and rename some stuff / remove TODO comments
@@ -238,3 +260,282 @@ def round_logprob(op, values, base_rv, **kwargs): | |||
from pymc.math import logdiffexp | |||
|
|||
return logdiffexp(logcdf_upper, logcdf_lower) | |||
|
|||
|
|||
class FlatSwitches(Op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docstrings explaining what this Op does / where it is used for. Most importantly what is the IR representation this Op uses for what kind of original graphs. This can be done either here or in the main rewrite. If on the main rewrite, just mention here to check out the docstring in the rewrite.
MeasurableVariable.register(FlatSwitches) | ||
|
||
|
||
def get_intervals(binary_node, valued_rvs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docstrings explaining what this does, possibly also input/output type hints
return [interval_true, interval_false] | ||
|
||
|
||
def adjust_intervals(intervals, outer_interval): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docstrings and possibly type hints
|
||
|
||
@node_rewriter(tracks=[switch]) | ||
def find_measurable_flat_switch_encoding(fgraph: FunctionGraph, node: Node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion, because "flat" is the IR output, not what is being found?
def find_measurable_flat_switch_encoding(fgraph: FunctionGraph, node: Node): | |
def find_nested_switch_encoding(fgraph: FunctionGraph, node: Node): |
encodings, intervals = [], [] | ||
rv_idx = () | ||
|
||
# TODO: Some alternative cleaner way to do this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is dirty about this approach? If so add comment, otherwise remove TODO?
|
||
|
||
@_logprob.register(FlatSwitches) | ||
def flat_switches_logprob(op, values, base_rv, *inputs, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def flat_switches_logprob(op, values, base_rv, *inputs, **kwargs): | |
def nested_switch_encoding_logprob(op, values, base_rv, *inputs, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also add some comment in the docstrings about the kind of logp graphs we are generating?
encodings, pt.eq(pt.unique(encodings, axis=0).shape[0], len(encodings)) | ||
) | ||
|
||
# TODO: We do not support the encoding graphs of discrete RVs yet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove easy to miss TODO. Either mention in docstrings, or add a NotImplementedError that will make sure we don't forget to update the logp function?
logcdf_interval_bounds[1, ...], logcdf_interval_bounds[0, ...] | ||
) # (encoding, *base_rv.shape) | ||
|
||
# default logprob is -inf if there is no RV in branches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain what's happening in the first if branch as well?
# Possible TODO: | ||
# encodings = op.get_encodings_from_inputs(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either implement or remove TODO since it's not a high priority anyway?
@@ -238,3 +260,282 @@ def round_logprob(op, values, base_rv, **kwargs): | |||
from pymc.math import logdiffexp | |||
|
|||
return logdiffexp(logcdf_upper, logcdf_lower) | |||
|
|||
|
|||
class FlatSwitches(Op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be more intuitive?
class FlatSwitches(Op): | |
class NestedEncodingSwitches(Op): |
What is this PR about?
This PR defines a FlatSwitch Op that aims to extract the intervals and their respective encodings required to infer the logprob of arbitrary censored distributions. It achieves this in the following steps:
pt.switch()
recursively.It then checks that we don't allow the broadcastability of a switch condition or any measurable branches, and if all the measurable components have the same source of measurability. The logic for these checks is based on #6834.
Once the intervals and their respective encodings are known, they can be used to calculate the log-probability. So, on running something like
we get something like:
TO-DO:
pt.switch(x>0, x, a)
but also onpt.switch(pt.exp(x)>0, pt.exp(x), b)
, where a and b are some encodings.base_rv
, intervals list and the corresponding encodings as inputs to the node so that they can be unpacked in the logprob calculation.Checklist
@ricardoV94 @larryshamalama