diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2a1861765c..10fec8dd60 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -66,7 +66,7 @@ repos: .*_pb2_grpc.py| .*/tests/.* )$ - require_serial: true + args: [--score=n] # "Local" hooks, see https://pre-commit.com/#repository-local-hooks - repo: local diff --git a/ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py b/ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py index 98f2e3eb5b..7791579aad 100644 --- a/ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py +++ b/ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py @@ -1,11 +1,14 @@ import logging +from typing import Any, Dict, List, Optional import tensorflow as tf from tensorflow.python.client import device_lib +from mlagents.envs.brain import BrainParameters from mlagents.envs.timers import timed from mlagents.trainers.models import EncoderType, LearningRateSchedule from mlagents.trainers.ppo.policy import PPOPolicy from mlagents.trainers.ppo.models import PPOModel +from mlagents.trainers.components.reward_signals import RewardSignal from mlagents.trainers.components.reward_signals.reward_signal_factory import ( create_reward_signal, ) @@ -17,6 +20,23 @@ class MultiGpuPPOPolicy(PPOPolicy): + def __init__( + self, + seed: int, + brain: BrainParameters, + trainer_params: Dict[str, Any], + is_training: bool, + load: bool, + ): + self.towers: List[PPOModel] = [] + self.devices: List[str] = [] + self.model: Optional[PPOModel] = None + self.total_policy_loss: Optional[tf.Tensor] = None + self.reward_signal_towers: List[Dict[str, RewardSignal]] = [] + self.reward_signals: Dict[str, RewardSignal] = {} + + super().__init__(seed, brain, trainer_params, is_training, load) + def create_model( self, brain, trainer_params, reward_signal_configs, is_training, load, seed ): @@ -28,7 +48,7 @@ def create_model( :param seed: Random seed. """ self.devices = get_devices() - self.towers = [] + with self.graph.as_default(): with tf.variable_scope("", reuse=tf.AUTO_REUSE): for device in self.devices: @@ -105,7 +125,6 @@ def create_reward_signals(self, reward_signal_configs): Create reward signals :param reward_signal_configs: Reward signal config. """ - self.reward_signal_towers = [] with self.graph.as_default(): with tf.variable_scope(TOWER_SCOPE_NAME, reuse=tf.AUTO_REUSE): for device_id, device in enumerate(self.devices): @@ -190,7 +209,7 @@ def average_gradients(self, tower_grads): return average_grads -def get_devices(): +def get_devices() -> List[str]: """ Get all available GPU devices """