@@ -54,9 +54,14 @@ def mce_partition_fh(
54
54
(V, Q, \pi) corresponding to the soft values, Q-values and MCE policy.
55
55
V is a 2d array, indexed V[t,s]. Q is a 3d array, indexed Q[t,s,a].
56
56
\pi is a 3d array, indexed \pi[t,s,a].
57
+
58
+ Raises:
59
+ ValueError: if the horizon is not finite (or an integer).
57
60
"""
58
61
# shorthand
59
- horizon = int (env .horizon )
62
+ if not isinstance (env .horizon , int ):
63
+ raise ValueError ("Only finite (integer) horizons are supported." )
64
+ horizon = env .horizon
60
65
n_states = env .state_dim
61
66
n_actions = env .action_dim
62
67
T = env .transition_matrix
@@ -110,9 +115,14 @@ def mce_occupancy_measures(
110
115
``(env.horizon, env.n_states)`` and records the probability of being in a
111
116
given state at a given timestep. ``Dcum`` is of shape ``(env.n_states,)``
112
117
and records the expected discounted number of times each state is visited.
118
+
119
+ Raises:
120
+ ValueError: if the horizon is not finite (or an integer).
113
121
"""
114
122
# shorthand
115
- horizon = int (env .horizon )
123
+ if not isinstance (env .horizon , int ):
124
+ raise ValueError ("Only finite (integer) horizons are supported." )
125
+ horizon = env .horizon
116
126
n_states = env .state_dim
117
127
n_actions = env .action_dim
118
128
T = env .transition_matrix
@@ -308,6 +318,9 @@ def __init__(
308
318
log_interval: how often to log current loss stats (using `logging`).
309
319
None to disable.
310
320
custom_logger: Where to log to; if None (default), creates a new logger.
321
+
322
+ Raises:
323
+ ValueError: if the env horizon is not finite (or an integer).
311
324
"""
312
325
self .discount = discount
313
326
self .env = env
@@ -329,7 +342,9 @@ def __init__(
329
342
# Initialize policy to be uniform random. We don't use this for MCE IRL
330
343
# training, but it gives us something to return at all times with `policy`
331
344
# property, similar to other algorithms.
332
- ones = np .ones ((int (self .env .horizon ), self .env .state_dim , self .env .action_dim ))
345
+ if not isinstance (self .env .horizon , int ):
346
+ raise ValueError ("Only finite (integer) horizons are supported." )
347
+ ones = np .ones ((self .env .horizon , self .env .state_dim , self .env .action_dim ))
333
348
uniform_pi = ones / self .env .action_dim
334
349
self ._policy = TabularPolicy (
335
350
state_space = self .env .state_space ,
@@ -380,6 +395,7 @@ def _set_demo_from_obs(
380
395
)
381
396
382
397
# Normalize occupancy measure estimates
398
+ assert isinstance (self .env .horizon , int )
383
399
self .demo_state_om *= (self .env .horizon + 1 ) / self .demo_state_om .sum ()
384
400
385
401
def set_demonstrations (self , demonstrations : MCEDemonstrations ) -> None :
0 commit comments