-
Notifications
You must be signed in to change notification settings - Fork 273
Enable mypy analysis for all files. #593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b5db98c
a1aa210
d247f73
708ebe1
57be740
b40046a
3e9ac54
f3805f8
12605f0
40d6d40
e990d74
b89b7f0
dc76fb8
758936a
081a3cf
b134fda
14ecd9c
cf59783
a376e93
6416750
42f3aec
cb5371a
26cd166
a4b080f
0cd35d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
import enum | ||
import itertools | ||
from collections.abc import Mapping | ||
from typing import Dict, Iterable, List, Optional, cast | ||
from typing import Any, Dict, Iterable, List, Optional, cast | ||
|
||
import numpy as np | ||
from gym.spaces.utils import flatten | ||
|
@@ -226,7 +226,7 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: | |
None: np.concatenate(list(self.transitions.values()), axis=0), | ||
} | ||
|
||
def train(self): | ||
def train(self) -> None: | ||
"""Fits the density model to demonstration data `self.transitions`.""" | ||
# if requested, we'll scale demonstration transitions so that they have | ||
# zero mean and unit variance (i.e. all components are equally important) | ||
|
@@ -343,7 +343,7 @@ def __call__( | |
rew_array = np.asarray(rew_list, dtype="float32") | ||
return rew_array | ||
|
||
def train_policy(self, n_timesteps: int = int(1e6), **kwargs): | ||
def train_policy(self, n_timesteps: int = int(1e6), **kwargs: Any) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could try to be more specific than this, but ParamSpec is not (yet) a very well supported feature by most type checkers, and it might not be worth the effort of typing it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've not used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant for Any, not for None haha. The problem with this is that any subtype has to take any kwarg in the signature, but you only really want it to take the kwargs that |
||
"""Train the imitation policy for a given number of timesteps. | ||
|
||
Args: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
""" | ||
import collections | ||
import warnings | ||
from typing import Any, Iterable, List, Mapping, Optional, Tuple, Type, Union | ||
from typing import Any, Iterable, List, Mapping, NoReturn, Optional, Tuple, Type, Union | ||
|
||
import gym | ||
import numpy as np | ||
|
@@ -43,9 +43,14 @@ def mce_partition_fh( | |
(V, Q, \pi) corresponding to the soft values, Q-values and MCE policy. | ||
V is a 2d array, indexed V[t,s]. Q is a 3d array, indexed Q[t,s,a]. | ||
\pi is a 3d array, indexed \pi[t,s,a]. | ||
|
||
Raises: | ||
ValueError: if ``env.horizon`` is None (infinite horizon). | ||
""" | ||
# shorthand | ||
horizon = env.horizon | ||
if horizon is None: | ||
raise ValueError("Only finite-horizon environments are supported.") | ||
n_states = env.state_dim | ||
n_actions = env.action_dim | ||
T = env.transition_matrix | ||
|
@@ -99,9 +104,14 @@ def mce_occupancy_measures( | |
``(env.horizon, env.n_states)`` and records the probability of being in a | ||
given state at a given timestep. ``Dcum`` is of shape ``(env.n_states,)`` | ||
and records the expected discounted number of times each state is visited. | ||
|
||
Raises: | ||
ValueError: if ``env.horizon`` is None (infinite horizon). | ||
""" | ||
# shorthand | ||
horizon = env.horizon | ||
if horizon is None: | ||
raise ValueError("Only finite-horizon environments are supported.") | ||
n_states = env.state_dim | ||
n_actions = env.action_dim | ||
T = env.transition_matrix | ||
|
@@ -150,7 +160,7 @@ def __init__( | |
action_space: gym.Space, | ||
pi: np.ndarray, | ||
rng: np.random.Generator, | ||
): | ||
) -> None: | ||
"""Builds TabularPolicy. | ||
|
||
Args: | ||
|
@@ -182,7 +192,7 @@ def forward( | |
self, | ||
observation: th.Tensor, | ||
deterministic: bool = False, | ||
): | ||
) -> NoReturn: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting . Is this type not incompatible with the superclass? How does that work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it works because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Type theory is quite cool. I remember attending a talk in Cambridge recently on some people working on re-constructing the foundations of mathematics using a flavor of type theory. |
||
raise NotImplementedError("Should never be called.") # pragma: no cover | ||
|
||
def predict( | ||
|
@@ -269,7 +279,7 @@ def __init__( | |
log_interval: Optional[int] = 100, | ||
*, | ||
custom_logger: Optional[imit_logger.HierarchicalLogger] = None, | ||
): | ||
) -> None: | ||
r"""Creates MCE IRL. | ||
|
||
Args: | ||
|
@@ -297,6 +307,9 @@ def __init__( | |
log_interval: how often to log current loss stats (using `logging`). | ||
None to disable. | ||
custom_logger: Where to log to; if None (default), creates a new logger. | ||
|
||
Raises: | ||
ValueError: if the env horizon is not finite (or an integer). | ||
""" | ||
self.discount = discount | ||
self.env = env | ||
|
@@ -318,6 +331,8 @@ def __init__( | |
# Initialize policy to be uniform random. We don't use this for MCE IRL | ||
# training, but it gives us something to return at all times with `policy` | ||
# property, similar to other algorithms. | ||
if self.env.horizon is None: | ||
raise ValueError("Only finite-horizon environments are supported.") | ||
ones = np.ones((self.env.horizon, self.env.state_dim, self.env.action_dim)) | ||
uniform_pi = ones / self.env.action_dim | ||
self._policy = TabularPolicy( | ||
|
@@ -369,6 +384,7 @@ def _set_demo_from_obs( | |
) | ||
|
||
# Normalize occupancy measure estimates | ||
assert self.env.horizon is not None | ||
self.demo_state_om *= (self.env.horizon + 1) / self.demo_state_om.sum() | ||
|
||
def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: | ||
|
@@ -381,9 +397,9 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: | |
# Demonstrations are either trajectories or transitions; | ||
# we must compute occupancy measure from this. | ||
if isinstance(demonstrations, Iterable): | ||
first_item, demonstrations = util.get_first_iter_element(demonstrations) | ||
first_item, demonstrations_it = util.get_first_iter_element(demonstrations) | ||
Rocamonde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(first_item, types.Trajectory): | ||
self._set_demo_from_trajectories(demonstrations) | ||
self._set_demo_from_trajectories(demonstrations_it) | ||
return | ||
|
||
# Demonstrations are from some kind of transitions-like object. This does | ||
|
@@ -427,7 +443,7 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: | |
f"Unsupported demonstration type {type(demonstrations)}", | ||
) | ||
|
||
def _train_step(self, obs_mat: th.Tensor): | ||
def _train_step(self, obs_mat: th.Tensor) -> Tuple[np.ndarray, np.ndarray]: | ||
self.optimizer.zero_grad() | ||
|
||
# get reward predicted for each state by current model, & compute | ||
|
@@ -487,9 +503,11 @@ def train(self, max_iter: int = 1000) -> np.ndarray: | |
predicted_r_np, visitations = self._train_step(torch_obs_mat) | ||
|
||
# these are just for termination conditions & debug logging | ||
grad_norm = util.tensor_iter_norm( | ||
p.grad for p in self.reward_net.parameters() | ||
).item() | ||
grads = [] | ||
Rocamonde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for p in self.reward_net.parameters(): | ||
assert p.grad is not None # for type checker | ||
grads.append(p.grad) | ||
grad_norm = util.tensor_iter_norm(grads).item() | ||
linf_delta = np.max(np.abs(self.demo_state_om - visitations)) | ||
|
||
if self.log_interval is not None and 0 == (t % self.log_interval): | ||
|
Uh oh!
There was an error while loading. Please reload this page.