Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 53 additions & 37 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,13 @@ def idata(self):

return self.model.idata

def print_coefficients(self) -> None:
def print_coefficients(self, round_to=None) -> None:
"""
Prints the model coefficients

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.

Example
--------
>>> import causalpy as cp
Expand All @@ -80,13 +83,13 @@ def print_coefficients(self) -> None:
... "progressbar": False
... }),
... )
>>> result.print_coefficients() # doctest: +NUMBER
>>> result.print_coefficients(round_to=1) # doctest: +NUMBER
Model coefficients:
Intercept 1.0, 94% HDI [1.0, 1.1]
post_treatment[T.True] 0.9, 94% HDI [0.9, 1.0]
group 0.1, 94% HDI [0.0, 0.2]
Intercept 1, 94% HDI [1, 1]
post_treatment[T.True] 1, 94% HDI [0.9, 1]
group 0.2, 94% HDI [0.09, 0.2]
group:post_treatment[T.True] 0.5, 94% HDI [0.4, 0.6]
sigma 0.0, 94% HDI [0.0, 0.1]
sigma 0.08, 94% HDI [0.07, 0.1]
"""
print("Model coefficients:")
coeffs = az.extract(self.idata.posterior, var_names="beta")
Expand All @@ -95,13 +98,13 @@ def print_coefficients(self) -> None:
for name in self.labels:
coeff_samples = coeffs.sel(coeffs=name)
print(
f"{name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]" # noqa: E501
f"{name: <30}{round_num(coeff_samples.mean().data, round_to)}, 94% HDI [{round_num(coeff_samples.quantile(0.03).data, round_to)}, {round_num(coeff_samples.quantile(1-0.03).data, round_to)}]" # noqa: E501
)
# add coeff for measurement std
coeff_samples = az.extract(self.model.idata.posterior, var_names="sigma")
name = "sigma"
print(
f"{name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]" # noqa: E501
f"{name: <30}{round_num(coeff_samples.mean().data, round_to)}, 94% HDI [{round_num(coeff_samples.quantile(0.03).data, round_to)}, {round_num(coeff_samples.quantile(1-0.03).data, round_to)}]" # noqa: E501
)


Expand Down Expand Up @@ -138,7 +141,7 @@ class PrePostFit(ExperimentalDesign):
... }
... ),
... )
>>> result.summary() # doctest: +NUMBER
>>> result.summary(round_to=1) # doctest: +NUMBER
==================================Pre-Post Fit==================================
Formula: actual ~ 0 + a + g
Model coefficients:
Expand Down Expand Up @@ -231,7 +234,7 @@ def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
Plot the results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))

Expand Down Expand Up @@ -331,15 +334,18 @@ def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):

return fig, ax

def summary(self) -> None:
def summary(self, round_to=None) -> None:
"""
Print text output summarising the results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""

print(f"{self.expt_type:=^80}")
print(f"Formula: {self.formula}")
# TODO: extra experiment specific outputs here
self.print_coefficients()
self.print_coefficients(round_to)


class InterruptedTimeSeries(PrePostFit):
Expand Down Expand Up @@ -420,7 +426,7 @@ def plot(self, plot_predictors=False, **kwargs):
"""Plot the results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = super().plot(counterfactual_label="Synthetic control", **kwargs)
if plot_predictors:
Expand Down Expand Up @@ -589,7 +595,7 @@ def plot(self, round_to=None):
"""Plot the results.

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = plt.subplots()

Expand Down Expand Up @@ -728,17 +734,19 @@ def _causal_impact_summary_stat(self, round_to=None) -> str:
causal_impact = f"{round_num(self.causal_impact.mean(), round_to)}, "
return f"Causal impact = {causal_impact + ci}"

def summary(self) -> None:
def summary(self, round_to=None) -> None:
"""
Print text output summarising the results
Print text output summarising the results.

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""

print(f"{self.expt_type:=^80}")
print(f"Formula: {self.formula}")
print("\nResults:")
# TODO: extra experiment specific outputs here
print(self._causal_impact_summary_stat())
self.print_coefficients()
print(round_num(self._causal_impact_summary_stat(), round_to))
self.print_coefficients(round_to)


class RegressionDiscontinuity(ExperimentalDesign):
Expand Down Expand Up @@ -894,7 +902,7 @@ def plot(self, round_to=None):
Plot the results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = plt.subplots()
# Plot raw data
Expand Down Expand Up @@ -943,9 +951,12 @@ def plot(self, round_to=None):
)
return fig, ax

