Skip to content

Commit 91cbebd

Browse files
authored
Port GARCH11 to v4 (#6119)
1 parent c53cd2f commit 91cbebd

File tree

3 files changed

+247
-90
lines changed

3 files changed

+247
-90
lines changed

pymc/distributions/timeseries.py

Lines changed: 127 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from aeppl.abstract import _get_measurable_outputs
2323
from aeppl.logprob import _logprob
24-
from aesara import scan
2524
from aesara.graph import FunctionGraph, rewrite_graph
2625
from aesara.graph.basic import Node, clone_replace
2726
from aesara.raise_op import Assert
@@ -230,7 +229,7 @@ def random_walk_moment(op, rv, init_dist, innovation_dist, steps):
230229

231230
@_logprob.register(RandomWalkRV)
232231
def random_walk_logp(op, values, *inputs, **kwargs):
233-
# ALthough Aeppl can derive the logprob of random walks, it does not collapse
232+
# Although Aeppl can derive the logprob of random walks, it does not collapse
234233
# what PyMC considers the core dimension of steps. We do it manually here.
235234
(value,) = values
236235
# Recreate RV and obtain inner graph
@@ -309,7 +308,6 @@ def get_dists(cls, *, mu, sigma, init_dist, **kwargs):
309308
class AutoRegressiveRV(SymbolicRandomVariable):
310309
"""A placeholder used to specify a log-likelihood for an AR sub-graph."""
311310

312-
_print_name = ("AR", "\\operatorname{AR}")
313311
default_output = 1
314312
ar_order: int
315313
constant_term: bool
@@ -616,17 +614,29 @@ def ar_moment(op, rv, rhos, sigma, init_dist, steps, noise_rng):
616614
return at.full_like(rv, moment(init_dist)[..., -1, None])
617615

618616

619-
class GARCH11(distribution.Continuous):
617+
class GARCH11RV(SymbolicRandomVariable):
618+
"""A placeholder used to specify a GARCH11 graph."""
619+
620+
default_output = 1
621+
_print_name = ("GARCH11", "\\operatorname{GARCH11}")
622+
623+
def update(self, node: Node):
624+
"""Return the update mapping for the noise RV."""
625+
# Since noise is a shared variable it shows up as the last node input
626+
return {node.inputs[-1]: node.outputs[0]}
627+
628+
629+
class GARCH11(Distribution):
620630
r"""
621631
GARCH(1,1) with Normal innovations. The model is specified by
622632
623633
.. math::
624-
y_t = \sigma_t * z_t
634+
y_t \sim N(0, \sigma_t^2)
625635
626636
.. math::
627637
\sigma_t^2 = \omega + \alpha_1 * y_{t-1}^2 + \beta_1 * \sigma_{t-1}^2
628638
629-
with z_t iid and Normal with mean zero and unit standard deviation.
639+
where \sigma_t^2 (the error variance) follows a ARMA(1, 1) model.
630640
631641
Parameters
632642
----------
@@ -640,54 +650,129 @@ class GARCH11(distribution.Continuous):
640650
initial_vol >= 0, initial volatility, sigma_0
641651
"""
642652

643-
def __new__(cls, *args, **kwargs):
644-
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
653+
rv_type = GARCH11RV
654+
655+
def __new__(cls, *args, steps=None, **kwargs):
656+
steps = get_steps(
657+
steps=steps,
658+
shape=None, # Shape will be checked in `cls.dist`
659+
dims=kwargs.get("dims", None),
660+
observed=kwargs.get("observed", None),
661+
step_shape_offset=1,
662+
)
663+
return super().__new__(cls, *args, steps=steps, **kwargs)
645664

646665
@classmethod
647-
def dist(cls, *args, **kwargs):
648-
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
666+
def dist(cls, omega, alpha_1, beta_1, initial_vol, *, steps=None, **kwargs):
667+
steps = get_steps(steps=steps, shape=kwargs.get("shape", None), step_shape_offset=1)
668+
if steps is None:
669+
raise ValueError("Must specify steps or shape parameter")
670+
steps = at.as_tensor_variable(intX(steps), ndim=0)
649671

650-
def __init__(self, omega, alpha_1, beta_1, initial_vol, *args, **kwargs):
651-
super().__init__(*args, **kwargs)
672+
omega = at.as_tensor_variable(omega)
673+
alpha_1 = at.as_tensor_variable(alpha_1)
674+
beta_1 = at.as_tensor_variable(beta_1)
675+
initial_vol = at.as_tensor_variable(initial_vol)
652676

653-
self.omega = omega = at.as_tensor_variable(omega)
654-
self.alpha_1 = alpha_1 = at.as_tensor_variable(alpha_1)
655-
self.beta_1 = beta_1 = at.as_tensor_variable(beta_1)
656-
self.initial_vol = at.as_tensor_variable(initial_vol)
657-
self.mean = at.as_tensor_variable(0.0)
677+
init_dist = Normal.dist(0, initial_vol)
678+
# Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term
679+
init_dist = ignore_logprob(init_dist)
680+
681+
return super().dist([omega, alpha_1, beta_1, initial_vol, init_dist, steps], **kwargs)
658682

659-
def get_volatility(self, x):
660-
x = x[:-1]
683+
@classmethod
684+
def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None):
685+
if size is not None:
686+
batch_size = size
687+
else:
688+
# In this case the size of the init_dist depends on the parameters shape
689+
batch_size = at.broadcast_shape(omega, alpha_1, beta_1, initial_vol)
690+
init_dist = change_dist_size(init_dist, batch_size)
691+
# initial_vol = initial_vol * at.ones(batch_size)
661692

