Skip to content

Enable mypy analysis for all files. #593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Nov 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b5db98c
[WIP] remaining mypy annotations
ianyfan Oct 26, 2022
a1aa210
makedir -> mkdir in test_bc
ianyfan Oct 26, 2022
d247f73
Fix remaining issues
ianyfan Oct 26, 2022
708ebe1
Remove TYPE_CHECKING guard
ianyfan Oct 26, 2022
57be740
Add extra checks for horizon type
ianyfan Oct 28, 2022
b40046a
Cleanup various annotations
ianyfan Oct 28, 2022
3e9ac54
fix mistake
ianyfan Oct 28, 2022
f3805f8
Revert type annotation in density
ianyfan Oct 28, 2022
12605f0
Bump seals to version with int horizon and update type checks
AdamGleave Oct 28, 2022
40d6d40
Revert some out-of-scope type changes
AdamGleave Oct 28, 2022
e990d74
Check for infinite horizon value error
AdamGleave Oct 28, 2022
b89b7f0
Update TODO to point to GH issue
AdamGleave Oct 28, 2022
dc76fb8
Check gradients are non-None
AdamGleave Oct 28, 2022
758936a
Bugfix: add p.grad not p
AdamGleave Nov 2, 2022
081a3cf
Reflow comment
AdamGleave Nov 2, 2022
b134fda
Make seals dependency >= to allow new versions
AdamGleave Nov 4, 2022
14ecd9c
Merge remote-tracking branch 'origin/master' into ianyfan/typing
AdamGleave Nov 4, 2022
cf59783
Add overload for type
AdamGleave Nov 4, 2022
a376e93
Comment why we're checking for non-None grads
AdamGleave Nov 4, 2022
6416750
Fix overloads
AdamGleave Nov 4, 2022
42f3aec
Remove duplicate TypeVar definition
AdamGleave Nov 4, 2022
cb5371a
Fix lint
AdamGleave Nov 4, 2022
26cd166
Fix type error introduced in upstream merge
AdamGleave Nov 4, 2022
a4b080f
Pin pyglet to solve glPush bug
AdamGleave Nov 4, 2022
0cd35d7
Merge remote-tracking branch 'origin/master' into ianyfan/typing
Rocamonde Nov 7, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,6 @@ executors:
# If you change these, also change ci/code_checks.sh
SRC_FILES: src/ tests/ experiments/ examples/ docs/conf.py setup.py ci/
NUM_CPUS: 2
EXCLUDE_MYPY: |
(?x)(
src/imitation/algorithms/preference_comparisons.py$
| src/imitation/rewards/serialize.py$
| src/imitation/algorithms/mce_irl.py$
| tests/algorithms/test_bc.py$
)

commands:
dependencies-linux:
Expand Down Expand Up @@ -280,7 +273,7 @@ jobs:

- run:
name: mypy
command: mypy --version && mypy ${SRC_FILES[@]} --exclude "${EXCLUDE_MYPY}" --follow-imports=silent --show-error-codes
command: mypy --version && mypy ${SRC_FILES[@]} --follow-imports=silent --show-error-codes

unit-test-linux:
executor: unit-test-linux
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"torch>=1.4.0",
"tqdm",
"scikit-learn>=0.21.2",
"seals==0.1.4",
"seals>=0.1.5",
STABLE_BASELINES3,
# TODO(adam) switch to upstream release if they make it
# See https://github.com/IDSIA/sacred/issues/879
Expand Down
16 changes: 13 additions & 3 deletions src/imitation/algorithms/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
"""Module of base classes and helper methods for imitation learning algorithms."""

import abc
from typing import Any, Generic, Iterable, Mapping, Optional, TypeVar, Union, cast
from typing import (
Any,
Generic,
Iterable,
Iterator,
Mapping,
Optional,
TypeVar,
Union,
cast,
)