def summary(self) -> None:
def summary(self, round_to: None) -> None:
"""
Print text output summarising the results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""

print(f"{self.expt_type:=^80}")
Expand All @@ -954,9 +965,9 @@ def summary(self) -> None:
print(f"Threshold on running variable: {self.treatment_threshold}")
print("\nResults:")
print(
f"Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f}"
f"Discontinuity at threshold = {round_num(self.discontinuity_at_threshold.mean(), round_to)}"
)
self.print_coefficients()
self.print_coefficients(round_to)


class RegressionKink(ExperimentalDesign):
Expand Down Expand Up @@ -1111,7 +1122,7 @@ def plot(self, round_to=None):
Plot the results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = plt.subplots()
# Plot raw data
Expand Down Expand Up @@ -1160,9 +1171,12 @@ def plot(self, round_to=None):
)
return fig, ax

def summary(self) -> None:
def summary(self, round_to=None) -> None:
"""
Print text output summarising the results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""

print(
Expand All @@ -1173,10 +1187,10 @@ def summary(self) -> None:
Kink point on running variable: {self.kink_point}

Results:
Change in slope at kink point = {self.gradient_change.mean():.2f}
Change in slope at kink point = {round_num(self.gradient_change.mean(), round_to)}
"""
)
self.print_coefficients()
self.print_coefficients(round_to)


class PrePostNEGD(ExperimentalDesign):
Expand Down Expand Up @@ -1213,17 +1227,17 @@ class PrePostNEGD(ExperimentalDesign):
... }
... )
... )
>>> result.summary() # doctest: +NUMBER
>>> result.summary(round_to=1) # doctest: +NUMBER
==================Pretest/posttest Nonequivalent Group Design===================
Formula: post ~ 1 + C(group) + pre
<BLANKLINE>
Results:
Causal impact = 1.8, $CI_{94%}$[1.7, 2.1]
Causal impact = 2, $CI_{94%}$[2, 2]
Model coefficients:
Intercept -0.4, 94% HDI [-1.1, 0.2]
C(group)[T.1] 1.8, 94% HDI [1.6, 2.0]
pre 1.0, 94% HDI [0.9, 1.1]
sigma 0.5, 94% HDI [0.4, 0.5]
Intercept -0.5, 94% HDI [-1, 0.2]
C(group)[T.1] 2, 94% HDI [2, 2]
pre 1, 94% HDI [1, 1]
sigma 0.5, 94% HDI [0.5, 0.6]
"""

def __init__(
Expand Down Expand Up @@ -1304,7 +1318,7 @@ def plot(self, round_to=None):
"""Plot the results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = plt.subplots(
2, 1, figsize=(7, 9), gridspec_kw={"height_ratios": [3, 1]}
Expand Down Expand Up @@ -1362,20 +1376,23 @@ def _causal_impact_summary_stat(self, round_to) -> str:
r"$CI_{94%}$"
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
)
causal_impact = f"{self.causal_impact.mean():.2f}, "
causal_impact = f"{round_num(self.causal_impact.mean(), round_to)}, "
return f"Causal impact = {causal_impact + ci}"

def summary(self, round_to=None) -> None:
"""
Print text output summarising the results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""

print(f"{self.expt_type:=^80}")
print(f"Formula: {self.formula}")
print("\nResults:")
# TODO: extra experiment specific outputs here
print(self._causal_impact_summary_stat(round_to))
self.print_coefficients()
self.print_coefficients(round_to)

def _get_treatment_effect_coeff(self) -> str:
"""Find the beta regression coefficient corresponding to the
Expand Down Expand Up @@ -1452,7 +1469,6 @@ class InstrumentalVariable(ExperimentalDesign):
... formula=formula,
... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs),
... )

"""

def __init__(
Expand Down
8 changes: 4 additions & 4 deletions causalpy/skl_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
"""Plot experiment results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))

Expand Down Expand Up @@ -270,7 +270,7 @@ def plot(self, plot_predictors=False, round_to=None, **kwargs):
"""Plot the results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = super().plot(
counterfactual_label="Synthetic control", round_to=round_to, **kwargs
Expand Down Expand Up @@ -415,7 +415,7 @@ def plot(self, round_to=None):
"""Plot results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = plt.subplots()

Expand Down Expand Up @@ -629,7 +629,7 @@ def plot(self, round_to=None):
"""Plot results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
fig, ax = plt.subplots()
# Plot raw data
Expand Down