Skip to content

Commit 78a218b

Browse files
author
Ervin T
authored
[change] Separate action outputs into OutputDistributions object (#3514)
1 parent 751ab0b commit 78a218b

File tree

8 files changed

+489
-147
lines changed

8 files changed

+489
-147
lines changed
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
import abc
2+
from typing import NamedTuple, List, Tuple
3+
import numpy as np
4+
5+
from mlagents.tf_utils import tf
6+
from mlagents.trainers.models import ModelUtils
7+
8+
EPSILON = 1e-6 # Small value to avoid divide by zero
9+
10+
11+
class OutputDistribution(abc.ABC):
12+
@abc.abstractproperty
13+
def log_probs(self) -> tf.Tensor:
14+
"""
15+
Returns a Tensor that when evaluated, produces the per-action log probabilities of this distribution.
16+
The shape of this Tensor should be equivalent to (batch_size x the number of actions) produced in sample.
17+
"""
18+
pass
19+
20+
@abc.abstractproperty
21+
def total_log_probs(self) -> tf.Tensor:
22+
"""
23+
Returns a Tensor that when evaluated, produces the total log probability for a single sample.
24+
The shape of this Tensor should be equivalent to (batch_size x 1) produced in sample.
25+
"""
26+
pass
27+
28+
@abc.abstractproperty
29+
def sample(self) -> tf.Tensor:
30+
"""
31+
Returns a Tensor that when evaluated, produces a sample of this OutputDistribution.
32+
"""
33+
pass
34+
35+
@abc.abstractproperty
36+
def entropy(self) -> tf.Tensor:
37+
"""
38+
Returns a Tensor that when evaluated, produces the entropy of this distribution.
39+
"""
40+
pass
41+
42+
43+
class DiscreteOutputDistribution(OutputDistribution):
44+
@abc.abstractproperty
45+
def sample_onehot(self) -> tf.Tensor:
46+
"""
47+
Returns a one-hot version of the output.
48+
"""
49+
50+
51+
class GaussianDistribution(OutputDistribution):
52+
"""
53+
A Gaussian output distribution for continuous actions.
54+
"""
55+
56+
class MuSigmaTensors(NamedTuple):
57+
mu: tf.Tensor
58+
log_sigma: tf.Tensor
59+
sigma: tf.Tensor
60+
61+
def __init__(
62+
self,
63+
logits: tf.Tensor,
64+
act_size: List[int],
65+
reparameterize: bool = False,
66+
tanh_squash: bool = False,
67+
log_sigma_min: float = -20,
68+
log_sigma_max: float = 2,
69+
):
70+
"""
71+
A Gaussian output distribution for continuous actions.
72+
:param logits: Hidden layer to use as the input to the Gaussian distribution.
73+
:param act_size: List containing the number of continuous actions.
74+
:param reparameterize: Whether or not to use the reparameterization trick (block gradients through
75+
log probability calculation.)
76+
:param tanh_squash: Squash the output using tanh, constraining it between -1 and 1.
77+
From: Haarnoja et. al, https://arxiv.org/abs/1801.01290
78+
:param log_sigma_min: Minimum log standard deviation to clip by.
79+
:param log_sigma_max: Maximum log standard deviation to clip by.
80+
"""
81+
encoded = self._create_mu_log_sigma(
82+
logits, act_size, log_sigma_min, log_sigma_max
83+
)
84+
self._sampled_policy = self._create_sampled_policy(encoded)
85+
if not reparameterize:
86+
_sampled_policy_probs = tf.stop_gradient(self._sampled_policy)
87+
else:
88+
_sampled_policy_probs = self._sampled_policy
89+
self._all_probs = self._create_log_probs(_sampled_policy_probs, encoded)
90+
if tanh_squash:
91+
self._sampled_policy = tf.tanh(self._sampled_policy)
92+
self._all_probs = self._do_squash_correction_for_tanh(
93+
self._all_probs, self._sampled_policy
94+
)
95+
self._total_prob = tf.reduce_sum(self._all_probs, axis=1, keepdims=True)
96+
self._entropy = self._create_entropy(encoded)
97+
98+
def _create_mu_log_sigma(
99+
self,
100+
logits: tf.Tensor,
101+
act_size: List[int],
102+
log_sigma_min: float,
103+
log_sigma_max: float,
104+
) -> "GaussianDistribution.MuSigmaTensors":
105+
106+
mu = tf.layers.dense(
107+
logits,
108+
act_size[0],
109+
activation=None,
110+
name="mu",
111+
kernel_initializer=ModelUtils.scaled_init(0.01),
112+
reuse=tf.AUTO_REUSE,
113+
)
114+
115+
# Policy-dependent log_sigma_sq
116+
log_sigma = tf.layers.dense(
117+
logits,
118+
act_size[0],
119+
activation=None,
120+
name="log_std",
121+
kernel_initializer=ModelUtils.scaled_init(0.01),
122+
)
123+
log_sigma = tf.clip_by_value(log_sigma, log_sigma_min, log_sigma_max)
124+
sigma = tf.exp(log_sigma)
125+
return self.MuSigmaTensors(mu, log_sigma, sigma)
126+
127+
def _create_sampled_policy(
128+
self, encoded: "GaussianDistribution.MuSigmaTensors"
129+
) -> tf.Tensor:
130+
epsilon = tf.random_normal(tf.shape(encoded.mu))
131+
sampled_policy = encoded.mu + encoded.sigma * epsilon
132+
133+
return sampled_policy
134+
135+
def _create_log_probs(
136+
self, sampled_policy: tf.Tensor, encoded: "GaussianDistribution.MuSigmaTensors"
137+
) -> tf.Tensor:
138+
_gauss_pre = -0.5 * (
139+
((sampled_policy - encoded.mu) / (encoded.sigma + EPSILON)) ** 2
140+
+ 2 * encoded.log_sigma
141+
+ np.log(2 * np.pi)
142+
)
143+
return _gauss_pre
144+
145+
def _create_entropy(
146+
self, encoded: "GaussianDistribution.MuSigmaTensors"
147+
) -> tf.Tensor:
148+
single_dim_entropy = 0.5 * tf.reduce_mean(
149+
tf.log(2 * np.pi * np.e) + tf.square(encoded.log_sigma)
150+
)
151+
# Make entropy the right shape
152+
return tf.ones_like(tf.reshape(encoded.mu[:, 0], [-1])) * single_dim_entropy
153+
154+
def _do_squash_correction_for_tanh(self, probs, squashed_policy):
155+
"""
156+
Adjust probabilities for squashed sample before output
157+
"""
158+
probs -= tf.log(1 - squashed_policy ** 2 + EPSILON)
159+
return probs
160+
161+
@property
162+
def total_log_probs(self) -> tf.Tensor:
163+
return self._total_prob
164+
165+
@property
166+
def log_probs(self) -> tf.Tensor:
167+
return self._all_probs
168+
169+
@property
170+
def sample(self) -> tf.Tensor:
171+
return self._sampled_policy
172+
173+
@property
174+
def entropy(self) -> tf.Tensor:
175+
return self._entropy
176+
177+
178+
class MultiCategoricalDistribution(DiscreteOutputDistribution):
179+
"""
180+
A categorical distribution for multi-branched discrete actions. Also supports action masking.
181+
"""
182+
183+
def __init__(self, logits: tf.Tensor, act_size: List[int], action_masks: tf.Tensor):
184+
"""
185+
A categorical distribution for multi-branched discrete actions.
186+
:param logits: Hidden layer to use as the input to the Gaussian distribution.
187+
:param act_size: List containing the number of discrete actions per branch.
188+
:param action_masks: Tensor representing action masks. Should be of length sum(act_size), and 0 for masked
189+
and 1 for unmasked.
190+
"""
191+
unmasked_log_probs = self._create_policy_branches(logits, act_size)
192+
self._sampled_policy, self._all_probs, action_index = self._get_masked_actions_probs(
193+
unmasked_log_probs, act_size, action_masks
194+
)
195+
self._sampled_onehot = self._action_onehot(self._sampled_policy, act_size)
196+
self._entropy = self._create_entropy(
197+
self._sampled_onehot, self._all_probs, action_index, act_size
198+
)
199+
self._total_prob = self._get_log_probs(
200+
self._sampled_onehot, self._all_probs, action_index, act_size
201+
)
202+
203+
def _create_policy_branches(
204+
self, logits: tf.Tensor, act_size: List[int]
205+
) -> List[tf.Tensor]:
206+
policy_branches = []
207+
for size in act_size:
208+
policy_branches.append(
209+
tf.layers.dense(
210+
logits,
211+
size,
212+
activation=None,
213+
use_bias=False,
214+
kernel_initializer=ModelUtils.scaled_init(0.01),
215+
)
216+
)
217+
unmasked_log_probs = tf.concat(policy_branches, axis=1)
218+
return unmasked_log_probs
219+
220+
def _get_masked_actions_probs(
221+
self,
222+
unmasked_log_probs: tf.Tensor,
223+
act_size: List[int],
224+
action_masks: tf.Tensor,
225+
) -> Tuple[tf.Tensor, tf.Tensor, np.ndarray]:
226+
output, _, all_log_probs = ModelUtils.create_discrete_action_masking_layer(
227+
unmasked_log_probs, action_masks, act_size
228+
)
229+
230+
action_idx = [0] + list(np.cumsum(act_size))
231+
return output, all_log_probs, action_idx
232+
233+
def _action_onehot(self, sample: tf.Tensor, act_size: List[int]) -> tf.Tensor:
234+
action_oh = tf.concat(
235+
[tf.one_hot(sample[:, i], act_size[i]) for i in range(len(act_size))],
236+
axis=1,
237+
)
238+
return action_oh
239+
240+
def _get_log_probs(
241+
self,
242+
sample_onehot: tf.Tensor,
243+
all_log_probs: tf.Tensor,
244+
action_idx: List[int],
245+
act_size: List[int],
246+
) -> tf.Tensor:
247+
log_probs = tf.reduce_sum(
248+
(
249+
tf.stack(
250+
[
251+
-tf.nn.softmax_cross_entropy_with_logits_v2(
252+
labels=sample_onehot[:, action_idx[i] : action_idx[i + 1]],
253+
logits=all_log_probs[:, action_idx[i] : action_idx[i + 1]],
254+
)
255+
for i in range(len(act_size))
256+
],
257+
axis=1,
258+
)
259+
),
260+
axis=1,
261+
keepdims=True,
262+
)
263+
return log_probs
264+
265+
def _create_entropy(
266+
self,
267+
all_log_probs: tf.Tensor,
268+
sample_onehot: tf.Tensor,
269+
action_idx: List[int],
270+
act_size: List[int],
271+
) -> tf.Tensor:
272+
entropy = tf.reduce_sum(
273+
(
274+
tf.stack(
275+
[
276+
tf.nn.softmax_cross_entropy_with_logits_v2(
277+
labels=tf.nn.softmax(
278+
all_log_probs[:, action_idx[i] : action_idx[i + 1]]
279+
),
280+
logits=all_log_probs[:, action_idx[i] : action_idx[i + 1]],
281+
)
282+
for i in range(len(act_size))
283+
],
284+
axis=1,
285+
)
286+
),
287+
axis=1,
288+
)
289+
290+
return entropy
291+
292+
@property
293+
def log_probs(self) -> tf.Tensor:
294+
return self._all_probs
295+
296+
@property
297+
def total_log_probs(self) -> tf.Tensor:
298+
return self._total_prob
299+
300+
@property
301+
def sample(self) -> tf.Tensor:
302+
return self._sampled_policy
303+
304+
@property
305+
def sample_onehot(self) -> tf.Tensor:
306+
return self._sampled_onehot
307+
308+
@property
309+
def entropy(self) -> tf.Tensor:
310+
return self._entropy

0 commit comments

Comments
 (0)