import numpy as np
import torch as th
Expand Down Expand Up @@ -59,7 +69,7 @@ def __init__(
self._horizon = None

@property
def logger(self):
def logger(self) -> imit_logger.HierarchicalLogger:
return self._logger

@logger.setter
Expand Down Expand Up @@ -191,7 +201,7 @@ def __init__(
self.data_loader = data_loader
self.expected_batch_size = expected_batch_size

def __iter__(self):
def __iter__(self) -> Iterator[TransitionMapping]:
"""Yields data from `self.data_loader`, checking `self.expected_batch_size`.

Yields:
Expand Down
4 changes: 2 additions & 2 deletions src/imitation/algorithms/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class BatchIteratorWithEpochEndCallback:
n_batches: Optional[int]
on_epoch_end: Optional[Callable[[int], None]]

def __post_init__(self):
def __post_init__(self) -> None:
epochs_and_batches_specified = (
self.n_epochs is not None and self.n_batches is not None
)
Expand All @@ -56,7 +56,7 @@ def __post_init__(self):
)

def __iter__(self) -> Iterator[algo_base.TransitionMapping]:
def batch_iterator():
def batch_iterator() -> Iterator[algo_base.TransitionMapping]:

# Note: the islice here ensures we do not exceed self.n_epochs
for epoch_num in itertools.islice(itertools.count(), self.n_epochs):
Expand Down
17 changes: 10 additions & 7 deletions src/imitation/algorithms/dagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
import os
import pathlib
from typing import Callable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch as th
Expand Down Expand Up @@ -44,7 +44,7 @@ def __call__(self, round_num: int) -> float:
class LinearBetaSchedule(BetaSchedule):
"""Linearly-decreasing schedule for beta."""

def __init__(self, rampdown_rounds: int):
def __init__(self, rampdown_rounds: int) -> None:
"""Builds LinearBetaSchedule.

Args:
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
beta: float,
save_dir: types.AnyPath,
rng: np.random.Generator,
):
) -> None:
"""Builds InteractiveTrajectoryCollector.

Args:
Expand All @@ -162,7 +162,7 @@ def __init__(
self._last_user_actions = None
self.rng = rng

def seed(self, seed=Optional[int]) -> List[Union[None, int]]:
def seed(self, seed: Optional[int] = None) -> List[Optional[int]]:
"""Set the seed for the DAgger random number generator and wrapped VecEnv.

The DAgger RNG is used along with `self.beta` to determine whether the expert
Expand Down Expand Up @@ -360,7 +360,7 @@ def policy(self) -> policies.BasePolicy:
def batch_size(self) -> int:
return self.bc_trainer.batch_size

def _load_all_demos(self):
def _load_all_demos(self) -> Tuple[types.Transitions, List[int]]:
num_demos_by_round = []
for round_num in range(self._last_loaded_round + 1, self.round_num + 1):
round_dir = self._demo_dir_path_for_round(round_num)
Expand All @@ -371,7 +371,7 @@ def _load_all_demos(self):
demo_transitions = rollout.flatten_trajectories(self._all_demos)
return demo_transitions, num_demos_by_round

def _get_demo_paths(self, round_dir):
def _get_demo_paths(self, round_dir: pathlib.Path) -> List[pathlib.Path]:
return [round_dir / p for p in os.listdir(round_dir) if p.endswith(".npz")]

def _demo_dir_path_for_round(self, round_num: Optional[int] = None) -> pathlib.Path:
Expand Down Expand Up @@ -411,7 +411,10 @@ def _try_load_demos(self) -> None:
self.bc_trainer.set_demonstrations(data_loader)
self._last_loaded_round = self.round_num

def extend_and_update(self, bc_train_kwargs: Optional[Mapping] = None) -> int:
def extend_and_update(
self,
bc_train_kwargs: Optional[Mapping[str, Any]] = None,
) -> int:
"""Extend internal batch of data and train BC.

Specifically, this method will load new transitions (if necessary), train
Expand Down
6 changes: 3 additions & 3 deletions src/imitation/algorithms/density.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import enum
import itertools
from collections.abc import Mapping
from typing import Dict, Iterable, List, Optional, cast
from typing import Any, Dict, Iterable, List, Optional, cast

import numpy as np
from gym.spaces.utils import flatten
Expand Down Expand Up @@ -226,7 +226,7 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None:
None: np.concatenate(list(self.transitions.values()), axis=0),
}

def train(self):
def train(self) -> None:
"""Fits the density model to demonstration data `self.transitions`."""
# if requested, we'll scale demonstration transitions so that they have
# zero mean and unit variance (i.e. all components are equally important)
Expand Down Expand Up @@ -343,7 +343,7 @@ def __call__(
rew_array = np.asarray(rew_list, dtype="float32")
return rew_array

def train_policy(self, n_timesteps: int = int(1e6), **kwargs):
def train_policy(self, n_timesteps: int = int(1e6), **kwargs: Any) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could try to be more specific than this, but ParamSpec is not (yet) a very well supported feature by most type checkers, and it might not be worth the effort of typing it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not used ParamSpec much so may be missing something, but how would it help us here? It seems designed in cases where the P = ParamSpec("P") variable is reused, like decorators or other higher-order functions. But if we're returning None I don't think it makes a difference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant for Any, not for None haha. The problem with this is that any subtype has to take any kwarg in the signature, but you only really want it to take the kwargs that .learn() can take. But not sure how you can use ParamSpec for this now I think about it. I think it was designed as a generic that is filled in implicitly at the callsite, not for extracting the type signature of an explicit method (which tbh should be a valid use case).

"""Train the imitation policy for a given number of timesteps.

Args:
Expand Down
38 changes: 28 additions & 10 deletions src/imitation/algorithms/mce_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import collections
import warnings
from typing import Any, Iterable, List, Mapping, Optional, Tuple, Type, Union
from typing import Any, Iterable, List, Mapping, NoReturn, Optional, Tuple, Type, Union

