Skip to content

Commit 73fa8bd

Browse files
Random Network Distillation for Torch (#4473)
* initial commit * works with Pyramids * added unit tests and a separate config file * Adding first batch of documentation * adding in the docs that rnd is only for PyTorch * adding newline at the end of the config files * adding some docs * Code comments * no normalization of the reward * Fixing the tests * [skip ci] * [skip ci] Make sure RND will only work for Torch by editing the config file * [skip ci] Additional information in the Documentation * Remove the _has_updated_once flag
1 parent 437d04e commit 73fa8bd

File tree

10 files changed

+236
-1
lines changed

10 files changed

+236
-1
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ and this project adheres to
1010
### Major Changes
1111
#### com.unity.ml-agents (C#)
1212
#### ml-agents / ml-agents-envs / gym-unity (Python)
13+
- Added the Random Network Distillation (RND) intrinsic reward signal to the Pytorch
14+
trainers. To use RND, add a `rnd` section to the `reward_signals` section of your
15+
yaml configuration file. [More information here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-Configuration-File.md#rnd-intrinsic-reward)
1316

1417
### Minor Changes
1518
#### com.unity.ml-agents (C#)

config/ppo/PyramidsRND.yaml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
behaviors:
2+
Pyramids:
3+
trainer_type: ppo
4+
hyperparameters:
5+
batch_size: 128
6+
buffer_size: 2048
7+
learning_rate: 0.0003
8+
beta: 0.01
9+
epsilon: 0.2
10+
lambd: 0.95
11+
num_epoch: 3
12+
learning_rate_schedule: linear
13+
network_settings:
14+
normalize: false
15+
hidden_units: 512
16+
num_layers: 2
17+
vis_encode_type: simple
18+
reward_signals:
19+
extrinsic:
20+
gamma: 0.99
21+
strength: 1.0
22+
rnd:
23+
gamma: 0.99
24+
strength: 0.01
25+
encoding_size: 64
26+
learning_rate: 0.0001
27+
keep_checkpoints: 5
28+
max_steps: 3000000
29+
time_horizon: 128
30+
summary_freq: 30000
31+
framework: pytorch
32+
threaded: true

docs/ML-Agents-Overview.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- [A Quick Note on Reward Signals](#a-quick-note-on-reward-signals)
1414
- [Deep Reinforcement Learning](#deep-reinforcement-learning)
1515
- [Curiosity for Sparse-reward Environments](#curiosity-for-sparse-reward-environments)
16+
- [RND for Sparse-reward Environments](#rnd-for-sparse-reward-environments)
1617
- [Imitation Learning](#imitation-learning)
1718
- [GAIL (Generative Adversarial Imitation Learning)](#gail-generative-adversarial-imitation-learning)
1819
- [Behavioral Cloning (BC)](#behavioral-cloning-bc)
@@ -359,7 +360,7 @@ The total reward that the agent will learn to maximize can be a mix of extrinsic
359360
and intrinsic reward signals.
360361

361362
The ML-Agents Toolkit allows reward signals to be defined in a modular way, and
362-
we provide three reward signals that can the mixed and matched to help shape
363+
we provide four reward signals that can the mixed and matched to help shape
363364
your agent's behavior:
364365

365366
- `extrinsic`: represents the rewards defined in your environment, and is
@@ -369,6 +370,9 @@ your agent's behavior:
369370
- `curiosity`: represents an intrinsic reward signal that encourages exploration
370371
in sparse-reward environments that is defined by the Curiosity module (see
371372
below).
373+
- `rnd`: represents an intrinsic reward signal that encourages exploration
374+
in sparse-reward environments that is defined by the Curiosity module (see
375+
below). (Not available for TensorFlow trainers)
372376

373377
### Deep Reinforcement Learning
374378

@@ -417,6 +421,24 @@ model is, the larger the reward will be.
417421
For more information, see our dedicated
418422
[blog post on the Curiosity module](https://blogs.unity3d.com/2018/06/26/solving-sparse-reward-tasks-with-curiosity/).
419423

424+
#### RND for Sparse-reward Environments
425+
426+
Similarly to Curiosity, Random Network Distillation (RND) is useful in sparse or rare
427+
reward environments as it helps the Agent explore. The RND Module is implemented following
428+
the paper [Exploration by Random Network Distillation](https://arxiv.org/abs/1810.12894).
429+
RND uses two networks:
430+
- The first is a network with fixed random weights that takes observations as inputs and
431+
generates an encoding
432+
- The second is a network with similar architecture that is trained to predict the
433+
outputs of the first network and uses the observations the Agent collects as training data.
434+
435+
The loss (the squared difference between the predicted and actual encoded observations)
436+
of the trained model is used as intrinsic reward. The more an Agent visits a state, the
437+
more accurate the predictions and the lower the rewards which encourages the Agent to
438+
explore new states with higher prediction errors.
439+
440+
__Note:__ RND is not available for TensorFlow trainers (only PyTorch trainers)
441+
420442
### Imitation Learning
421443

422444
It is often more intuitive to simply demonstrate the behavior we want an agent

docs/Training-Configuration-File.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- [Extrinsic Rewards](#extrinsic-rewards)
1111
- [Curiosity Intrinsic Reward](#curiosity-intrinsic-reward)
1212
- [GAIL Intrinsic Reward](#gail-intrinsic-reward)
13+
- [RND Intrinsic Reward](#rnd-intrinsic-reward)
1314
- [Reward Signal Settings for SAC](#reward-signal-settings-for-sac)
1415
- [Behavioral Cloning](#behavioral-cloning)
1516
- [Memory-enhanced Agents using Recurrent Neural Networks](#memory-enhanced-agents-using-recurrent-neural-networks)
@@ -118,6 +119,18 @@ settings:
118119
| `gail -> use_actions` | (default = `false`) Determines whether the discriminator should discriminate based on both observations and actions, or just observations. Set to True if you want the agent to mimic the actions from the demonstrations, and False if you'd rather have the agent visit the same states as in the demonstrations but with possibly different actions. Setting to False is more likely to be stable, especially with imperfect demonstrations, but may learn slower. |
119120
| `gail -> use_vail` | (default = `false`) Enables a variational bottleneck within the GAIL discriminator. This forces the discriminator to learn a more general representation and reduces its tendency to be "too good" at discriminating, making learning more stable. However, it does increase training time. Enable this if you notice your imitation learning is unstable, or unable to learn the task at hand. |
120121

122+
### RND Intrinsic Reward
123+
124+
Random Network Distillation (RND) is only available for the PyTorch trainers.
125+
To enable RND, provide these settings:
126+
127+
| **Setting** | **Description** |
128+
| :--------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
129+
| `rnd -> strength` | (default = `1.0`) Magnitude of the curiosity reward generated by the intrinsic rnd module. This should be scaled in order to ensure it is large enough to not be overwhelmed by extrinsic reward signals in the environment. Likewise it should not be too large to overwhelm the extrinsic reward signal. <br><br>Typical range: `0.001` - `0.01` |
130+
| `rnd -> gamma` | (default = `0.99`) Discount factor for future rewards. <br><br>Typical range: `0.8` - `0.995` |
131+
| `rnd -> encoding_size` | (default = `64`) Size of the encoding used by the intrinsic RND model. <br><br>Typical range: `64` - `256` |
132+
| `curiosity -> learning_rate` | (default = `3e-4`) Learning rate used to update the RND module. This should be large enough for the RND module to quickly learn the state representation, but small enough to allow for stable learning. <br><br>Typical range: `1e-5` - `1e-3`
133+
121134

122135
## Behavioral Cloning
123136

ml-agents/mlagents/trainers/model_saver/tf_model_saver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def export(self, output_filepath: str, behavior_name: str) -> None:
7474
# only on worker-0 if there are multiple workers
7575
if self.policy and self.policy.rank is not None and self.policy.rank != 0:
7676
return
77+
if self.graph is None:
78+
logger.info("No model to export")
79+
return
7780
export_policy_model(
7881
self.model_path, output_filepath, behavior_name, self.graph, self.sess
7982
)

ml-agents/mlagents/trainers/settings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,14 @@ class RewardSignalType(Enum):
167167
EXTRINSIC: str = "extrinsic"
168168
GAIL: str = "gail"
169169
CURIOSITY: str = "curiosity"
170+
RND: str = "rnd"
170171

171172
def to_settings(self) -> type:
172173
_mapping = {
173174
RewardSignalType.EXTRINSIC: RewardSignalSettings,
174175
RewardSignalType.GAIL: GAILSettings,
175176
RewardSignalType.CURIOSITY: CuriositySettings,
177+
RewardSignalType.RND: RNDSettings,
176178
}
177179
return _mapping[self]
178180

@@ -214,6 +216,12 @@ class CuriositySettings(RewardSignalSettings):
214216
learning_rate: float = 3e-4
215217

216218

219+
@attr.s(auto_attribs=True)
220+
class RNDSettings(RewardSignalSettings):
221+
encoding_size: int = 64
222+
learning_rate: float = 1e-4
223+
224+
217225
# SAMPLERS #############################################################################
218226
class ParameterRandomizationType(Enum):
219227
UNIFORM: str = "uniform"
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import numpy as np
2+
import pytest
3+
from mlagents.torch_utils import torch
4+
from mlagents.trainers.torch.components.reward_providers import (
5+
RNDRewardProvider,
6+
create_reward_provider,
7+
)
8+
from mlagents_envs.base_env import BehaviorSpec, ActionType
9+
from mlagents.trainers.settings import RNDSettings, RewardSignalType
10+
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
11+
create_agent_buffer,
12+
)
13+
14+
SEED = [42]
15+
16+
17+
@pytest.mark.parametrize(
18+
"behavior_spec",
19+
[
20+
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
21+
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
22+
],
23+
)
24+
def test_construction(behavior_spec: BehaviorSpec) -> None:
25+
curiosity_settings = RNDSettings(32, 0.01)
26+
curiosity_settings.strength = 0.1
27+
curiosity_rp = RNDRewardProvider(behavior_spec, curiosity_settings)
28+
assert curiosity_rp.strength == 0.1
29+
assert curiosity_rp.name == "RND"
30+
31+
32+
@pytest.mark.parametrize(
33+
"behavior_spec",
34+
[
35+
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
36+
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5),
37+
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)),
38+
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
39+
],
40+
)
41+
def test_factory(behavior_spec: BehaviorSpec) -> None:
42+
curiosity_settings = RNDSettings(32, 0.01)
43+
curiosity_rp = create_reward_provider(
44+
RewardSignalType.RND, behavior_spec, curiosity_settings
45+
)
46+
assert curiosity_rp.name == "RND"
47+
48+
49+
@pytest.mark.parametrize("seed", SEED)
50+
@pytest.mark.parametrize(
51+
"behavior_spec",
52+
[
53+
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5),
54+
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
55+
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
56+
],
57+
)
58+
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None:
59+
np.random.seed(seed)
60+
torch.manual_seed(seed)
61+
rnd_settings = RNDSettings(32, 0.01)
62+
rnd_rp = RNDRewardProvider(behavior_spec, rnd_settings)
63+
buffer = create_agent_buffer(behavior_spec, 5)
64+
rnd_rp.update(buffer)
65+
reward_old = rnd_rp.evaluate(buffer)[0]
66+
for _ in range(100):
67+
rnd_rp.update(buffer)
68+
reward_new = rnd_rp.evaluate(buffer)[0]
69+
assert reward_new < reward_old

ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( # noqa F401
1111
GAILRewardProvider,
1212
)
13+
from mlagents.trainers.torch.components.reward_providers.rnd_reward_provider import ( # noqa F401
14+
RNDRewardProvider,
15+
)
1316
from mlagents.trainers.torch.components.reward_providers.reward_provider_factory import ( # noqa F401
1417
create_reward_provider,
1518
)

ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import (
1616
GAILRewardProvider,
1717
)
18+
from mlagents.trainers.torch.components.reward_providers.rnd_reward_provider import (
19+
RNDRewardProvider,
20+
)
1821

1922
from mlagents_envs.base_env import BehaviorSpec
2023

2124
NAME_TO_CLASS: Dict[RewardSignalType, Type[BaseRewardProvider]] = {
2225
RewardSignalType.EXTRINSIC: ExtrinsicRewardProvider,
2326
RewardSignalType.CURIOSITY: CuriosityRewardProvider,
2427
RewardSignalType.GAIL: GAILRewardProvider,
28+
RewardSignalType.RND: RNDRewardProvider,
2529
}
2630

2731

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import numpy as np
2+
from typing import Dict
3+
from mlagents.torch_utils import torch
4+
5+
from mlagents.trainers.buffer import AgentBuffer
6+
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (
7+
BaseRewardProvider,
8+
)
9+
from mlagents.trainers.settings import RNDSettings
10+
11+
from mlagents_envs.base_env import BehaviorSpec
12+
from mlagents.trainers.torch.utils import ModelUtils
13+
from mlagents.trainers.torch.networks import NetworkBody
14+
from mlagents.trainers.settings import NetworkSettings, EncoderType
15+
16+
17+
class RNDRewardProvider(BaseRewardProvider):
18+
"""
19+
Implementation of Random Network Distillation : https://arxiv.org/pdf/1810.12894.pdf
20+
"""
21+
22+
def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None:
23+
super().__init__(specs, settings)
24+
self._ignore_done = True
25+
self._random_network = RNDNetwork(specs, settings)
26+
self._training_network = RNDNetwork(specs, settings)
27+
self.optimizer = torch.optim.Adam(
28+
self._training_network.parameters(), lr=settings.learning_rate
29+
)
30+
31+
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
32+
with torch.no_grad():
33+
target = self._random_network(mini_batch)
34+
prediction = self._training_network(mini_batch)
35+
rewards = torch.sum((prediction - target) ** 2, dim=1)
36+
return rewards.detach().cpu().numpy()
37+
38+
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
39+
with torch.no_grad():
40+
target = self._random_network(mini_batch)
41+
prediction = self._training_network(mini_batch)
42+
loss = torch.mean(torch.sum((prediction - target) ** 2, dim=1))
43+
self.optimizer.zero_grad()
44+
loss.backward()
45+
self.optimizer.step()
46+
return {"Losses/RND Loss": loss.detach().cpu().numpy()}
47+
48+
49+
class RNDNetwork(torch.nn.Module):
50+
EPSILON = 1e-10
51+
52+
def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None:
53+
super().__init__()
54+
self._policy_specs = specs
55+
state_encoder_settings = NetworkSettings(
56+
normalize=True,
57+
hidden_units=settings.encoding_size,
58+
num_layers=3,
59+
vis_encode_type=EncoderType.SIMPLE,
60+
memory=None,
61+
)
62+
self._encoder = NetworkBody(specs.observation_shapes, state_encoder_settings)
63+
64+
def forward(self, mini_batch: AgentBuffer) -> torch.Tensor:
65+
n_vis = len(self._encoder.visual_processors)
66+
hidden, _ = self._encoder.forward(
67+
vec_inputs=[
68+
ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float)
69+
],
70+
vis_inputs=[
71+
ModelUtils.list_to_tensor(
72+
mini_batch["visual_obs%d" % i], dtype=torch.float
73+
)
74+
for i in range(n_vis)
75+
],
76+
)
77+
self._encoder.update_normalization(torch.tensor(mini_batch["vector_obs"]))
78+
return hidden

0 commit comments

Comments
 (0)