Skip to content

Commit d713d93

Browse files
ianyfandan-pandori
andcommitted
Add extra checks for horizon type
Co-authored-by: Daniel Pandori <[email protected]>
1 parent 0e950b8 commit d713d93

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

src/imitation/algorithms/mce_irl.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,14 @@ def mce_partition_fh(
5454
(V, Q, \pi) corresponding to the soft values, Q-values and MCE policy.
5555
V is a 2d array, indexed V[t,s]. Q is a 3d array, indexed Q[t,s,a].
5656
\pi is a 3d array, indexed \pi[t,s,a].
57+
58+
Raises:
59+
ValueError: if the horizon is not finite (or an integer).
5760
"""
5861
# shorthand
59-
horizon = int(env.horizon)
62+
if not isinstance(env.horizon, int):
63+
raise ValueError("Only finite (integer) horizons are supported.")
64+
horizon = env.horizon
6065
n_states = env.state_dim
6166
n_actions = env.action_dim
6267
T = env.transition_matrix
@@ -110,9 +115,14 @@ def mce_occupancy_measures(
110115
``(env.horizon, env.n_states)`` and records the probability of being in a
111116
given state at a given timestep. ``Dcum`` is of shape ``(env.n_states,)``
112117
and records the expected discounted number of times each state is visited.
118+
119+
Raises:
120+
ValueError: if the horizon is not finite (or an integer).
113121
"""
114122
# shorthand
115-
horizon = int(env.horizon)
123+
if not isinstance(env.horizon, int):
124+
raise ValueError("Only finite (integer) horizons are supported.")
125+
horizon = env.horizon
116126
n_states = env.state_dim
117127
n_actions = env.action_dim
118128
T = env.transition_matrix
@@ -308,6 +318,9 @@ def __init__(
308318
log_interval: how often to log current loss stats (using `logging`).
309319
None to disable.
310320
custom_logger: Where to log to; if None (default), creates a new logger.
321+
322+
Raises:
323+
ValueError: if the env horizon is not finite (or an integer).
311324
"""
312325
self.discount = discount
313326
self.env = env
@@ -329,7 +342,9 @@ def __init__(
329342
# Initialize policy to be uniform random. We don't use this for MCE IRL
330343
# training, but it gives us something to return at all times with `policy`
331344
# property, similar to other algorithms.
332-
ones = np.ones((int(self.env.horizon), self.env.state_dim, self.env.action_dim))
345+
if not isinstance(self.env.horizon, int):
346+
raise ValueError("Only finite (integer) horizons are supported.")
347+
ones = np.ones((self.env.horizon, self.env.state_dim, self.env.action_dim))
333348
uniform_pi = ones / self.env.action_dim
334349
self._policy = TabularPolicy(
335350
state_space=self.env.state_space,
@@ -380,6 +395,7 @@ def _set_demo_from_obs(
380395
)
381396

382397
# Normalize occupancy measure estimates
398+
assert isinstance(self.env.horizon, int)
383399
self.demo_state_om *= (self.env.horizon + 1) / self.demo_state_om.sum()
384400

385401
def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None:

src/imitation/algorithms/preference_comparisons.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,8 +641,10 @@ def __call__(
641641

642642
# we need two fragments for each comparison
643643
for _ in range(2 * num_pairs):
644-
p = np.array(weights) / sum(weights)
645-
traj = self.rng.choice(trajectories, p=p) # type: ignore[arg-type]
644+
traj = self.rng.choice(
645+
trajectories, # type: ignore[arg-type]
646+
p=np.array(weights) / sum(weights),
647+
)
646648
n = len(traj)
647649
start = self.rng.integers(0, n - fragment_length, endpoint=True)
648650
end = start + fragment_length

tests/algorithms/test_mce_irl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,4 +417,7 @@ def test_mce_irl_reasonable_mdp(
417417
stats = rollout.rollout_stats(trajs)
418418
if discount > 0.0: # skip check when discount==0.0 (random policy)
419419
eps = 1e-6 # avoid test failing due to rounding error
420-
assert stats["return_mean"] >= (mdp.horizon - 1) * 2 * 0.8 - eps
420+
assert (
421+
isinstance(mdp.horizon, int)
422+
and stats["return_mean"] >= (mdp.horizon - 1) * 2 * 0.8 - eps
423+
)

0 commit comments

Comments
 (0)