import gym
import numpy as np
Expand Down Expand Up @@ -43,9 +43,14 @@ def mce_partition_fh(
(V, Q, \pi) corresponding to the soft values, Q-values and MCE policy.
V is a 2d array, indexed V[t,s]. Q is a 3d array, indexed Q[t,s,a].
\pi is a 3d array, indexed \pi[t,s,a].

Raises:
ValueError: if ``env.horizon`` is None (infinite horizon).
"""
# shorthand
horizon = env.horizon
if horizon is None:
raise ValueError("Only finite-horizon environments are supported.")
n_states = env.state_dim
n_actions = env.action_dim
T = env.transition_matrix
Expand Down Expand Up @@ -99,9 +104,14 @@ def mce_occupancy_measures(
``(env.horizon, env.n_states)`` and records the probability of being in a
given state at a given timestep. ``Dcum`` is of shape ``(env.n_states,)``
and records the expected discounted number of times each state is visited.

Raises:
ValueError: if ``env.horizon`` is None (infinite horizon).
"""
# shorthand
horizon = env.horizon
if horizon is None:
raise ValueError("Only finite-horizon environments are supported.")
n_states = env.state_dim
n_actions = env.action_dim
T = env.transition_matrix
Expand Down Expand Up @@ -150,7 +160,7 @@ def __init__(
action_space: gym.Space,
pi: np.ndarray,
rng: np.random.Generator,
):
) -> None:
"""Builds TabularPolicy.

Args:
Expand Down Expand Up @@ -182,7 +192,7 @@ def forward(
self,
observation: th.Tensor,
deterministic: bool = False,
):
) -> NoReturn:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting . Is this type not incompatible with the superclass? How does that work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it works because NoReturn is the bottom type: i.e. a type that is a subtype of all other types. There's a good discussion at python/mypy#4116

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type theory is quite cool. I remember attending a talk in Cambridge recently on some people working on re-constructing the foundations of mathematics using a flavor of type theory.

raise NotImplementedError("Should never be called.") # pragma: no cover

def predict(
Expand Down Expand Up @@ -269,7 +279,7 @@ def __init__(
log_interval: Optional[int] = 100,
*,
custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
):
) -> None:
r"""Creates MCE IRL.

Args:
Expand Down Expand Up @@ -297,6 +307,9 @@ def __init__(
log_interval: how often to log current loss stats (using `logging`).
None to disable.
custom_logger: Where to log to; if None (default), creates a new logger.

Raises:
ValueError: if the env horizon is not finite (or an integer).
"""
self.discount = discount
self.env = env
Expand All @@ -318,6 +331,8 @@ def __init__(
# Initialize policy to be uniform random. We don't use this for MCE IRL
# training, but it gives us something to return at all times with `policy`
# property, similar to other algorithms.
if self.env.horizon is None:
raise ValueError("Only finite-horizon environments are supported.")
ones = np.ones((self.env.horizon, self.env.state_dim, self.env.action_dim))
uniform_pi = ones / self.env.action_dim
self._policy = TabularPolicy(
Expand Down Expand Up @@ -369,6 +384,7 @@ def _set_demo_from_obs(
)

# Normalize occupancy measure estimates
assert self.env.horizon is not None
self.demo_state_om *= (self.env.horizon + 1) / self.demo_state_om.sum()

def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None:
Expand All @@ -381,9 +397,9 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None:
# Demonstrations are either trajectories or transitions;
# we must compute occupancy measure from this.
if isinstance(demonstrations, Iterable):
first_item, demonstrations = util.get_first_iter_element(demonstrations)
first_item, demonstrations_it = util.get_first_iter_element(demonstrations)
if isinstance(first_item, types.Trajectory):
self._set_demo_from_trajectories(demonstrations)
self._set_demo_from_trajectories(demonstrations_it)
return

# Demonstrations are from some kind of transitions-like object. This does
Expand Down Expand Up @@ -427,7 +443,7 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None:
f"Unsupported demonstration type {type(demonstrations)}",
)

def _train_step(self, obs_mat: th.Tensor):
def _train_step(self, obs_mat: th.Tensor) -> Tuple[np.ndarray, np.ndarray]:
self.optimizer.zero_grad()

# get reward predicted for each state by current model, & compute
Expand Down Expand Up @@ -487,9 +503,11 @@ def train(self, max_iter: int = 1000) -> np.ndarray:
predicted_r_np, visitations = self._train_step(torch_obs_mat)

# these are just for termination conditions & debug logging
grad_norm = util.tensor_iter_norm(
p.grad for p in self.reward_net.parameters()
).item()
grads = []
for p in self.reward_net.parameters():
assert p.grad is not None # for type checker
grads.append(p.grad)
grad_norm = util.tensor_iter_norm(grads).item()
linf_delta = np.max(np.abs(self.demo_state_om - visitations))

if self.log_interval is not None and 0 == (t % self.log_interval):
Expand Down
Loading