Skip to content

Commit f577c2c

Browse files
maresbaseyboldtmichaelosthege
authored
Improve some type hints to bump mypy pin (#6294)
* Add a few missing type imports * Trade assert with assignment to keep mypy happy * Add a few type annotations * Add missing return type for __call__ * Switch comment type declaration to raw * Get operators.py to pass * Fix pymc.backends.report * Fix a bunch of typing issues * Import __future__.annotations to fix "| None" * Update pymc/step_methods/hmc/integration.py * Add __future__.annotations to hmc.py * Remove unused Any import * Don't cast float to np.array * Replace 0 with 0.0 for float zeros * Update pymc/step_methods/hmc/nuts.py Closes #6282 Co-authored-by: Adrian Seyboldt <[email protected]> Co-authored-by: Michael Osthege <[email protected]>
1 parent 5d7283e commit f577c2c

16 files changed

+121
-52
lines changed

conda-envs/environment-dev.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,5 @@ dependencies:
3838
- watermark
3939
- polyagamma
4040
- sphinx-remove-toctrees
41-
- mypy=0.982
41+
- mypy=0.990
4242
- types-cachetools

conda-envs/environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ dependencies:
2727
- pre-commit>=2.8.0
2828
- pytest-cov>=2.5
2929
- pytest>=3.0
30-
- mypy=0.982
30+
- mypy=0.990
3131
- types-cachetools

conda-envs/windows-environment-dev.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,5 @@ dependencies:
3535
- sphinx>=1.5
3636
- watermark
3737
- sphinx-remove-toctrees
38-
- mypy=0.982
38+
- mypy=0.990
3939
- types-cachetools

conda-envs/windows-environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@ dependencies:
2828
- pre-commit>=2.8.0
2929
- pytest-cov>=2.5
3030
- pytest>=3.0
31-
- mypy=0.982
31+
- mypy=0.990
3232
- types-cachetools

pymc/backends/report.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import dataclasses
1616
import logging
1717

18-
from typing import Optional
18+
from typing import Dict, List, Optional
1919

2020
import arviz
2121

@@ -32,7 +32,7 @@
3232
class SamplerReport:
3333
"""Bundle warnings, convergence stats and metadata of a sampling run."""
3434

35-
def __init__(self):
35+
def __init__(self) -> None:
3636
self._chain_warnings: Dict[int, List[SamplerWarning]] = {}
3737
self._global_warnings: List[SamplerWarning] = []
3838
self._n_tune = None

pymc/blocking.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
1818
Classes for working with subsets of parameters.
1919
"""
20-
import collections
20+
from __future__ import annotations
2121

2222
from functools import partial
23-
from typing import Callable, Dict, Generic, Optional, TypeVar
23+
from typing import Callable, Dict, Generic, NamedTuple, TypeVar
2424

2525
import numpy as np
2626

@@ -32,7 +32,9 @@
3232

3333
# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
3434
# each of the raveled variables.
35-
RaveledVars = collections.namedtuple("RaveledVars", "data, point_map_info")
35+
class RaveledVars(NamedTuple):
36+
data: np.ndarray
37+
point_map_info: tuple[tuple[str, tuple[int, ...], np.dtype], ...]
3638

3739

3840
class Compose(Generic[T]):
@@ -69,7 +71,7 @@ def map(var_dict: PointType) -> RaveledVars:
6971
@staticmethod
7072
def rmap(
7173
array: RaveledVars,
72-
start_point: Optional[PointType] = None,
74+
start_point: PointType | None = None,
7375
) -> PointType:
7476
"""Map 1D concatenated array to a dictionary of variables in their original spaces.
7577
@@ -100,7 +102,7 @@ def rmap(
100102

101103
@classmethod
102104
def mapf(
103-
cls, f: Callable[[PointType], T], start_point: Optional[PointType] = None
105+
cls, f: Callable[[PointType], T], start_point: PointType | None = None
104106
) -> Callable[[RaveledVars], T]:
105107
"""Create a callable that first maps back to ``dict`` inputs and then applies a function.
106108

pymc/gp/util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
# Avoid circular dependency when importing modelcontext
3232
from pymc.distributions.distribution import Distribution
3333

34-
assert Distribution # keep both pylint and black happy
34+
_ = Distribution # keep both pylint and black happy
3535
from pymc.model import modelcontext
3636

3737
JITTER_DEFAULT = 1e-6

pymc/step_methods/hmc/base_hmc.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import logging
1618
import time
1719

1820
from abc import abstractmethod
19-
from collections import namedtuple
20-
from typing import Optional
21+
from typing import Any, NamedTuple
2122

2223
import numpy as np
2324

@@ -29,20 +30,32 @@
2930
from pymc.step_methods import step_sizes
3031
from pymc.step_methods.arraystep import GradientSharedStep
3132
from pymc.step_methods.hmc import integration
33+
from pymc.step_methods.hmc.integration import IntegrationError, State
3234
from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
3335
from pymc.tuning import guess_scaling
3436
from pymc.util import get_value_vars_from_user_vars
3537

3638
logger = logging.getLogger("pymc")
3739

38-
HMCStepData = namedtuple("HMCStepData", "end, accept_stat, divergence_info, stats")
3940

40-
DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state, state_div")
41+
class DivergenceInfo(NamedTuple):
42+
message: str
43+
exec_info: IntegrationError | None
44+
state: State
45+
state_div: State | None
46+
47+
48+
class HMCStepData(NamedTuple):
49+
end: State
50+
accept_stat: int
51+
divergence_info: DivergenceInfo | None
52+
stats: dict[str, Any]
4153

4254

4355
class BaseHMC(GradientSharedStep):
4456
"""Superclass to implement Hamiltonian/hybrid monte carlo."""
4557

58+
integrator: integration.CpuLeapfrogIntegrator
4659
default_blocked = True
4760

4861
def __init__(
@@ -138,13 +151,13 @@ def __init__(
138151
self._num_divs_sample = 0
139152

140153
@abstractmethod
141-
def _hamiltonian_step(self, start, p0, step_size):
154+
def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData:
142155
"""Compute one Hamiltonian trajectory and return the next state.
143156
144157
Subclasses must overwrite this abstract method and return an `HMCStepData` object.
145158
"""
146159

147-
def astep(self, q0):
160+
def astep(self, q0: RaveledVars) -> tuple[RaveledVars, list[dict[str, Any]]]:
148161
"""Perform a single HMC iteration."""
149162
perf_start = time.perf_counter()
150163
process_start = time.process_time()
@@ -154,6 +167,7 @@ def astep(self, q0):
154167

155168
start = self.integrator.compute_state(q0, p0)
156169

170+
warning: SamplerWarning | None = None
157171
if not np.isfinite(start.energy):
158172
model = self._model
159173
check_test_point_dict = model.point_logps()
@@ -188,7 +202,6 @@ def astep(self, q0):
188202

189203
self.step_adapt.update(hmc_step.accept_stat, adapt_step)
190204
self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
191-
warning: Optional[SamplerWarning] = None
192205
if hmc_step.divergence_info:
193206
info = hmc_step.divergence_info
194207
point = None
@@ -221,7 +234,7 @@ def astep(self, q0):
221234

222235
self.iter_count += 1
223236

224-
stats = {
237+
stats: dict[str, Any] = {
225238
"tune": self.tune,
226239
"diverging": bool(hmc_step.divergence_info),
227240
"perf_counter_diff": perf_end - perf_start,

pymc/step_methods/hmc/hmc.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import Any
18+
1519
import numpy as np
1620

1721
from pymc.stats.convergence import SamplerWarning
@@ -119,7 +123,7 @@ def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs):
119123
self.path_length = path_length
120124
self.max_steps = max_steps
121125

122-
def _hamiltonian_step(self, start, p0, step_size):
126+
def _hamiltonian_step(self, start, p0, step_size: float) -> HMCStepData:
123127
n_steps = max(1, int(self.path_length / step_size))
124128
n_steps = min(self.max_steps, n_steps)
125129

@@ -156,7 +160,7 @@ def _hamiltonian_step(self, start, p0, step_size):
156160
end = state
157161
accepted = True
158162

159-
stats = {
163+
stats: dict[str, Any] = {
160164
"path_length": self.path_length,
161165
"n_steps": n_steps,
162166
"accept": accept_stat,

pymc/step_methods/hmc/integration.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,32 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from collections import namedtuple
15+
from typing import NamedTuple
1616

1717
import numpy as np
1818

1919
from scipy import linalg
2020

2121
from pymc.blocking import RaveledVars
22+
from pymc.step_methods.hmc.quadpotential import QuadPotential
2223

23-
State = namedtuple("State", "q, p, v, q_grad, energy, model_logp, index_in_trajectory")
24+
25+
class State(NamedTuple):
26+
q: RaveledVars
27+
p: RaveledVars
28+
v: np.ndarray
29+
q_grad: np.ndarray
30+
energy: float
31+
model_logp: float
32+
index_in_trajectory: int
2433

2534

2635
class IntegrationError(RuntimeError):
2736
pass
2837

2938

3039
class CpuLeapfrogIntegrator:
31-
def __init__(self, potential, logp_dlogp_func):
40+
def __init__(self, potential: QuadPotential, logp_dlogp_func):
3241
"""Leapfrog integrator using CPU."""
3342
self._potential = potential
3443
self._logp_dlogp_func = logp_dlogp_func
@@ -39,14 +48,14 @@ def __init__(self, potential, logp_dlogp_func):
3948
"don't match." % (self._potential.dtype, self._dtype)
4049
)
4150

42-
def compute_state(self, q, p):
51+
def compute_state(self, q: RaveledVars, p: RaveledVars):
4352
"""Compute Hamiltonian functions using a position and momentum."""
4453
if q.data.dtype != self._dtype or p.data.dtype != self._dtype:
4554
raise ValueError("Invalid dtype. Must be %s" % self._dtype)
4655

4756
logp, dlogp = self._logp_dlogp_func(q)
4857

49-
v = self._potential.velocity(p.data)
58+
v = self._potential.velocity(p.data, out=None)
5059
kinetic = self._potential.energy(p.data, velocity=v)
5160
energy = kinetic - logp
5261
return State(q, p, v, dlogp, energy, logp, 0)

pymc/step_methods/hmc/nuts.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
from collections import namedtuple
1618

1719
import numpy as np
@@ -20,8 +22,9 @@
2022
from pymc.math import logbern
2123
from pymc.stats.convergence import SamplerWarning
2224
from pymc.step_methods.arraystep import Competence
25+
from pymc.step_methods.hmc import integration
2326
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
24-
from pymc.step_methods.hmc.integration import IntegrationError
27+
from pymc.step_methods.hmc.integration import IntegrationError, State
2528
from pymc.vartypes import continuous_types
2629

2730
__all__ = ["NUTS"]
@@ -227,7 +230,14 @@ def competence(var, has_grad):
227230

228231

229232
class _Tree:
230-
def __init__(self, ndim, integrator, start, step_size, Emax):
233+
def __init__(
234+
self,
235+
ndim: int,
236+
integrator: integration.CpuLeapfrogIntegrator,
237+
start: State,
238+
step_size: float,
239+
Emax: float,
240+
):
231241
"""Binary tree from the NUTS algorithm.
232242
233243
Parameters
@@ -247,17 +257,17 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
247257
self.start = start
248258
self.step_size = step_size
249259
self.Emax = Emax
250-
self.start_energy = np.array(start.energy)
260+
self.start_energy = start.energy
251261

252262
self.left = self.right = start
253263
self.proposal = Proposal(start.q.data, start.q_grad, start.energy, start.model_logp, 0)
254264
self.depth = 0
255-
self.log_size = 0
265+
self.log_size = 0.0
256266
self.log_accept_sum = -np.inf
257267
self.mean_tree_accept = 0.0
258268
self.n_proposals = 0
259269
self.p_sum = start.p.data.copy()
260-
self.max_energy_change = 0
270+
self.max_energy_change = 0.0
261271

262272
def extend(self, direction):
263273
"""Double the treesize by extending the tree in the given direction.
@@ -315,16 +325,19 @@ def extend(self, direction):
315325

316326
return diverging, turning
317327

318-
def _single_step(self, left, epsilon):
328+
def _single_step(self, left: State, epsilon: float):
319329
"""Perform a leapfrog step and handle error cases."""
330+
right: State | None
331+
error: IntegrationError | None
332+
error_msg: str | None
320333
try:
321-
# `State` type
322334
right = self.integrator.step(epsilon, left)
323335
except IntegrationError as err:
324336
error_msg = str(err)
325337
error = err
326338
right = None
327339
else:
340+
assert right is not None # since there was no IntegrationError
328341
# h - H0
329342
energy_change = right.energy - self.start_energy
330343
if np.isnan(energy_change):
@@ -354,8 +367,8 @@ def _single_step(self, left, epsilon):
354367
finally:
355368
self.n_proposals += 1
356369
tree = Subtree(None, None, None, None, -np.inf)
357-
divergance_info = DivergenceInfo(error_msg, error, left, right)
358-
return tree, divergance_info, False
370+
divergence_info = DivergenceInfo(error_msg, error, left, right)
371+
return tree, divergence_info, False
359372

360373
def _build_subtree(self, left, depth, epsilon):
361374
if depth == 0:

pymc/step_methods/hmc/quadpotential.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import warnings
1618

19+
from typing import overload
20+
1721
import aesara
1822
import numpy as np
1923
import scipy.linalg
@@ -94,7 +98,17 @@ def __str__(self):
9498

9599

96100
class QuadPotential:
97-
def velocity(self, x, out=None):
101+
dtype: np.dtype
102+
103+
@overload
104+
def velocity(self, x: np.ndarray, out: None) -> np.ndarray:
105+
...
106+
107+
@overload
108+
def velocity(self, x: np.ndarray, out: np.ndarray) -> None:
109+
...
110+
111+
def velocity(self, x: np.ndarray, out: np.ndarray | None = None) -> np.ndarray | None:
98112
"""Compute the current velocity at a position in parameter space."""
99113
raise NotImplementedError("Abstract method")
100114

0 commit comments

Comments
 (0)