diff --git a/functions/colab.js b/functions/colab.js index 193d5b3f..9fd88551 100644 --- a/functions/colab.js +++ b/functions/colab.js @@ -37,6 +37,13 @@ exports.handler = async function (event, _) { ) } + if (title === 'Template Reinforcement Learning') { + specific_commands.push( + '!pip install swig\n', + '!pip install gymnasium[box2d]' + ) + } + const md_cell = [ `# ${title} by PyTorch-Ignite Code-Generator\n\n`, 'Please, run the cell below to execute your code.' diff --git a/src/templates/template-reinforcement-learning/README.md b/src/templates/template-reinforcement-learning/README.md new file mode 100644 index 00000000..ab78fd32 --- /dev/null +++ b/src/templates/template-reinforcement-learning/README.md @@ -0,0 +1,35 @@ +[![Code-Generator](https://badgen.net/badge/Template%20by/Code-Generator/ee4c2c?labelColor=eaa700)](https://github.com/pytorch-ignite/code-generator) + +# Reinforcement Learning Template + +This is the Reinforcement Learning template by Code-Generator using OpenAI Gym for the environment CarRacing-v2. + +## Getting Started + +Install the dependencies with `pip`: + +```sh +pip install -r requirements.txt --progress-bar off -U +``` + +### Code structure + +``` +| +|- README.md +| +|- a2c.py : main script to run +|- a2c_model_env.py : Utility functions for the reinforcement learning template for various tasks +|- utils.py : module with various helper functions +|- requirements.txt : dependencies to install with pip +| +|- config_a2c.yaml : global configuration YAML file +``` + +## Training + +### 1 GPU Training + +```sh +python a2c.py config_a2c.yaml +``` diff --git a/src/templates/template-reinforcement-learning/a2c.py b/src/templates/template-reinforcement-learning/a2c.py new file mode 100644 index 00000000..13710548 --- /dev/null +++ b/src/templates/template-reinforcement-learning/a2c.py @@ -0,0 +1,109 @@ +from pprint import pformat +from shutil import copy +from typing import Any + +import ignite.distributed as idist +import torch +from ignite.engine import Events +from ignite.handlers import LRScheduler + +from ignite.utils import manual_seed + +from utils import * + +from a2c_model_env import make_a2c_models, make_collector, make_loss, make_optim, make_test_env + + +def main(): + config = setup_config() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + config.device = f"{device}" + + rank = idist.get_rank() + manual_seed(config.seed + rank) + config.output_dir = setup_output_dir(config, rank) + if rank == 0: + save_config(config, config.output_dir) + + actor, critic = make_a2c_models(config) + actor = actor.to(device) + critic = critic.to(device) + + collector = make_collector(config, policy=actor) + loss_module, adv_module = make_loss(config, actor_network=actor, value_network=critic) + optim = make_optim(config, actor_network=actor, value_network=critic) + + batch_size = config.total_frames * config.num_envs + total_network_updates = config.total_frames // batch_size + + scheduler = None + if config.lr_scheduler: + scheduler = torch.optim.lr_scheduler.LinearLR(optim, total_iters=total_network_updates) + scheduler = LRScheduler(scheduler) + + test_env = make_test_env(config) + + def run_single_timestep(engine, _): + frames_in_batch = engine.state.data.numel() + trainer.state.collected_frames += frames_in_batch * config.frame_skip + data_view = engine.state.data.reshape(-1) + + with torch.no_grad(): + batch = adv_module(data_view) + + # Normalize advantage + adv = batch.get("advantage") + + # mean of the advantage values + loc = adv.mean().item() + # standard deviation of the advantage values + scale = adv.std().clamp_min(1e-6).item() + # normalizing the advantage values + adv = (adv - loc) / scale + batch.set("advantage", adv) + + # Forward pass A2C loss + batch = batch.to(device) + loss = loss_module(batch) + loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + learning step + loss_sum.backward() + grad_norm = torch.nn.utils.clip_grad_norm_(list(actor.parameters()) + list(critic.parameters()), max_norm=0.5) + engine.state.metrics = { + "loss_sum": loss_sum.item(), + } + optim.step() + optim.zero_grad() + + trainer = Engine(run_single_timestep) + + logger = setup_logging(config) + logger.info("Configuration: \n%s", pformat(vars(config))) + trainer.logger = logger + + if config.lr_scheduler: + trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler) + + trainer.add_event_handler( + Events.ITERATION_COMPLETED(every=config.log_every_episodes), + log_metrics, + tag="train", + ) + + @trainer.on(Events.ITERATION_STARTED) + def update_data(): + # print(f"New iteration started") + trainer.state.data = next(iter(collector)) + trainer.state.collected_frames = 0 + + @trainer.on(Events.ITERATION_COMPLETED) + def log2(): + collector.update_policy_weights_() + + # timesteps = range(config.steps_per_episode) + trainer.run(epoch_length=int(config.total_frames / config.frames_per_batch), max_epochs=1) + + +if __name__ == "__main__": + main() diff --git a/src/templates/template-reinforcement-learning/a2c_model_env.py b/src/templates/template-reinforcement-learning/a2c_model_env.py new file mode 100644 index 00000000..afcdadeb --- /dev/null +++ b/src/templates/template-reinforcement-learning/a2c_model_env.py @@ -0,0 +1,249 @@ +import gymnasium as gym + +import torch + +import torch.nn +import torch.optim +from ignite.contrib.engines import common +from ignite.engine import Engine + +from ignite.engine.events import Events + +from ignite.utils import setup_logger + +from tensordict.nn import TensorDictModule + +from torchrl.collectors import SyncDataCollector +from torchrl.data import CompositeSpec + +from torchrl.data.tensor_specs import DiscreteBox +from torchrl.envs import ( + EnvCreator, + ExplorationType, + ObservationNorm, + ParallelEnv, + StepCounter, + ToTensorImage, + TransformedEnv, +) + +from torchrl.envs.libs.gym import GymWrapper +from torchrl.modules import ActorValueOperator, ConvNet, MLP, OneHotCategorical, ProbabilisticActor, ValueOperator +from torchrl.objectives import A2CLoss +from torchrl.objectives.value import GAE +from torchrl.objectives.value.advantages import GAE + +# from torchrl.trainers.helpers.envs import get_norm_state_dict + +from utils import * + + +def make_base_env(config): + env_kwargs = {"id": "CarRacing-v2", "continuous": False, "render_mode": "rgb_array"} + + env = gym.make(**env_kwargs) + + if config.render: + + def trigger_recording(episode): + return episode % config.save_every_episodes == 0 + + env = gym.wrappers.RecordVideo(env, config.recordings_path, episode_trigger=trigger_recording, video_length=0) + + env_kwargs2 = { + "device": config.device, + "from_pixels": True, + "pixels_only": True, + } + + env = GymWrapper(env, **env_kwargs2) + print("Base Env Created") + return env + + +def get_stats(config): + env = make_transformed_env_pixels(make_base_env(config), config) + return get_norm_state_dict(env) + + +def make_transformed_env_pixels(base_env, config): + env = TransformedEnv(base_env) + env.append_transform(ToTensorImage()) + env.append_transform(StepCounter()) + + return env + + +def make_parallel_env(config, state_dict): + num_envs = config.num_envs + env = make_transformed_env_pixels(ParallelEnv(num_envs, EnvCreator(lambda: make_base_env(config))), config) + for t in env.transform: + if isinstance(t, ObservationNorm): + t.init_stats(3, cat_dim=1, reduce_dim=[0, 1]) + env.load_state_dict(state_dict, strict=False) + return env + + +def make_a2c_models(config): + base_env = make_transformed_env_pixels(make_base_env(config), config) + + common_module, policy_module, value_module = make_a2c_models_pixels(base_env, config) + + actor_critic = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_module, + value_operator=value_module, + ) + + actor = actor_critic.get_policy_operator() + critic = actor_critic.get_value_head() # to avoid + + with torch.no_grad(): + td = base_env.rollout(max_steps=100, break_when_any_done=False) + td = actor(td) + td = critic(td) + del td + + return actor, critic + + +def make_a2c_models_pixels(base_env, config): + env = base_env + + # define the input shape + input_shape = env.observation_spec["pixels"].shape + + # defining the distribution class and kwargs, in this case, the action space is DiscreteBox + if isinstance(env.action_spec.space, DiscreteBox): + num_outputs = env.action_spec.space.n + distribution_class = OneHotCategorical + distribution_kwargs = {} + + # Define the input keys + in_keys = ["pixels"] + + # Define a shared Module and TensorDictModule (CNN + MLP) + common_cnn = ConvNet( + activation_class=torch.nn.ReLU, + num_cells=[32, 64, 64], + kernel_sizes=[3, 1, 1], + strides=[2, 2, 1], + device=config.device, + ) + common_cnn_output = common_cnn(torch.ones(input_shape).to(config.device)) + common_mlp = MLP( + in_features=common_cnn_output.shape[-1], + activation_class=torch.nn.ReLU, + activate_last_layer=True, + out_features=512, + num_cells=[], + device=config.device, + ) + common_mlp_output = common_mlp(common_cnn_output).to(config.device) + + # Define shared net as TensorDictModule + common_module = TensorDictModule( + module=torch.nn.Sequential(common_cnn, common_mlp), + in_keys=in_keys, + out_keys=["common_features"], + ) + + # Define on head for the policy + policy_net = MLP( + in_features=common_mlp_output.shape[-1], + out_features=num_outputs, + num_cells=[256], + device=config.device, + ) + policy_module = TensorDictModule( + module=policy_net, + in_keys=["common_features"], + out_keys=["logits"], + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + policy_module, + in_keys=["logits"], + spec=CompositeSpec(action=env.action_spec), + safe=True, + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define another head for the value + value_net = MLP( + in_features=common_mlp_output.shape[-1], + out_features=1, + num_cells=[256], + device=config.device, + ) + value_module = ValueOperator( + value_net, + in_keys=["common_features"], + ) + + return common_module, policy_module, value_module + + +def make_collector(config, policy): + collector_class = SyncDataCollector + state_dict = get_stats(config) + collector = collector_class( + make_parallel_env(config, state_dict), + policy=policy, + frames_per_batch=config.frames_per_batch, + total_frames=config.total_frames, + device=config.device, + max_frames_per_traj=config.max_frames_per_traj, + ) + return collector + + +def make_advantage_module(config, value_network): + advantage_module = GAE( + gamma=config.gamma, + lmbda=config.gae_lambda, + value_network=value_network, + average_gae=True, + ) + return advantage_module + + +def make_test_env(config): + num_envs = 1 + state_dict = get_stats(config) + env = make_parallel_env(config, state_dict) + return env + + +def make_loss(config, actor_network, value_network): + advantage_module = make_advantage_module(config, value_network) + loss_module = A2CLoss( + actor=actor_network, + critic=value_network, + loss_critic_type=config.loss_critic_type, + entropy_coef=config.entropy_coef, + critic_coef=config.critic_coef, + entropy_bonus=True, + ) + loss_module.make_value_estimator(gamma=config.gamma) + return loss_module, advantage_module + + +def make_optim(config, actor_network, value_network): + optim = torch.optim.Adam( + list(actor_network.parameters()) + list(value_network.parameters()), + lr=config.lr, + weight_decay=config.weight_decay, + ) + return optim + + +def get_norm_state_dict(env): + """Gets the normalization loc and scale from the env state_dict.""" + sd = env.state_dict() + sd = {key: val for key, val in sd.items() if key.endswith("loc") or key.endswith("scale")} + return sd diff --git a/src/templates/template-reinforcement-learning/config_a2c.yaml b/src/templates/template-reinforcement-learning/config_a2c.yaml new file mode 100644 index 00000000..aceb6e2b --- /dev/null +++ b/src/templates/template-reinforcement-learning/config_a2c.yaml @@ -0,0 +1,37 @@ +# task and env +frame_skip: 2 +num_envs: 1 +reward_scaling: 1.0 +from_pixels: True +render_mode: rgb_array +continuous: False +pixels_only: True + +# collector: +frames_per_batch: 64 +total_frames: 1_000_000 +max_frames_per_traj: -1 + +# logger: +log_interval: 10000 + +# optim: +lr: 0.0005 +weight_decay: 0.0 +lr_scheduler: True + +# loss: +gamma: 0.99 +gae_lambda: 0.95 +critic_coef: 0.5 +entropy_coef: 0.01 +loss_critic_type: l2 + +seed: 666 +render: true +recordings_path: ./recordings +max_episodes: 10000 +log_every_episodes: 50 +save_every_episodes: 50 +output_dir: ./logs +debug: false diff --git a/src/templates/template-reinforcement-learning/requirements.txt b/src/templates/template-reinforcement-learning/requirements.txt new file mode 100644 index 00000000..af9d9f2f --- /dev/null +++ b/src/templates/template-reinforcement-learning/requirements.txt @@ -0,0 +1,3 @@ +#::= from_template_common ::# +torchrl +moviepy \ No newline at end of file diff --git a/src/templates/template-reinforcement-learning/utils.py b/src/templates/template-reinforcement-learning/utils.py new file mode 100644 index 00000000..aec8eaac --- /dev/null +++ b/src/templates/template-reinforcement-learning/utils.py @@ -0,0 +1 @@ +#::= from_template_common ::# diff --git a/src/templates/templates.json b/src/templates/templates.json index 47b4257a..2690ac9d 100644 --- a/src/templates/templates.json +++ b/src/templates/templates.json @@ -43,5 +43,13 @@ "utils.py", "test_all.py", "requirements.txt" + ], + "template-reinforcement-learning": [ + "README.md", + "config_a2c.yaml", + "a2c_model_env.py", + "a2c.py", + "utils.py", + "requirements.txt" ] }