diff --git a/ml-agents/mlagents/trainers/components/reward_signals/__init__.py b/ml-agents/mlagents/trainers/components/reward_signals/__init__.py index 80101a8110..e35b03a038 100644 --- a/ml-agents/mlagents/trainers/components/reward_signals/__init__.py +++ b/ml-agents/mlagents/trainers/components/reward_signals/__init__.py @@ -23,7 +23,7 @@ def __init__(self, policy: TFPolicy, settings: RewardSignalSettings): """ Initializes a reward signal. At minimum, you must pass in the policy it is being applied to, the reward strength, and the gamma (discount factor.) - :param policy: The Policy object (e.g. NNPolicy) that this Reward Signal will apply to. + :param policy: The Policy object (e.g. TFPolicy) that this Reward Signal will apply to. :param settings: Settings parameters for this Reward Signal, including gamma and strength. :return: A RewardSignal object. """ diff --git a/ml-agents/mlagents/trainers/policy/nn_policy.py b/ml-agents/mlagents/trainers/policy/nn_policy.py deleted file mode 100644 index bbbec0d827..0000000000 --- a/ml-agents/mlagents/trainers/policy/nn_policy.py +++ /dev/null @@ -1,275 +0,0 @@ -from typing import Any, Dict, Optional, List -from mlagents.tf_utils import tf -from mlagents_envs.timers import timed -from mlagents_envs.base_env import DecisionSteps, BehaviorSpec -from mlagents.trainers.models import EncoderType -from mlagents.trainers.models import ModelUtils -from mlagents.trainers.policy.tf_policy import TFPolicy -from mlagents.trainers.settings import TrainerSettings -from mlagents.trainers.distributions import ( - GaussianDistribution, - MultiCategoricalDistribution, -) - -EPSILON = 1e-6 # Small value to avoid divide by zero - - -class NNPolicy(TFPolicy): - def __init__( - self, - seed: int, - behavior_spec: BehaviorSpec, - trainer_params: TrainerSettings, - is_training: bool, - model_path: str, - load: bool, - tanh_squash: bool = False, - reparameterize: bool = False, - condition_sigma_on_obs: bool = True, - create_tf_graph: bool = True, - ): - """ - Policy that uses a multilayer perceptron to map the observations to actions. Could - also use a CNN to encode visual input prior to the MLP. Supports discrete and - continuous action spaces, as well as recurrent networks. - :param seed: Random seed. - :param brain: Assigned BrainParameters object. - :param trainer_params: Defined training parameters. - :param is_training: Whether the model should be trained. - :param load: Whether a pre-trained model will be loaded or a new one created. - :param model_path: Path where the model should be saved and loaded. - :param tanh_squash: Whether to use a tanh function on the continuous output, or a clipped output. - :param reparameterize: Whether we are using the resampling trick to update the policy in continuous output. - """ - super().__init__(seed, behavior_spec, trainer_params, model_path, load) - self.grads = None - self.update_batch: Optional[tf.Operation] = None - num_layers = self.network_settings.num_layers - self.h_size = self.network_settings.hidden_units - if num_layers < 1: - num_layers = 1 - self.num_layers = num_layers - self.vis_encode_type = self.network_settings.vis_encode_type - self.tanh_squash = tanh_squash - self.reparameterize = reparameterize - self.condition_sigma_on_obs = condition_sigma_on_obs - self.trainable_variables: List[tf.Variable] = [] - - # Non-exposed parameters; these aren't exposed because they don't have a - # good explanation and usually shouldn't be touched. - self.log_std_min = -20 - self.log_std_max = 2 - if create_tf_graph: - self.create_tf_graph() - - def get_trainable_variables(self) -> List[tf.Variable]: - """ - Returns a List of the trainable variables in this policy. if create_tf_graph hasn't been called, - returns empty list. - """ - return self.trainable_variables - - def create_tf_graph(self) -> None: - """ - Builds the tensorflow graph needed for this policy. - """ - with self.graph.as_default(): - tf.set_random_seed(self.seed) - _vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) - if len(_vars) > 0: - # We assume the first thing created in the graph is the Policy. If - # already populated, don't create more tensors. - return - - self.create_input_placeholders() - encoded = self._create_encoder( - self.visual_in, - self.processed_vector_in, - self.h_size, - self.num_layers, - self.vis_encode_type, - ) - if self.use_continuous_act: - self._create_cc_actor( - encoded, - self.tanh_squash, - self.reparameterize, - self.condition_sigma_on_obs, - ) - else: - self._create_dc_actor(encoded) - self.trainable_variables = tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy" - ) - self.trainable_variables += tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm" - ) # LSTMs need to be root scope for Barracuda export - - self.inference_dict: Dict[str, tf.Tensor] = { - "action": self.output, - "log_probs": self.all_log_probs, - "entropy": self.entropy, - } - if self.use_continuous_act: - self.inference_dict["pre_action"] = self.output_pre - if self.use_recurrent: - self.inference_dict["memory_out"] = self.memory_out - - # We do an initialize to make the Policy usable out of the box. If an optimizer is needed, - # it will re-load the full graph - self._initialize_graph() - - @timed - def evaluate( - self, decision_requests: DecisionSteps, global_agent_ids: List[str] - ) -> Dict[str, Any]: - """ - Evaluates policy for the agent experiences provided. - :param decision_requests: DecisionSteps object containing inputs. - :param global_agent_ids: The global (with worker ID) agent ids of the data in the batched_step_result. - :return: Outputs from network as defined by self.inference_dict. - """ - feed_dict = { - self.batch_size_ph: len(decision_requests), - self.sequence_length_ph: 1, - } - if self.use_recurrent: - if not self.use_continuous_act: - feed_dict[self.prev_action] = self.retrieve_previous_action( - global_agent_ids - ) - feed_dict[self.memory_in] = self.retrieve_memories(global_agent_ids) - feed_dict = self.fill_eval_dict(feed_dict, decision_requests) - run_out = self._execute_model(feed_dict, self.inference_dict) - return run_out - - def _create_encoder( - self, - visual_in: List[tf.Tensor], - vector_in: tf.Tensor, - h_size: int, - num_layers: int, - vis_encode_type: EncoderType, - ) -> tf.Tensor: - """ - Creates an encoder for visual and vector observations. - :param h_size: Size of hidden linear layers. - :param num_layers: Number of hidden linear layers. - :param vis_encode_type: Type of visual encoder to use if visual input. - :return: The hidden layer (tf.Tensor) after the encoder. - """ - with tf.variable_scope("policy"): - encoded = ModelUtils.create_observation_streams( - self.visual_in, - self.processed_vector_in, - 1, - h_size, - num_layers, - vis_encode_type, - )[0] - return encoded - - def _create_cc_actor( - self, - encoded: tf.Tensor, - tanh_squash: bool = False, - reparameterize: bool = False, - condition_sigma_on_obs: bool = True, - ) -> None: - """ - Creates Continuous control actor-critic model. - :param h_size: Size of hidden linear layers. - :param num_layers: Number of hidden linear layers. - :param vis_encode_type: Type of visual encoder to use if visual input. - :param tanh_squash: Whether to use a tanh function, or a clipped output. - :param reparameterize: Whether we are using the resampling trick to update the policy. - """ - if self.use_recurrent: - self.memory_in = tf.placeholder( - shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" - ) - hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( - encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy" - ) - - self.memory_out = tf.identity(memory_policy_out, name="recurrent_out") - else: - hidden_policy = encoded - - with tf.variable_scope("policy"): - distribution = GaussianDistribution( - hidden_policy, - self.act_size, - reparameterize=reparameterize, - tanh_squash=tanh_squash, - condition_sigma=condition_sigma_on_obs, - ) - - if tanh_squash: - self.output_pre = distribution.sample - self.output = tf.identity(self.output_pre, name="action") - else: - self.output_pre = distribution.sample - # Clip and scale output to ensure actions are always within [-1, 1] range. - output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3 - self.output = tf.identity(output_post, name="action") - - self.selected_actions = tf.stop_gradient(self.output) - - self.all_log_probs = tf.identity(distribution.log_probs, name="action_probs") - self.entropy = distribution.entropy - - # We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control. - self.total_log_probs = distribution.total_log_probs - - def _create_dc_actor(self, encoded: tf.Tensor) -> None: - """ - Creates Discrete control actor-critic model. - :param h_size: Size of hidden linear layers. - :param num_layers: Number of hidden linear layers. - :param vis_encode_type: Type of visual encoder to use if visual input. - """ - if self.use_recurrent: - self.prev_action = tf.placeholder( - shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action" - ) - prev_action_oh = tf.concat( - [ - tf.one_hot(self.prev_action[:, i], self.act_size[i]) - for i in range(len(self.act_size)) - ], - axis=1, - ) - hidden_policy = tf.concat([encoded, prev_action_oh], axis=1) - - self.memory_in = tf.placeholder( - shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" - ) - hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( - hidden_policy, - self.memory_in, - self.sequence_length_ph, - name="lstm_policy", - ) - - self.memory_out = tf.identity(memory_policy_out, "recurrent_out") - else: - hidden_policy = encoded - - self.action_masks = tf.placeholder( - shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks" - ) - - with tf.variable_scope("policy"): - distribution = MultiCategoricalDistribution( - hidden_policy, self.act_size, self.action_masks - ) - # It's important that we are able to feed_dict a value into this tensor to get the - # right one-hot encoding, so we can't do identity on it. - self.output = distribution.sample - self.all_log_probs = tf.identity(distribution.log_probs, name="action") - self.selected_actions = tf.stop_gradient( - distribution.sample_onehot - ) # In discrete, these are onehot - self.entropy = distribution.entropy - self.total_log_probs = distribution.total_log_probs diff --git a/ml-agents/mlagents/trainers/policy/policy.py b/ml-agents/mlagents/trainers/policy/policy.py index d40aca4b81..e97e6ee7a3 100644 --- a/ml-agents/mlagents/trainers/policy/policy.py +++ b/ml-agents/mlagents/trainers/policy/policy.py @@ -1,12 +1,160 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod +from typing import Dict, List, Optional +import numpy as np from mlagents_envs.base_env import DecisionSteps +from mlagents_envs.exception import UnityException + +from mlagents.model_serialization import SerializationSettings from mlagents.trainers.action_info import ActionInfo +from mlagents_envs.base_env import BehaviorSpec +from mlagents.trainers.settings import TrainerSettings, NetworkSettings -class Policy(ABC): - @abstractmethod +class UnityPolicyException(UnityException): + """ + Related to errors with the Trainer. + """ + + pass + + +class Policy: + def __init__( + self, + seed: int, + behavior_spec: BehaviorSpec, + trainer_settings: TrainerSettings, + model_path: str, + load: bool = False, + tanh_squash: bool = False, + reparameterize: bool = False, + condition_sigma_on_obs: bool = True, + ): + self.behavior_spec = behavior_spec + self.trainer_settings = trainer_settings + self.network_settings: NetworkSettings = trainer_settings.network_settings + self.seed = seed + self.act_size = ( + list(behavior_spec.discrete_action_branches) + if behavior_spec.is_action_discrete() + else [behavior_spec.action_size] + ) + self.vec_obs_size = sum( + shape[0] for shape in behavior_spec.observation_shapes if len(shape) == 1 + ) + self.vis_obs_size = sum( + 1 for shape in behavior_spec.observation_shapes if len(shape) == 3 + ) + self.model_path = model_path + self.initialize_path = self.trainer_settings.init_path + self._keep_checkpoints = self.trainer_settings.keep_checkpoints + self.use_continuous_act = behavior_spec.is_action_continuous() + self.num_branches = self.behavior_spec.action_size + self.previous_action_dict: Dict[str, np.array] = {} + self.memory_dict: Dict[str, np.ndarray] = {} + self.normalize = trainer_settings.network_settings.normalize + self.use_recurrent = self.network_settings.memory is not None + self.load = load + self.h_size = self.network_settings.hidden_units + num_layers = self.network_settings.num_layers + if num_layers < 1: + num_layers = 1 + self.num_layers = num_layers + + self.vis_encode_type = self.network_settings.vis_encode_type + self.tanh_squash = tanh_squash + self.reparameterize = reparameterize + self.condition_sigma_on_obs = condition_sigma_on_obs + + self.m_size = 0 + self.sequence_length = 1 + if self.network_settings.memory is not None: + self.m_size = self.network_settings.memory.memory_size + self.sequence_length = self.network_settings.memory.sequence_length + + # Non-exposed parameters; these aren't exposed because they don't have a + # good explanation and usually shouldn't be touched. + self.log_std_min = -20 + self.log_std_max = 2 + + def make_empty_memory(self, num_agents): + """ + Creates empty memory for use with RNNs + :param num_agents: Number of agents. + :return: Numpy array of zeros. + """ + return np.zeros((num_agents, self.m_size), dtype=np.float32) + + def save_memories( + self, agent_ids: List[str], memory_matrix: Optional[np.ndarray] + ) -> None: + if memory_matrix is None: + return + for index, agent_id in enumerate(agent_ids): + self.memory_dict[agent_id] = memory_matrix[index, :] + + def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray: + memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32) + for index, agent_id in enumerate(agent_ids): + if agent_id in self.memory_dict: + memory_matrix[index, :] = self.memory_dict[agent_id] + return memory_matrix + + def remove_memories(self, agent_ids): + for agent_id in agent_ids: + if agent_id in self.memory_dict: + self.memory_dict.pop(agent_id) + + def make_empty_previous_action(self, num_agents): + """ + Creates empty previous action for use with RNNs and discrete control + :param num_agents: Number of agents. + :return: Numpy array of zeros. + """ + return np.zeros((num_agents, self.num_branches), dtype=np.int) + + def save_previous_action( + self, agent_ids: List[str], action_matrix: Optional[np.ndarray] + ) -> None: + if action_matrix is None: + return + for index, agent_id in enumerate(agent_ids): + self.previous_action_dict[agent_id] = action_matrix[index, :] + + def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray: + action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int) + for index, agent_id in enumerate(agent_ids): + if agent_id in self.previous_action_dict: + action_matrix[index, :] = self.previous_action_dict[agent_id] + return action_matrix + + def remove_previous_action(self, agent_ids): + for agent_id in agent_ids: + if agent_id in self.previous_action_dict: + self.previous_action_dict.pop(agent_id) + def get_action( self, decision_requests: DecisionSteps, worker_id: int = 0 ) -> ActionInfo: + raise NotImplementedError + + @abstractmethod + def update_normalization(self, vector_obs: np.ndarray) -> None: + pass + + @abstractmethod + def increment_step(self, n_steps): + pass + + @abstractmethod + def get_current_step(self): + pass + + @abstractmethod + def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None: + pass + + @abstractmethod + def save(self, output_filepath: str, settings: SerializationSettings) -> None: pass diff --git a/ml-agents/mlagents/trainers/policy/tf_policy.py b/ml-agents/mlagents/trainers/policy/tf_policy.py index b23a819b04..261d7bf7c9 100644 --- a/ml-agents/mlagents/trainers/policy/tf_policy.py +++ b/ml-agents/mlagents/trainers/policy/tf_policy.py @@ -1,8 +1,9 @@ from typing import Any, Dict, List, Optional, Tuple -import abc import numpy as np from distutils.version import LooseVersion +from mlagents_envs.timers import timed + from mlagents.model_serialization import SerializationSettings, export_policy_model from mlagents.tf_utils import tf from mlagents import tf_utils @@ -14,9 +15,13 @@ from mlagents.trainers.trajectory import SplitObservations from mlagents.trainers.behavior_id_utils import get_global_agent_id from mlagents_envs.base_env import DecisionSteps -from mlagents.trainers.models import ModelUtils -from mlagents.trainers.settings import TrainerSettings, NetworkSettings +from mlagents.trainers.models import ModelUtils, EncoderType +from mlagents.trainers.settings import TrainerSettings from mlagents.trainers import __version__ +from mlagents.trainers.distributions import ( + GaussianDistribution, + MultiCategoricalDistribution, +) logger = get_logger(__name__) @@ -26,6 +31,8 @@ # determines compatibility with inference in Barracuda. MODEL_FORMAT_VERSION = 2 +EPSILON = 1e-6 # Small value to avoid divide by zero + class UnityPolicyException(UnityException): """ @@ -48,6 +55,10 @@ def __init__( trainer_settings: TrainerSettings, model_path: str, load: bool = False, + tanh_squash: bool = False, + reparameterize: bool = False, + condition_sigma_on_obs: bool = True, + create_tf_graph: bool = True, ): """ Initialized the policy. @@ -57,67 +68,116 @@ def __init__( :param model_path: Where to load/save the model. :param load: If True, load model from model_path. Otherwise, create new model. """ - - self.m_size = 0 - self.trainer_settings = trainer_settings - self.network_settings: NetworkSettings = trainer_settings.network_settings + super().__init__( + seed, + behavior_spec, + trainer_settings, + model_path, + load, + tanh_squash, + reparameterize, + condition_sigma_on_obs, + ) # for ghost trainer save/load snapshots self.assign_phs: List[tf.Tensor] = [] self.assign_ops: List[tf.Operation] = [] - - self.inference_dict: Dict[str, tf.Tensor] = {} self.update_dict: Dict[str, tf.Tensor] = {} - self.sequence_length = 1 - self.seed = seed - self.behavior_spec = behavior_spec - - self.act_size = ( - list(behavior_spec.discrete_action_branches) - if behavior_spec.is_action_discrete() - else [behavior_spec.action_size] - ) - self.vec_obs_size = sum( - shape[0] for shape in behavior_spec.observation_shapes if len(shape) == 1 - ) - self.vis_obs_size = sum( - 1 for shape in behavior_spec.observation_shapes if len(shape) == 3 - ) + self.inference_dict: Dict[str, tf.Tensor] = {} - self.use_recurrent = self.network_settings.memory is not None - self.memory_dict: Dict[str, np.ndarray] = {} - self.num_branches = self.behavior_spec.action_size - self.previous_action_dict: Dict[str, np.array] = {} - self.normalize = self.network_settings.normalize - self.use_continuous_act = behavior_spec.is_action_continuous() - self.model_path = model_path - self.initialize_path = self.trainer_settings.init_path - self.keep_checkpoints = self.trainer_settings.keep_checkpoints self.graph = tf.Graph() self.sess = tf.Session( config=tf_utils.generate_session_config(), graph=self.graph ) self.saver: Optional[tf.Operation] = None - self.seed = seed - if self.network_settings.memory is not None: - self.m_size = self.network_settings.memory.memory_size - self.sequence_length = self.network_settings.memory.sequence_length self._initialize_tensorflow_references() - self.load = load + self.grads = None + self.update_batch: Optional[tf.Operation] = None + self.trainable_variables: List[tf.Variable] = [] + if create_tf_graph: + self.create_tf_graph() - @abc.abstractmethod def get_trainable_variables(self) -> List[tf.Variable]: """ Returns a List of the trainable variables in this policy. if create_tf_graph hasn't been called, returns empty list. """ - pass + return self.trainable_variables - @abc.abstractmethod - def create_tf_graph(self): + def create_tf_graph(self) -> None: """ Builds the tensorflow graph needed for this policy. """ - pass + with self.graph.as_default(): + tf.set_random_seed(self.seed) + _vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + if len(_vars) > 0: + # We assume the first thing created in the graph is the Policy. If + # already populated, don't create more tensors. + return + + self.create_input_placeholders() + encoded = self._create_encoder( + self.visual_in, + self.processed_vector_in, + self.h_size, + self.num_layers, + self.vis_encode_type, + ) + if self.use_continuous_act: + self._create_cc_actor( + encoded, + self.tanh_squash, + self.reparameterize, + self.condition_sigma_on_obs, + ) + else: + self._create_dc_actor(encoded) + self.trainable_variables = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy" + ) + self.trainable_variables += tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm" + ) # LSTMs need to be root scope for Barracuda export + + self.inference_dict = { + "action": self.output, + "log_probs": self.all_log_probs, + "entropy": self.entropy, + } + if self.use_continuous_act: + self.inference_dict["pre_action"] = self.output_pre + if self.use_recurrent: + self.inference_dict["memory_out"] = self.memory_out + + # We do an initialize to make the Policy usable out of the box. If an optimizer is needed, + # it will re-load the full graph + self._initialize_graph() + + def _create_encoder( + self, + visual_in: List[tf.Tensor], + vector_in: tf.Tensor, + h_size: int, + num_layers: int, + vis_encode_type: EncoderType, + ) -> tf.Tensor: + """ + Creates an encoder for visual and vector observations. + :param h_size: Size of hidden linear layers. + :param num_layers: Number of hidden linear layers. + :param vis_encode_type: Type of visual encoder to use if visual input. + :return: The hidden layer (tf.Tensor) after the encoder. + """ + with tf.variable_scope("policy"): + encoded = ModelUtils.create_observation_streams( + self.visual_in, + self.processed_vector_in, + 1, + h_size, + num_layers, + vis_encode_type, + )[0] + return encoded @staticmethod def _convert_version_string(version_string: str) -> Tuple[int, ...]: @@ -147,13 +207,13 @@ def _check_model_version(self, version: str) -> None: def _initialize_graph(self): with self.graph.as_default(): - self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints) + self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints) init = tf.global_variables_initializer() self.sess.run(init) def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None: with self.graph.as_default(): - self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints) + self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints) logger.info(f"Loading model from {model_path}.") ckpt = tf.train.get_checkpoint_state(model_path) if ckpt is None: @@ -222,15 +282,29 @@ def load_weights(self, values): feed_dict[assign_ph] = value self.sess.run(self.assign_ops, feed_dict=feed_dict) + @timed def evaluate( self, decision_requests: DecisionSteps, global_agent_ids: List[str] ) -> Dict[str, Any]: """ Evaluates policy for the agent experiences provided. - :param decision_requests: DecisionSteps input to network. - :return: Output from policy based on self.inference_dict. + :param decision_requests: DecisionSteps object containing inputs. + :param global_agent_ids: The global (with worker ID) agent ids of the data in the batched_step_result. + :return: Outputs from network as defined by self.inference_dict. """ - raise UnityPolicyException("The evaluate function was not implemented.") + feed_dict = { + self.batch_size_ph: len(decision_requests), + self.sequence_length_ph: 1, + } + if self.use_recurrent: + if not self.use_continuous_act: + feed_dict[self.prev_action] = self.retrieve_previous_action( + global_agent_ids + ) + feed_dict[self.memory_in] = self.retrieve_memories(global_agent_ids) + feed_dict = self.fill_eval_dict(feed_dict, decision_requests) + run_out = self._execute_model(feed_dict, self.inference_dict) + return run_out def get_action( self, decision_requests: DecisionSteps, worker_id: int = 0 @@ -556,3 +630,108 @@ def create_input_placeholders(self): trainable=False, dtype=tf.int32, ) + + def _create_cc_actor( + self, + encoded: tf.Tensor, + tanh_squash: bool = False, + reparameterize: bool = False, + condition_sigma_on_obs: bool = True, + ) -> None: + """ + Creates Continuous control actor-critic model. + :param h_size: Size of hidden linear layers. + :param num_layers: Number of hidden linear layers. + :param vis_encode_type: Type of visual encoder to use if visual input. + :param tanh_squash: Whether to use a tanh function, or a clipped output. + :param reparameterize: Whether we are using the resampling trick to update the policy. + """ + if self.use_recurrent: + self.memory_in = tf.placeholder( + shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" + ) + hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( + encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy" + ) + + self.memory_out = tf.identity(memory_policy_out, name="recurrent_out") + else: + hidden_policy = encoded + + with tf.variable_scope("policy"): + distribution = GaussianDistribution( + hidden_policy, + self.act_size, + reparameterize=reparameterize, + tanh_squash=tanh_squash, + condition_sigma=condition_sigma_on_obs, + ) + + if tanh_squash: + self.output_pre = distribution.sample + self.output = tf.identity(self.output_pre, name="action") + else: + self.output_pre = distribution.sample + # Clip and scale output to ensure actions are always within [-1, 1] range. + output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3 + self.output = tf.identity(output_post, name="action") + + self.selected_actions = tf.stop_gradient(self.output) + + self.all_log_probs = tf.identity(distribution.log_probs, name="action_probs") + self.entropy = distribution.entropy + + # We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control. + self.total_log_probs = distribution.total_log_probs + + def _create_dc_actor(self, encoded: tf.Tensor) -> None: + """ + Creates Discrete control actor-critic model. + :param h_size: Size of hidden linear layers. + :param num_layers: Number of hidden linear layers. + :param vis_encode_type: Type of visual encoder to use if visual input. + """ + if self.use_recurrent: + self.prev_action = tf.placeholder( + shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action" + ) + prev_action_oh = tf.concat( + [ + tf.one_hot(self.prev_action[:, i], self.act_size[i]) + for i in range(len(self.act_size)) + ], + axis=1, + ) + hidden_policy = tf.concat([encoded, prev_action_oh], axis=1) + + self.memory_in = tf.placeholder( + shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" + ) + hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( + hidden_policy, + self.memory_in, + self.sequence_length_ph, + name="lstm_policy", + ) + + self.memory_out = tf.identity(memory_policy_out, "recurrent_out") + else: + hidden_policy = encoded + + self.action_masks = tf.placeholder( + shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks" + ) + + with tf.variable_scope("policy"): + distribution = MultiCategoricalDistribution( + hidden_policy, self.act_size, self.action_masks + ) + # It's important that we are able to feed_dict a value into this tensor to get the + # right one-hot encoding, so we can't do identity on it. + self.output = distribution.sample + self.all_log_probs = tf.identity(distribution.log_probs, name="action") + self.selected_actions = tf.stop_gradient( + distribution.sample_onehot + ) # In discrete, these are onehot + self.entropy = distribution.entropy + self.total_log_probs = distribution.total_log_probs diff --git a/ml-agents/mlagents/trainers/ppo/trainer.py b/ml-agents/mlagents/trainers/ppo/trainer.py index f6b7af3a27..aa6927fd1a 100644 --- a/ml-agents/mlagents/trainers/ppo/trainer.py +++ b/ml-agents/mlagents/trainers/ppo/trainer.py @@ -9,7 +9,6 @@ from mlagents_envs.logging_util import get_logger from mlagents_envs.base_env import BehaviorSpec -from mlagents.trainers.policy.nn_policy import NNPolicy from mlagents.trainers.trainer.rl_trainer import RLTrainer from mlagents.trainers.policy.tf_policy import TFPolicy from mlagents.trainers.ppo.optimizer import PPOOptimizer @@ -52,7 +51,7 @@ def __init__( ) self.load = load self.seed = seed - self.policy: NNPolicy = None # type: ignore + self.policy: TFPolicy = None # type: ignore def _process_trajectory(self, trajectory: Trajectory) -> None: """ @@ -196,13 +195,12 @@ def create_policy( :param behavior_spec: specifications for policy construction :return policy """ - policy = NNPolicy( + policy = TFPolicy( self.seed, behavior_spec, self.trainer_settings, - self.is_training, - self.artifact_path, - self.load, + model_path=self.artifact_path, + load=self.load, condition_sigma_on_obs=False, # Faster training for PPO create_tf_graph=False, # We will create the TF graph in the Optimizer ) @@ -224,8 +222,6 @@ def add_policy( self.__class__.__name__ ) ) - if not isinstance(policy, NNPolicy): - raise RuntimeError("Non-NNPolicy passed to PPOTrainer.add_policy()") self.policy = policy self.policies[parsed_behavior_id.behavior_id] = policy self.optimizer = PPOOptimizer(self.policy, self.trainer_settings) diff --git a/ml-agents/mlagents/trainers/sac/optimizer.py b/ml-agents/mlagents/trainers/sac/optimizer.py index 2d8b94067a..b629dc153f 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer.py +++ b/ml-agents/mlagents/trainers/sac/optimizer.py @@ -66,8 +66,8 @@ def __init__(self, policy: TFPolicy, trainer_params: TrainerSettings): # Non-exposed SAC parameters self.discrete_target_entropy_scale = ( - 0.2 - ) # Roughly equal to e-greedy 0.05 + 0.2 # Roughly equal to e-greedy 0.05 + ) self.continuous_target_entropy_scale = 1.0 stream_names = list(self.reward_signals.keys()) diff --git a/ml-agents/mlagents/trainers/sac/trainer.py b/ml-agents/mlagents/trainers/sac/trainer.py index e8cc361850..0eac648617 100644 --- a/ml-agents/mlagents/trainers/sac/trainer.py +++ b/ml-agents/mlagents/trainers/sac/trainer.py @@ -13,7 +13,6 @@ from mlagents_envs.timers import timed from mlagents_envs.base_env import BehaviorSpec from mlagents.trainers.policy.tf_policy import TFPolicy -from mlagents.trainers.policy.nn_policy import NNPolicy from mlagents.trainers.sac.optimizer import SACOptimizer from mlagents.trainers.trainer.rl_trainer import RLTrainer from mlagents.trainers.trajectory import Trajectory, SplitObservations @@ -58,7 +57,7 @@ def __init__( self.load = load self.seed = seed - self.policy: NNPolicy = None # type: ignore + self.policy: TFPolicy = None # type: ignore self.optimizer: SACOptimizer = None # type: ignore self.hyperparameters: SACSettings = cast( SACSettings, trainer_settings.hyperparameters @@ -197,11 +196,10 @@ def _update_policy(self) -> bool: def create_policy( self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec ) -> TFPolicy: - policy = NNPolicy( + policy = TFPolicy( self.seed, behavior_spec, self.trainer_settings, - self.is_training, self.artifact_path, self.load, tanh_squash=True, @@ -326,8 +324,6 @@ def add_policy( self.__class__.__name__ ) ) - if not isinstance(policy, NNPolicy): - raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()") self.policy = policy self.policies[parsed_behavior_id.behavior_id] = policy self.optimizer = SACOptimizer(self.policy, self.trainer_settings) diff --git a/ml-agents/mlagents/trainers/tests/test_bcmodule.py b/ml-agents/mlagents/trainers/tests/test_bcmodule.py index 67b4595591..ca4bb1c0b4 100644 --- a/ml-agents/mlagents/trainers/tests/test_bcmodule.py +++ b/ml-agents/mlagents/trainers/tests/test_bcmodule.py @@ -4,7 +4,7 @@ import numpy as np import os -from mlagents.trainers.policy.nn_policy import NNPolicy +from mlagents.trainers.policy.tf_policy import TFPolicy from mlagents.trainers.components.bc.module import BCModule from mlagents.trainers.settings import ( TrainerSettings, @@ -19,11 +19,10 @@ def create_bc_module(mock_behavior_specs, bc_settings, use_rnn, tanhresample): trainer_config.network_settings.memory = ( NetworkSettings.MemorySettings() if use_rnn else None ) - policy = NNPolicy( + policy = TFPolicy( 0, mock_behavior_specs, trainer_config, - False, "test", False, tanhresample, @@ -89,7 +88,7 @@ def test_bcmodule_constant_lr_update(is_sac): assert isinstance(item, np.float32) old_learning_rate = bc_module.current_lr - stats = bc_module.update() + _ = bc_module.update() assert old_learning_rate == bc_module.current_lr diff --git a/ml-agents/mlagents/trainers/tests/test_nn_policy.py b/ml-agents/mlagents/trainers/tests/test_nn_policy.py index 667c5976af..a09d1a3b23 100644 --- a/ml-agents/mlagents/trainers/tests/test_nn_policy.py +++ b/ml-agents/mlagents/trainers/tests/test_nn_policy.py @@ -8,7 +8,7 @@ from mlagents.tf_utils import tf -from mlagents.trainers.policy.nn_policy import NNPolicy +from mlagents.trainers.policy.tf_policy import TFPolicy from mlagents.trainers.models import EncoderType, ModelUtils, Tensor3DShape from mlagents.trainers.exception import UnityTrainerException from mlagents.trainers.tests import mock_brain as mb @@ -32,7 +32,7 @@ def create_policy_mock( model_path: str = "", load: bool = False, seed: int = 0, -) -> NNPolicy: +) -> TFPolicy: mock_spec = mb.setup_test_behavior_specs( use_discrete, use_visual, @@ -47,7 +47,9 @@ def create_policy_mock( trainer_settings.network_settings.memory = ( NetworkSettings.MemorySettings() if use_rnn else None ) - policy = NNPolicy(seed, mock_spec, trainer_settings, False, model_path, load) + policy = TFPolicy( + seed, mock_spec, trainer_settings, model_path=model_path, load=load + ) return policy @@ -101,7 +103,7 @@ def test_version_compare(self): assert len(cm.output) == 1 -def _compare_two_policies(policy1: NNPolicy, policy2: NNPolicy) -> None: +def _compare_two_policies(policy1: TFPolicy, policy2: TFPolicy) -> None: """ Make sure two policies have the same output for the same input. """ @@ -149,11 +151,10 @@ def test_normalization(): # Change half of the obs to 0 for i in range(3): trajectory.steps[i].obs[0] = np.zeros(1, dtype=np.float32) - policy = NNPolicy( + policy = TFPolicy( 0, behavior_spec, TrainerSettings(network_settings=NetworkSettings(normalize=True)), - False, "testdir", False, ) diff --git a/ml-agents/mlagents/trainers/tests/test_ppo.py b/ml-agents/mlagents/trainers/tests/test_ppo.py index ba5b206abf..0c79758fd3 100644 --- a/ml-agents/mlagents/trainers/tests/test_ppo.py +++ b/ml-agents/mlagents/trainers/tests/test_ppo.py @@ -9,7 +9,7 @@ from mlagents.trainers.ppo.trainer import PPOTrainer, discount_rewards from mlagents.trainers.ppo.optimizer import PPOOptimizer -from mlagents.trainers.policy.nn_policy import NNPolicy +from mlagents.trainers.policy.tf_policy import TFPolicy from mlagents.trainers.agent_processor import AgentManagerQueue from mlagents.trainers.tests import mock_brain as mb from mlagents.trainers.tests.test_trajectory import make_fake_trajectory @@ -49,8 +49,8 @@ def _create_ppo_optimizer_ops_mock(dummy_config, use_rnn, use_discrete, use_visu if use_rnn else None ) - policy = NNPolicy( - 0, mock_specs, trainer_settings, False, "test", False, create_tf_graph=False + policy = TFPolicy( + 0, mock_specs, trainer_settings, "test", False, create_tf_graph=False ) optimizer = PPOOptimizer(policy, trainer_settings) return optimizer @@ -203,7 +203,7 @@ def test_trainer_increment_step(ppo_optimizer): ppo_optimizer.return_value = mock_optimizer trainer = PPOTrainer("test_brain", 0, trainer_params, True, False, 0, "0") - policy_mock = mock.Mock(spec=NNPolicy) + policy_mock = mock.Mock(spec=TFPolicy) policy_mock.get_current_step.return_value = 0 step_count = ( 5 # 10 hacked because this function is no longer called through trainer @@ -319,7 +319,7 @@ def test_add_get_policy(ppo_optimizer, dummy_config): ppo_optimizer.return_value = mock_optimizer trainer = PPOTrainer("test_policy", 0, dummy_config, True, False, 0, "0") - policy = mock.Mock(spec=NNPolicy) + policy = mock.Mock(spec=TFPolicy) policy.get_current_step.return_value = 2000 behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name) @@ -329,11 +329,6 @@ def test_add_get_policy(ppo_optimizer, dummy_config): # Make sure the summary steps were loaded properly assert trainer.get_step == 2000 - # Test incorrect class of policy - policy = mock.Mock() - with pytest.raises(RuntimeError): - trainer.add_policy(behavior_id, policy) - if __name__ == "__main__": pytest.main() diff --git a/ml-agents/mlagents/trainers/tests/test_reward_signals.py b/ml-agents/mlagents/trainers/tests/test_reward_signals.py index 9023bf86c0..d13cb2674b 100644 --- a/ml-agents/mlagents/trainers/tests/test_reward_signals.py +++ b/ml-agents/mlagents/trainers/tests/test_reward_signals.py @@ -2,7 +2,7 @@ import copy import os import mlagents.trainers.tests.mock_brain as mb -from mlagents.trainers.policy.nn_policy import NNPolicy +from mlagents.trainers.policy.tf_policy import TFPolicy from mlagents.trainers.sac.optimizer import SACOptimizer from mlagents.trainers.ppo.optimizer import PPOOptimizer from mlagents.trainers.tests.test_simple_rl import PPO_CONFIG, SAC_CONFIG @@ -69,8 +69,8 @@ def create_optimizer_mock( if use_rnn else None ) - policy = NNPolicy( - 0, mock_specs, trainer_settings, False, "test", False, create_tf_graph=False + policy = TFPolicy( + 0, mock_specs, trainer_settings, "test", False, create_tf_graph=False ) if trainer_settings.trainer_type == TrainerType.SAC: optimizer = SACOptimizer(policy, trainer_settings) diff --git a/ml-agents/mlagents/trainers/tests/test_sac.py b/ml-agents/mlagents/trainers/tests/test_sac.py index 023110d8b0..3ac4b91826 100644 --- a/ml-agents/mlagents/trainers/tests/test_sac.py +++ b/ml-agents/mlagents/trainers/tests/test_sac.py @@ -7,7 +7,7 @@ from mlagents.trainers.sac.trainer import SACTrainer from mlagents.trainers.sac.optimizer import SACOptimizer -from mlagents.trainers.policy.nn_policy import NNPolicy +from mlagents.trainers.policy.tf_policy import TFPolicy from mlagents.trainers.agent_processor import AgentManagerQueue from mlagents.trainers.tests import mock_brain as mb from mlagents.trainers.tests.mock_brain import setup_test_behavior_specs @@ -46,8 +46,8 @@ def create_sac_optimizer_mock(dummy_config, use_rnn, use_discrete, use_visual): if use_rnn else None ) - policy = NNPolicy( - 0, mock_brain, trainer_settings, False, "test", False, create_tf_graph=False + policy = TFPolicy( + 0, mock_brain, trainer_settings, "test", False, create_tf_graph=False ) optimizer = SACOptimizer(policy, trainer_settings) return optimizer @@ -134,7 +134,7 @@ def test_add_get_policy(sac_optimizer, dummy_config): sac_optimizer.return_value = mock_optimizer trainer = SACTrainer("test", 0, dummy_config, True, False, 0, "0") - policy = mock.Mock(spec=NNPolicy) + policy = mock.Mock(spec=TFPolicy) policy.get_current_step.return_value = 2000 behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name) trainer.add_policy(behavior_id, policy) @@ -143,11 +143,6 @@ def test_add_get_policy(sac_optimizer, dummy_config): # Make sure the summary steps were loaded properly assert trainer.get_step == 2000 - # Test incorrect class of policy - policy = mock.Mock() - with pytest.raises(RuntimeError): - trainer.add_policy(behavior_id, policy) - def test_advance(dummy_config): specs = setup_test_behavior_specs(