Skip to content

Commit 672b608

Browse files
Ervin Tawjulianivincentpierremidopoolerandrewcoh
authored
[feature] Add experimental PyTorch support (#4335)
* Begin porting work * Add ResNet and distributions * Dynamically construct actor and critic * Initial optimizer port * Refactoring policy and optimizer * Resolving a few bugs * Share more code between tf and torch policies * Slightly closer to running model * Training runs, but doesn’t actually work * Fix a couple additional bugs * Add conditional sigma for distribution * Fix normalization * Support discrete actions as well * Continuous and discrete now train * Mulkti-discrete now working * Visual observations now train as well * GRU in-progress and dynamic cnns * Fix for memories * Remove unused arg * Combine actor and critic classes. Initial export. * Support tf and pytorch alongside one another * Prepare model for onnx export * Use LSTM and fix a few merge errors * Fix bug in probs calculation * Optimize np -> tensor operations * Time action sample function * Small performance improvement during inference * ONNX exporting * Fix some issues with pdf * Fix bug in pdf function * Fix ResNet * Remove double setting * Fix for discrete actions (#4181) * Fix discrete actions and GridWorld * Remove print statement * Convert List[np.ndarray] to np.ndarray before using torch.as_tensor (#4183) Big speedup in visual obs * Develop add fire exp framework (#4213) * Experiment branch for comparing torch * Updates and merging ervin changes * improvements on experiment_torch.py * Better printing of results * preliminary gpu experiment * Testing gpu * Prepare to see a lot of commits, because I like my IDE and I am testing on a server and I am using git to sync the two * Prepare to see a lot of commits, because I like my IDE and I am testing on a server and I am using git to sync the two * _ * _ * _ * _ * _ * _ * _ * _ * Attempt at gpu on tf. Does not work * _ * _ * _ * _ * _ * _ * _ * _ * _ * _ * _ * Fixing learn.py * reformating experiment_torch.py * Pytorch port of SAC (#4219) * Update add-fire to latest master, including Policy refactor (#4263) * Update Dockerfile * Separate send environment data from reset (#4128) * Fixed a typo on ML-Agents-Overview.md (#4130) Fixed redundant "to" word from the sentence since it is probably a typo in document. * Updated the badge’s link to point to the newest doc version * Replaced all of the doc to release_3_doc * Fix 3DBall and 3DBallHard SAC regressions (#4132) * Move memory validation to settings * Update docs * Add settings test * Update to release_3 in installation.md (#4144) * rename to SideChannelManager +backcompat (#4137) * Remove comment about logo with --help (#4148) * [bugfix] Make FoodCollector heuristic playable (#4147) * Make FoodCollector heuristic playable * Update changelog * script to check for old release links and references (#4153) * Remove package validation suite from Project (#4146) * RayPerceptionSensor: handle empty and invalid tags (#4155) * handle empty and invalid tags * don't compare null or empty tags * changelog * avoid console spam when editing tag name * [docs] Fix a typo in a link in the package docs (#4161) * update FAQ to include disabling graphics (#4159) * update FAQ and exception message * remove colab link * [MLA-1009] observable performance tests (#4031) * WIP perf tests * WIP perf test * add marker tests too * move to devproject * yamato first pass * chmod * fix trigger, fix meta files * fix utr command * fix artifact paths * Update com.unity.ml-agents-performance.yml * test properties, reduce some noise * timer around RequestDecision * actually set ObservableAttributeHandling * undo asmdef changes * Yamato inference tests (#4066) * better errors for missing constants * run inference in yamato after training * add extension * debug subprocess args * fix exe path * search for executable * fix dumb bug * -batchmode * fail if inference fails * install tf2onnx on yamato * allow onnx for overrides (expect to fail now) * enable logs * fix commandline arg * catch exception from SetModel and exit * cleanup error message * model artifacts, logs as artifacts, fix pip * don't run onnx * cleanup and comment * update extension handling * Update CONTRIBUTING.md (#4170) we removed contributions welcome. * FoodCollectorAgent - don't convert bool to int (#4169) * [docs] Corrected num_layers default value (#4167) * [docs] Fix table formatting (#4168) * [CI] Better hyperparameters for Pyramids-SAC, WalkerStatic-SAC, and Reacher-PPO (#4154) * [bug-fix] Initialize-from being incorrectly loaded as "None" rather than None (#4175) * Modified the documentation of the Heuristic method (default action = previous action) (#4174) * Modifying the documentation to explain that Heuristic method default action will be the previous action decided by the heuristic. Changing this behavior would be a breking change. * Rephrase the working of the documentation of the default action of the Heuristic method * Forgot an import * [MLA-240] Physic-based pose generation (#4166) * hierarchy POC * WIP * abstract class * clean up init * separate files * cleanup, unit test * add Articulation util * sensor WIP * ArticulationBody sensor, starting docs * docstrings, cleanup * hierarchy tests, transform operators * unit tests * use Pose struct instead * delete QTTransform * rename * renames and compile fixes * remove ArticulationBodySensor* for now * revert CrawlerAgent changes * rename * Add TargetController/OrientationCubeController Components & Bugfix (#4157) * added Target and OCube controllers. updated crawler envs * update walker prefab * add refs to prefab * Update Crawler.prefab * update platform, ragdoll, ocube prefabs * reformat file * reformat files * fix behavior name * add final retrained crawler and walker nn files * collect hip ocube rot in world space * update crawler observations and update prefabs * change to 20M steps * update crwl prefab to 142 observ * update obsvs to 241. add expvel reward * change walkspeed to 3 * add new crawler and walker nn files * adjust rewards * enable other pairs * add RewardManager * cleanup about to do final training * cleanup add nn files for increased facing rew reduced height rew * try no facing rew * add vel only policy, try dy target * inc torq on cube * added dynamic cube nn. gonna try 40M steps * add 40M step test, more cleanup * change back to 20M steps * Update WalkerStatic.unity * add no vel pen nn file * .005 head height rew * remove extra walker in scene * Update WalkerWithTargetPair.prefab * Update WalkerStatic.unity * more cleanup add new nn file with less head height reward * added Target and OCube controllers. updated crawler envs * update walker prefab * add refs to prefab * Update Crawler.prefab * update platform, ragdoll, ocube prefabs * reformat file * reformat files * fix behavior name * add final retrained crawler and walker nn files * collect hip ocube rot in world space * update crawler observations and update prefabs * change to 20M steps * update crwl prefab to 142 observ * update obsvs to 241. add expvel reward * change walkspeed to 3 * add new crawler and walker nn files * adjust rewards * enable other pairs * add RewardManager * cleanup about to do final training * cleanup add nn files for increased facing rew reduced height rew * try no facing rew * add vel only policy, try dy target * inc torq on cube * added dynamic cube nn. gonna try 40M steps * add 40M step test, more cleanup * change back to 20M steps * Update WalkerStatic.unity * add no vel pen nn file * .005 head height rew * remove extra walker in scene * Update WalkerWithTargetPair.prefab * Update WalkerStatic.unity * more cleanup add new nn file with less head height reward * cleanup * remove comment * more cleanup * correct format * Update ProjectVersion.txt * change to Log() * cleanup * use the starting y position instead of a hard coded height * test old fromtorot * add 236 model * testing new 236 nn files * add final walker nn files * cleanup * crawler cleanup * update crawler observ size * add final crawler nn files * fixed formatting ssues * [refactor] Remove BrainParameters from Python code (#4138) * Fix extension package tests (#4189) * Move Heuristic fixes to changelog bug section (#4177) * Fix a typo in Python-API.md (#4179) Fix behavior_spec to behavior_specs * Docs: note about required Windows Python x86-64 (#4060) * note about required Windows Python x86-64 Co-authored-by: Arthur Juliani <[email protected]> Co-authored-by: andrewcoh <[email protected]> * documentation touchups (#4099) * doc updates getting started page now uses consistent run-id re-order create-new docs to have less back/forth between unity and text editor * add link explaining decisions where we tell the reader to modify its parameter * Fix 3DBall PPO hard regression (#4133) * enforce warnings-as-errors, fix warning (#4191) * (case 1255312) Conditionally use different namespace for ScriptedImporters (#4188) * (case 1255312) Conditionally use different namespace for ScriptedImporters. * Add semi-colon. * Update barracuda dependency. * Update changelog. * Fix PR number. * Refactor of Curriculum and parameter sampling (#4160) * Introduced the Constant Parameter Sampler that will be useful later as samplers and floats can be used interchangeably * Refactored the settings.py to refect the new format of the config.yaml * First working version * Added the unit tests * Update to Upgrade for Updates * fixing the tests * Upgraded the config files * Fixes * Additional error catching * addressing some comments * Making the code nicer with cattr * Added and registered an unstructure hook for PrameterRandomization * Updating C# Walljump * Adding comments * Add test for settings export (#4164) * Add test for settings export * Update ml-agents/mlagents/trainers/tests/test_settings.py Co-authored-by: Vincent-Pierre BERGES <[email protected]> Co-authored-by: Vincent-Pierre BERGES <[email protected]> * Including environment parameters for the test for settings export * First documentation update * Fixing a link * Updating changelog and migrating * adding some more tests for the conversion script * fixing bugs and using samplers in the walljump curriculum * Changing the format of the curriculum file as per discussion * Addressing comments * Update ml-agents/mlagents/trainers/settings.py Co-authored-by: Ervin T. <[email protected]> * Update docs/Migrating.md Co-authored-by: Chris Elion <[email protected]> * addressing comments Co-authored-by: Ervin T <[email protected]> Co-authored-by: Chris Elion <[email protected]> * [bug-fix] Make StatsReporter thread-safe (#4201) * Update changelog for release 4 (#4202) * don't allow --num-envs >1 with no --env (#4203) * don't allow --num-envs >1 with no --env * changelog * PR feedback * Add warning if behavior name not found in trainer config (#4204) Co-authored-by: Ervin T. <[email protected]> Co-authored-by: Chris Elion <[email protected]> * better logging for NaN rewards (#4205) * [MLA-1145] don't allow --num-envs >1 with no --env (#4209) * don't allow --num-envs >1 with no --env (#4203) * Convert checkpoints to .NN (#4127) This change adds an export to .nn for each checkpoint generated by RLTrainer and adds a NNCheckpointManager to track the generated checkpoints and final model in training_status.json. Co-authored-by: Jonathan Harper <[email protected]> * Update version for release 4 (master) (#4207) * Update version for release 4 * newline in json file * actually fix newline Co-authored-by: Chris Elion <[email protected]> * Update version for release 4 (release branch) (#4210) * Update versions for release 4 * Link validation file should ignore itself * Remove 'unreleased' section from changelog * Change to 0.18.0 for python versions * also update extensions package version Co-authored-by: Chris Elion <[email protected]> * [MLA-1141] Rigidbody and ArticulationBody sensors (#4192) * Update release table (#4221) * Add macOS Catalina notice to FAQ (#4222) * Add macOS Catalina notice to FAQ * Change wording and line breaks. * update com.unity.ml-agents.extensions to Apache 2.0 license (#4223) * update com.unity.ml-agents.extensions to Apache 2.0 license (#4223) (#4225) * Throw if Academy.EnvironmentStep() is called recursively (#4227) * speed up infinite loops * changelog * fix job deps (#4230) * use old yamato test config (#4231) * Run all package test types (#4232) * Revert "use old yamato test config" (#4233) * Revert "use old yamato test config (#4231)" This reverts commit e5e21dc. * Apply changes from #4232 * update document (#4237) small fix to documentation formatting * Update changelog for .nn checkpoints (#4240) Co-authored-by: sankalp04 <[email protected]> * Don't drop multiple stats from the same step (#4236) * add pyupgrade to pre-commit and run (#4239) * [MLA-427] make pyupgrade convert f-strings too (#4244) * make pyupgrade convert f-strings too * Run coverage checks with python3 (#4245) * Run code coverage for extensions package (#4243) * run code coverage for extensions package * reasonable coverage pct * fix artifactory url (#4246) * Refactor TFPolicy and Policy * don't try/except for control flow (#4251) * Update two docstring references to NNPolicy * Longer demos for ragdoll envs (#4247) * [docs] buffer_size parameter clarification (#4252) * [docs] buffer_size parameter clarification It was not fully clear that it has a different behavior for PPO and SAC. The docs update should improve the understanding. * [docs] updated buffer_size parameter clarification Co-authored-by: Vincent-Pierre BERGES <[email protected]> Co-authored-by: Vincent-Pierre BERGES <[email protected]> * Remove un-needed check * Remove irrelevant tests * Address feedback * [MLA-1172] Reduce calls to training_behaviors (#4259) * Remove unnecessary line (#4260) * [MLA-1138] joint observations (#4224) * Update to latest master * Refactor TFPolicy and Policy (#4254) * Refactor TFPolicy and Policy * Move TF-specific files to tf/ folder * Move EncoderType and ScheduleType to settings.py * Move Torch files to separate folder * Update imports to keep Torch working * [bugfix] summary writer no longer crashes if Hyperparameters could not be written (#4265) * Bug fix, returnning an empty string in case of error breaks the summary writter * addressing comments * [refactor] Move TF-specific files to tf/ folder (#4266) * Break up models.py into separate files * Use network_settings for configuring networks * [refactor] Make classes except Optimizer framework agnostic (#4268) * Fixing tensorboard command line params (#4262) * Update Using-Tensorboard.md "--logdir=results" is broken in newer versions of tensor board; "logdir results" without the equal sign works. See tensorflow/tensorboard#686 * Removing equal sign from tensorboard command line params in docs Co-authored-by: Nancy Iskander <[email protected]> Co-authored-by: Pulkit Midha <[email protected]> Co-authored-by: andrewcoh <[email protected]> Co-authored-by: Furkan Çelik <[email protected]> Co-authored-by: Yuan Gao <[email protected]> Co-authored-by: Chris Elion <[email protected]> Co-authored-by: Anupam Bhatnagar <[email protected]> Co-authored-by: Jeffrey Shih <[email protected]> Co-authored-by: Christian Coenen <[email protected]> Co-authored-by: Florian Pöhler <[email protected]> Co-authored-by: Vincent-Pierre BERGES <[email protected]> Co-authored-by: Hunter-Unity <[email protected]> Co-authored-by: yongjun823 <[email protected]> Co-authored-by: Stefano Cecere <[email protected]> Co-authored-by: Arthur Juliani <[email protected]> Co-authored-by: Tom Thompson <[email protected]> Co-authored-by: Chris Goy <[email protected]> Co-authored-by: sankalp04 <[email protected]> Co-authored-by: Jonathan Harper <[email protected]> Co-authored-by: Ruo-Ping (Rachel) Dong <[email protected]> Co-authored-by: Nancy Iskander <[email protected]> Co-authored-by: Nancy Iskander <[email protected]> * [refactor] Refactor normalizers and encoders (#4275) * Refactor normalizers and encoders * Unify Critic and ValueNetwork * Rename ActionVectorEncoder * Update docstring of create_encoders * Add docstring to UnnormalizedInputEncoder * fix onnx save path and output_name * add Saver class (only TF working) * fix pytorch checkpointing. add tensors in Normalizer as parameter * remove print * move tf and add torch model serialization * remove * remove unused * add sac checkpoint * small improvements * small improvements * remove print * move checkpoint_path logic to saver * [refactor] Refactor Actor and Critic classes (#4287) * fix onnx input * fix formatting and test * [bug-fix] Fix non-LSTM SeparateActorCritic (#4306) * small improvements * small improvement * [bug-fix] Fix error with discrete probs (#4309) * [tests] Add tests for core PyTorch files (#4292) * [feature] Fix TF tests, add --torch CLI option, allow run TF without torch installed (#4305) * Test fixes on add-fire (#4317) * fix tests * Add components directory and init (#4320) * [add-fire] Halve Gaussian entropy (#4319) * Halve entropy * Fix utils test * [add-fire] Add learning rate and beta/epsilon decay to PyTorch (#4318) * Added Reward Providers for Torch (#4280) * Added Reward Providers for Torch * Use NetworkBody to encode state in the reward providers * Integrating the reward prodiders with ppo and torch * work in progress, integration with PPO. Not training properly Pyramids at the moment * Integration in PPO * Removing duplicate file * Gail and Curiosity working * addressing comments * Enfore float32 for tests * enfore np.float32 in buffer * Fix discrete export (#4322) Fix discrete export * [add-fire] Fix CategoricalDistInstance test and replace `range` with `arange` (#4327) * Develop add fire layers (#4321) * Layer initialization + swish as a layer * integrating with the existing layers * fixing tests * setting the seed for a test * Using swish and fixing tests * fixing typo * [add-fire] Merge post-0.19.0 master into add-fire (#4328) * Revert "[add-fire] Merge post-0.19.0 master into add-fire (#4328)" (#4330) This reverts commit 9913e71. * More comments and Made ResNetBlock (#4329) * update saver interface and add tests * update * Fixed the reporting of the discriminator loss (#4348) * Fixed the reporting of the discriminator loss * Update ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py * fixing pre-commit test * Fix ONNX import for continuous * fix export input names * Behavioral Cloning Pytorch (#4293) * fix export input name * [add-fire] Add LSTM to SAC, LSTM fixes and initializations (#4324) * add comments * fix bc tests * change brain_name to behavior_name * reverting Project settings * [add-fire] Fix masked mean for 2d tensors (#4364) * Removing the experiment script from add fire (#4373) * Removing the experiment script * Removing the script * [add-fire] Add tests and fix issues with Policy (#4372) * Pytorch ghost trainer (#4370) * add test_simple_rl tests to torch * revert tests * Fix of the test for multi visual input * Make reset block submodule * fix export input_name * [add-fire] Memory class abstraction (#4375) * make visual input channel first for export * Don't use torch.split in LSTM * Add fire to test_simple_rl.py (#4378) Co-authored-by: Vincent-Pierre BERGES <[email protected]> Co-authored-by: Ervin T <[email protected]> * reverting unity_to_external_pb2_grpc.py * remove duplicate of curr documentation * Revert "remove duplicate of curr documentation" This reverts commit 3d7b809. * remove duplicated curriculum doc (#4386) * Fixed discrete models * Always export one Action tensor (#4388) * [add-fire] Revert unneeded changes back to master (#4389) * add comment * fix test * add fire clean up docstrings in create policies (#4391) * [add-fire] Update changelog (#4397) Co-authored-by: Arthur Juliani <[email protected]> Co-authored-by: Vincent-Pierre BERGES <[email protected]> Co-authored-by: Pulkit Midha <[email protected]> Co-authored-by: andrewcoh <[email protected]> Co-authored-by: Furkan Çelik <[email protected]> Co-authored-by: Yuan Gao <[email protected]> Co-authored-by: Chris Elion <[email protected]> Co-authored-by: Anupam Bhatnagar <[email protected]> Co-authored-by: Jeffrey Shih <[email protected]> Co-authored-by: Christian Coenen <[email protected]> Co-authored-by: Florian Pöhler <[email protected]> Co-authored-by: Hunter-Unity <[email protected]> Co-authored-by: yongjun823 <[email protected]> Co-authored-by: Stefano Cecere <[email protected]> Co-authored-by: Tom Thompson <[email protected]> Co-authored-by: Chris Goy <[email protected]> Co-authored-by: sankalp04 <[email protected]> Co-authored-by: Jonathan Harper <[email protected]> Co-authored-by: Ruo-Ping (Rachel) Dong <[email protected]> Co-authored-by: Nancy Iskander <[email protected]> Co-authored-by: Nancy Iskander <[email protected]> Co-authored-by: Ruo-Ping Dong <[email protected]> Co-authored-by: Andrew Cohen <[email protected]>
1 parent 7f0b9e2 commit 672b608

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+3928
-157
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ and this project adheres to
3232
- The interaction between EnvManager and TrainerController was changed; EnvManager.advance() was split into to stages,
3333
and TrainerController now uses the results from the first stage to handle new behavior names. This change speeds up
3434
Python training by approximately 5-10%. (#4259)
35+
- Experimental PyTorch support has been added. Use `--torch` when running `mlagents-learn`, or add
36+
`framework: pytorch` to your trainer configuration (under the behavior name) to enable it.
37+
Note that PyTorch 1.6.0 or greater should be installed to use this feature; see
38+
[the PyTorch website](https://pytorch.org/) for installation instructions. (#4335)
3539

3640
### Minor Changes
3741
#### com.unity.ml-agents (C#)

ml-agents/mlagents/trainers/buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def extend(self, data: np.ndarray) -> None:
4848
Adds a list of np.arrays to the end of the list of np.arrays.
4949
:param data: The np.array list to append.
5050
"""
51-
self += list(np.array(data))
51+
self += list(np.array(data, dtype=np.float32))
5252

5353
def set(self, data):
5454
"""

ml-agents/mlagents/trainers/cli_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,13 @@ def _create_parser() -> argparse.ArgumentParser:
168168
action=DetectDefaultStoreTrue,
169169
help="Forces training using CPU only",
170170
)
171+
argparser.add_argument(
172+
"--torch",
173+
default=False,
174+
action=DetectDefaultStoreTrue,
175+
help="(Experimental) Use the PyTorch framework instead of TensorFlow. Install PyTorch "
176+
"before using this option",
177+
)
171178

172179
eng_conf = argparser.add_argument_group(title="Engine Configuration")
173180
eng_conf.add_argument(

ml-agents/mlagents/trainers/ghost/trainer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,10 @@ def save_model(self) -> None:
304304
self.trainer.save_model()
305305

306306
def create_policy(
307-
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
307+
self,
308+
parsed_behavior_id: BehaviorIdentifiers,
309+
behavior_spec: BehaviorSpec,
310+
create_graph: bool = False,
308311
) -> Policy:
309312
"""
310313
Creates policy with the wrapped trainer's create_policy function
@@ -313,10 +316,10 @@ def create_policy(
313316
team are grouped. All policies associated with this team are added to the
314317
wrapped trainer to be trained.
315318
"""
316-
policy = self.trainer.create_policy(parsed_behavior_id, behavior_spec)
317-
policy.create_tf_graph()
319+
policy = self.trainer.create_policy(
320+
parsed_behavior_id, behavior_spec, create_graph=True
321+
)
318322
self.trainer.saver.initialize_or_load(policy)
319-
policy.init_load_weights()
320323
team_id = parsed_behavior_id.team_id
321324
self.controller.subscribe_team_id(team_id, self)
322325

@@ -326,7 +329,6 @@ def create_policy(
326329
parsed_behavior_id, behavior_spec
327330
)
328331
self.trainer.add_policy(parsed_behavior_id, internal_trainer_policy)
329-
internal_trainer_policy.init_load_weights()
330332
self.current_policy_snapshot[
331333
parsed_behavior_id.brain_name
332334
] = internal_trainer_policy.get_weights()
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Dict, Optional, Tuple, List
2+
import torch
3+
import numpy as np
4+
5+
from mlagents.trainers.buffer import AgentBuffer
6+
from mlagents.trainers.trajectory import SplitObservations
7+
from mlagents.trainers.torch.components.bc.module import BCModule
8+
from mlagents.trainers.torch.components.reward_providers import create_reward_provider
9+
10+
from mlagents.trainers.policy.torch_policy import TorchPolicy
11+
from mlagents.trainers.optimizer import Optimizer
12+
from mlagents.trainers.settings import TrainerSettings
13+
from mlagents.trainers.torch.utils import ModelUtils
14+
15+
16+
class TorchOptimizer(Optimizer): # pylint: disable=W0223
17+
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
18+
super().__init__()
19+
self.policy = policy
20+
self.trainer_settings = trainer_settings
21+
self.update_dict: Dict[str, torch.Tensor] = {}
22+
self.value_heads: Dict[str, torch.Tensor] = {}
23+
self.memory_in: torch.Tensor = None
24+
self.memory_out: torch.Tensor = None
25+
self.m_size: int = 0
26+
self.global_step = torch.tensor(0)
27+
self.bc_module: Optional[BCModule] = None
28+
self.create_reward_signals(trainer_settings.reward_signals)
29+
if trainer_settings.behavioral_cloning is not None:
30+
self.bc_module = BCModule(
31+
self.policy,
32+
trainer_settings.behavioral_cloning,
33+
policy_learning_rate=trainer_settings.hyperparameters.learning_rate,
34+
default_batch_size=trainer_settings.hyperparameters.batch_size,
35+
default_num_epoch=3,
36+
)
37+
38+
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
39+
pass
40+
41+
def create_reward_signals(self, reward_signal_configs):
42+
"""
43+
Create reward signals
44+
:param reward_signal_configs: Reward signal config.
45+
"""
46+
for reward_signal, settings in reward_signal_configs.items():
47+
# Name reward signals by string in case we have duplicates later
48+
self.reward_signals[reward_signal.value] = create_reward_provider(
49+
reward_signal, self.policy.behavior_spec, settings
50+
)
51+
52+
def get_trajectory_value_estimates(
53+
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool
54+
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]:
55+
vector_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])]
56+
if self.policy.use_vis_obs:
57+
visual_obs = []
58+
for idx, _ in enumerate(
59+
self.policy.actor_critic.network_body.visual_encoders
60+
):
61+
visual_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
62+
visual_obs.append(visual_ob)
63+
else:
64+
visual_obs = []
65+
66+
memory = torch.zeros([1, 1, self.policy.m_size])
67+
68+
vec_vis_obs = SplitObservations.from_observations(next_obs)
69+
next_vec_obs = [
70+
ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0)
71+
]
72+
next_vis_obs = [
73+
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0)
74+
for _vis_ob in vec_vis_obs.visual_observations
75+
]
76+
77+
value_estimates, next_memory = self.policy.actor_critic.critic_pass(
78+
vector_obs, visual_obs, memory, sequence_length=batch.num_experiences
79+
)
80+
81+
next_value_estimate, _ = self.policy.actor_critic.critic_pass(
82+
next_vec_obs, next_vis_obs, next_memory, sequence_length=1
83+
)
84+
85+
for name, estimate in value_estimates.items():
86+
value_estimates[name] = estimate.detach().cpu().numpy()
87+
next_value_estimate[name] = next_value_estimate[name].detach().cpu().numpy()
88+
89+
if done:
90+
for k in next_value_estimate:
91+
if not self.reward_signals[k].ignore_done:
92+
next_value_estimate[k] = 0.0
93+
94+
return value_estimates, next_value_estimate

ml-agents/mlagents/trainers/policy/tf_policy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def create_tf_graph(self) -> None:
152152
# We do an initialize to make the Policy usable out of the box. If an optimizer is needed,
153153
# it will re-load the full graph
154154
self.initialize()
155+
# Create assignment ops for Ghost Trainer
156+
self.init_load_weights()
155157

156158
def _create_encoder(
157159
self,

0 commit comments

Comments
 (0)