662-
def volatility_update(x, vol, w, a, b):
663-
return at.sqrt(w + a * at.square(x) + b * at.square(vol))
693+
# Create OpFromGraph representing random draws from GARCH11 process
694+
# Variables with underscore suffix are dummy inputs into the OpFromGraph
695+
init_ = init_dist.type()
696+
initial_vol_ = initial_vol.type()
697+
omega_ = omega.type()
698+
alpha_1_ = alpha_1.type()
699+
beta_1_ = beta_1.type()
700+
steps_ = steps.type()
664701

665-
vol, _ = scan(
666-
fn=volatility_update,
667-
sequences=[x],
668-
outputs_info=[self.initial_vol],
669-
non_sequences=[self.omega, self.alpha_1, self.beta_1],
702+
noise_rng = aesara.shared(np.random.default_rng())
703+
704+
def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
705+
new_sigma = at.sqrt(
706+
omega + alpha_1 * at.square(prev_y) + beta_1 * at.square(prev_sigma)
707+
)
708+
next_rng, new_y = Normal.dist(mu=0, sigma=new_sigma, rng=rng).owner.outputs
709+
return (new_y, new_sigma), {rng: next_rng}
710+
711+
(y_t, _), innov_updates_ = aesara.scan(
712+
fn=step,
713+
outputs_info=[init_, initial_vol_ * at.ones(batch_size)],
714+
non_sequences=[omega_, alpha_1_, beta_1_, noise_rng],
715+
n_steps=steps_,
716+
strict=True,
670717
)
671-
return at.concatenate([[self.initial_vol], vol])
718+
(noise_next_rng,) = tuple(innov_updates_.values())
672719

673-
def logp(self, x):
674-
"""
675-
Calculate log-probability of GARCH(1, 1) distribution at specified value.
720+
garch11_ = at.concatenate([init_[None, ...], y_t], axis=0).dimshuffle(
721+
tuple(range(1, y_t.ndim)) + (0,)
722+
)
676723

677-
Parameters
678-
----------
679-
x: numeric
680-
Value for which log-probability is calculated.
724+
garch11_op = GARCH11RV(
725+
inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_],
726+
outputs=[noise_next_rng, garch11_],
727+
ndim_supp=1,
728+
)
681729

682-
Returns
683-
-------
684-
TensorVariable
685-
"""
686-
vol = self.get_volatility(x)
687-
return at.sum(Normal.dist(0.0, sigma=vol).logp(x))
730+
garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps)
731+
return garch11
688732

689-
def _distr_parameters_for_repr(self):
690-
return ["omega", "alpha_1", "beta_1"]
733+
734+
@_change_dist_size.register(GARCH11RV)
735+
def change_garch11_size(op, dist, new_size, expand=False):
736+
737+
if expand:
738+
old_size = dist.shape[:-1]
739+
new_size = tuple(new_size) + tuple(old_size)
740+
741+
return GARCH11.rv_op(
742+
*dist.owner.inputs[:-1],
743+
size=new_size,
744+
)
745+
746+
747+
@_logprob.register(GARCH11RV)
748+
def garch11_logp(
749+
op, values, omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng, **kwargs
750+
):
751+
(value,) = values
752+
# Move the time axis to the first dimension
753+
value_dimswapped = value.dimshuffle((value.ndim - 1,) + tuple(range(0, value.ndim - 1)))
754+
initial_vol = initial_vol * at.ones_like(value_dimswapped[0])
755+
756+
def volatility_update(x, vol, w, a, b):
757+
return at.sqrt(w + a * at.square(x) + b * at.square(vol))
758+
759+
vol, _ = aesara.scan(
760+
fn=volatility_update,
761+
sequences=[value_dimswapped[:-1]],
762+
outputs_info=[initial_vol],
763+
non_sequences=[omega, alpha_1, beta_1],
764+
strict=True,
765+
)
766+
sigma_t = at.concatenate([[initial_vol], vol])
767+
# Compute and collapse logp across time dimension
768+
innov_logp = at.sum(logp(Normal.dist(0, sigma_t), value_dimswapped), axis=0)
769+
return innov_logp
770+
771+
772+
@_moment.register(GARCH11RV)
773+
def garch11_moment(op, rv, omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng):
774+
# GARCH(1,1) mean is zero
775+
return at.zeros_like(rv)
691776

692777

693778
class EulerMaruyama(distribution.Continuous):

pymc/tests/distributions/test_distribution.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def test_all_distributions_have_moments():
109109

110110
# Distributions that have not been refactored for V4 yet
111111
not_implemented = {
112-
dist_module.timeseries.GARCH11,
113112
dist_module.timeseries.MvGaussianRandomWalk,
114113
dist_module.timeseries.MvStudentTRandomWalk,
115114
dist_module.timeseries.EulerMaruyama,

0 commit comments

Comments
 (0)