Skip to content

Commit 6dbe156

Browse files
committed
Improve test
1 parent d14cb3a commit 6dbe156

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

pymc/distributions/timeseries.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ def garch11_logp(
749749
op, values, omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng, **kwargs
750750
):
751751
(value,) = values
752+
# Move the time axis to the first dimension
752753
value_dimswapped = value.dimshuffle((value.ndim - 1,) + tuple(range(0, value.ndim - 1)))
753754
initial_vol = initial_vol * at.ones_like(value_dimswapped[0])
754755

@@ -760,11 +761,11 @@ def volatility_update(x, vol, w, a, b):
760761
sequences=[value_dimswapped[:-1]],
761762
outputs_info=[initial_vol],
762763
non_sequences=[omega, alpha_1, beta_1],
763-
strict = True,
764+
strict=True,
764765
)
765766
sigma_t = at.concatenate([[initial_vol], vol])
766767
# Compute and collapse logp across time dimension
767-
innov_logp = at.sum(logp(Normal.dist(0, sigma_t), value_dimswapped), axis=-1)
768+
innov_logp = at.sum(logp(Normal.dist(0, sigma_t), value_dimswapped), axis=0)
768769
return innov_logp
769770

770771

pymc/tests/distributions/test_timeseries.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -547,19 +547,27 @@ def test_batched_size(self, explicit_shape, batched_param):
547547
initial_vol=2.5,
548548
)
549549
kwargs0 = init_kwargs.copy()
550-
kwargs0[arg_name] = init_kwargs[arg_name] * param_val
550+
kwargs0[batched_param] = init_kwargs[batched_param] * param_val
551+
if explicit_shape:
552+
kwargs0["shape"] = (batch_size, steps)
553+
else:
554+
kwargs0["steps"] = steps - 1
551555
with Model() as t0:
552-
y = GARCH11("y", shape=(batch_size, steps), **kwargs0)
556+
y = GARCH11("y", **kwargs0)
553557

554558
y_eval = draw(y, draws=2)
555559
assert y_eval[0].shape == (batch_size, steps)
556560
assert not np.any(np.isclose(y_eval[0], y_eval[1]))
557561

558562
kwargs1 = init_kwargs.copy()
563+
if explicit_shape:
564+
kwargs1["shape"] = steps
565+
else:
566+
kwargs1["steps"] = steps - 1
559567
with Model() as t1:
560568
for i in range(batch_size):
561-
kwargs1[arg_name] = init_kwargs[arg_name] * param_val[i]
562-
GARCH11(f"y_{i}", shape=steps, **kwargs1)
569+
kwargs1[batched_param] = init_kwargs[batched_param] * param_val[i]
570+
GARCH11(f"y_{i}", **kwargs1)
563571

564572
np.testing.assert_allclose(
565573
t0.compile_logp()(t0.initial_point()),
@@ -584,7 +592,7 @@ def test_moment(self, size, expected):
584592
steps=7,
585593
size=size,
586594
)
587-
assert_moment_is_expected(model, expected, check_finite_logp=False)
595+
assert_moment_is_expected(model, expected, check_finite_logp=True)
588596

589597
def test_change_dist_size(self):
590598
base_dist = pm.GARCH11.dist(

0 commit comments

Comments
 (0)