From 09d9c9ce9f6ba293f46c453a115c5371cb405ca5 Mon Sep 17 00:00:00 2001 From: Christopher Goy Date: Thu, 6 Aug 2020 12:28:40 -0700 Subject: [PATCH 1/9] First pass of actuator changes without changing the environments. --- com.unity.ml-agents/Runtime/Actuators.meta | 11 +- .../Runtime/Actuators/ActionSegment.cs | 2 +- .../Runtime/Actuators/ActionSpec.cs | 2 +- .../Runtime/Actuators/IActionReceiver.cs | 4 +- .../Runtime/Actuators/IActuator.cs | 2 +- .../Runtime/Actuators/IDiscreteActionMask.cs | 2 +- .../Runtime/Actuators/VectorActuator.cs | 2 +- com.unity.ml-agents/Runtime/Agent.cs | 237 ++++++++++++------ .../Runtime/Agent.deprecated.cs | 38 +++ .../Runtime/Agent.deprecated.cs.meta | 3 + .../Runtime/Communicator/GrpcExtensions.cs | 14 +- .../Runtime/Communicator/RpcCommunicator.cs | 2 +- .../Runtime/DecisionRequester.cs | 6 +- .../Runtime/DiscreteActionMasker.cs | 118 +-------- .../Runtime/Policies/BarracudaPolicy.cs | 12 +- .../Runtime/Policies/BehaviorParameters.cs | 37 ++- .../Runtime/Policies/HeuristicPolicy.cs | 19 +- .../Runtime/Policies/IPolicy.cs | 2 +- .../Runtime/Policies/RemotePolicy.cs | 11 +- .../Tests/Editor/BehaviorParameterTests.cs | 2 +- .../Tests/Editor/EditModeTestActionMasker.cs | 142 ----------- .../Editor/EditModeTestActionMasker.cs.meta | 11 - .../Tests/Editor/MLAgentsEditModeTest.cs | 2 +- 23 files changed, 303 insertions(+), 378 deletions(-) create mode 100644 com.unity.ml-agents/Runtime/Agent.deprecated.cs create mode 100644 com.unity.ml-agents/Runtime/Agent.deprecated.cs.meta delete mode 100644 com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs delete mode 100644 com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs.meta diff --git a/com.unity.ml-agents/Runtime/Actuators.meta b/com.unity.ml-agents/Runtime/Actuators.meta index 96bbfb99b3..588dd29561 100644 --- a/com.unity.ml-agents/Runtime/Actuators.meta +++ b/com.unity.ml-agents/Runtime/Actuators.meta @@ -1,8 +1,3 @@ -fileFormatVersion: 2 -guid: 26733e59183b6479e8f0e892a8bf09a4 -folderAsset: yes -DefaultImporter: - externalObjects: {} - userData: - assetBundleName: - assetBundleVariant: +fileFormatVersion: 2 +guid: 528e32f655ff4952b857b698e3efdd8e +timeCreated: 1592848300 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs index feb06a708d..9e18bd95e9 100644 --- a/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs +++ b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs @@ -11,7 +11,7 @@ namespace Unity.MLAgents.Actuators /// the offset into the original array, and an length. /// /// The type of object stored in the underlying - internal readonly struct ActionSegment : IEnumerable, IEquatable> + public readonly struct ActionSegment : IEnumerable, IEquatable> where T : struct { /// diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs b/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs index fbee0c4476..5b9175680f 100644 --- a/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs +++ b/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs @@ -8,7 +8,7 @@ namespace Unity.MLAgents.Actuators /// /// Defines the structure of an Action Space to be used by the Actuator system. /// - internal readonly struct ActionSpec + public readonly struct ActionSpec { /// diff --git a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs index 4e2a251f10..9ff7623ecc 100644 --- a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs +++ b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs @@ -7,7 +7,7 @@ namespace Unity.MLAgents.Actuators /// A structure that wraps the s for a particular and is /// used when is called. /// - internal readonly struct ActionBuffers + public readonly struct ActionBuffers { /// /// An empty action buffer. @@ -62,7 +62,7 @@ public override int GetHashCode() /// /// An interface that describes an object that can receive actions from a Reinforcement Learning network. /// - internal interface IActionReceiver + public interface IActionReceiver { /// diff --git a/com.unity.ml-agents/Runtime/Actuators/IActuator.cs b/com.unity.ml-agents/Runtime/Actuators/IActuator.cs index eedb940a36..5988f0aab3 100644 --- a/com.unity.ml-agents/Runtime/Actuators/IActuator.cs +++ b/com.unity.ml-agents/Runtime/Actuators/IActuator.cs @@ -6,7 +6,7 @@ namespace Unity.MLAgents.Actuators /// /// Abstraction that facilitates the execution of actions. /// - internal interface IActuator : IActionReceiver + public interface IActuator : IActionReceiver { int TotalNumberOfActions { get; } diff --git a/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs index 7cb0e99f72..30f8425792 100644 --- a/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs +++ b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs @@ -5,7 +5,7 @@ namespace Unity.MLAgents.Actuators /// /// Interface for writing a mask to disable discrete actions for agents for the next decision. /// - internal interface IDiscreteActionMask + public interface IDiscreteActionMask { /// /// Modifies an action mask for discrete control agents. diff --git a/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs index e2635c4164..7e8d6d7642 100644 --- a/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs +++ b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs @@ -4,7 +4,7 @@ namespace Unity.MLAgents.Actuators { - internal class VectorActuator : IActuator + public class VectorActuator : IActuator { IActionReceiver m_ActionReceiver; diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index cad4575f0b..443c2895d2 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -1,8 +1,10 @@ using System; using System.Collections.Generic; using System.Collections.ObjectModel; +using System.Linq; using UnityEngine; using Unity.Barracuda; +using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; using Unity.MLAgents.Sensors.Reflection; using Unity.MLAgents.Demonstrations; @@ -48,15 +50,30 @@ internal struct AgentInfo /// to separate between different agents in the environment. /// public int episodeId; - } - /// - /// Struct that contains the action information sent from the Brain to the - /// Agent. - /// - internal struct AgentAction - { - public float[] vectorActions; + public void ClearActions() + { + Array.Clear(storedVectorActions, 0, storedVectorActions.Length); + } + + public void CopyActions(float[] continuousActions, int[] discreteActions) + { + var start = 0; + if (continuousActions != null) + { + Array.Copy(continuousActions, 0, storedVectorActions, start, continuousActions.Length); + start = continuousActions.Length; + } + if (start >= storedVectorActions.Length) + { + return; + } + + if (continuousActions != null) + { + Array.Copy(discreteActions, 0, storedVectorActions, continuousActions.Length, discreteActions.Length); + } + } } /// @@ -106,7 +123,7 @@ internal struct AgentAction /// can only take an action when it touches the ground, so several frames might elapse between /// one decision and the need for the next. /// - /// Use the function to implement the actions your agent can take, + /// Use the function to implement the actions your agent can take, /// such as moving to reach a goal or interacting with its environment. /// /// When you call on an agent or the agent reaches its count, @@ -155,7 +172,7 @@ internal struct AgentAction "docs/Learning-Environment-Design-Agents.md")] [Serializable] [RequireComponent(typeof(BehaviorParameters))] - public class Agent : MonoBehaviour, ISerializationCallbackReceiver + public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver { IPolicy m_Brain; BehaviorParameters m_PolicyFactory; @@ -222,9 +239,6 @@ internal struct AgentParameters /// Current Agent information (message sent to Brain). AgentInfo m_Info; - /// Current Agent action (message sent from Brain). - AgentAction m_Action; - /// Represents the reward the agent accumulated during the current step. /// It is reset to 0 at the beginning of every step. /// Should be set to a positive value when the agent performs a "good" @@ -281,6 +295,24 @@ internal struct AgentParameters /// internal VectorSensor collectObservationsSensor; + /// + /// List of IActuators that this Agent will delegate actions to if any exist. + /// + ActuatorManager m_Actuators; + + /// + /// DiscreteVectorActuator which is used by default if no other sensors exist on this Agent. This VectorSensor will + /// delegate its actions to by default in order to keep backward compatibility + /// with the current behavior of Agent. + /// + IActuator m_VectorActuator; + + /// + /// This is used to avoid allocation of a float array every frame if users are still using the old + /// OnActionReceived method. + /// + float[] m_LegacyActionCache; + /// /// Called when the attached [GameObject] becomes enabled and active. /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html @@ -385,7 +417,6 @@ public void LazyInitialize() m_PolicyFactory = GetComponent(); m_Info = new AgentInfo(); - m_Action = new AgentAction(); sensors = new List(); Academy.Instance.AgentIncrementStep += AgentIncrementStep; @@ -402,6 +433,13 @@ public void LazyInitialize() InitializeSensors(); } + using (TimerStack.Instance.Scoped("InitializeActuators")) + { + InitializeActuators(); + } + + m_Info.storedVectorActions = new float[m_Actuators.TotalNumberOfActions]; + // The first time the Academy resets, all Agents in the scene will be // forced to reset through the event. // To avoid the Agent resetting twice, the Agents will not begin their @@ -624,7 +662,7 @@ public void SetReward(float reward) /// set the reward assigned to the current step with a specific value rather than /// increasing or decreasing it. /// - /// Typically, you assign rewards in the Agent subclass's + /// Typically, you assign rewards in the Agent subclass's /// implementation after carrying out the received action and evaluating its success. /// /// Rewards are used during reinforcement learning; they are ignored during inference. @@ -701,7 +739,7 @@ public void RequestDecision() /// /// Call `RequestAction()` to repeat the previous action returned by the agent's /// most recent decision. A new decision is not requested. When you call this function, - /// the Agent instance invokes with the + /// the Agent instance invokes with the /// existing action vector. /// /// You can use `RequestAction()` in situations where an agent must take an action @@ -728,16 +766,7 @@ public void RequestAction() /// at the end of an episode. void ResetData() { - var param = m_PolicyFactory.BrainParameters; - m_ActionMasker = new DiscreteActionMasker(param); - // If we haven't initialized vectorActions, initialize to 0. This should only - // happen during the creation of the Agent. In subsequent episodes, vectorAction - // should stay the previous action before the Done(), so that it is properly recorded. - if (m_Action.vectorActions == null) - { - m_Action.vectorActions = new float[param.NumActions]; - m_Info.storedVectorActions = new float[param.NumActions]; - } + m_Actuators?.ResetData(); } /// @@ -765,11 +794,11 @@ public virtual void Initialize() {} /// control of an agent using keyboard, mouse, or game controller input. /// /// Your heuristic implementation can use any decision making logic you specify. Assign decision - /// values to the float[] array, , passed to your function as a parameter. + /// values to the float[] array, , passed to your function as a parameter. /// The same array will be reused between steps. It is up to the user to initialize /// the values on each call, for example by calling `Array.Clear(actionsOut, 0, actionsOut.Length);`. /// Add values to the array at the same indexes as they are used in your - /// function, which receives this array and + /// function, which receives this array and /// implements the corresponding agent behavior. See [Actions] for more information /// about agent actions. /// Note : Do not create a new float array of action in the `Heuristic()` method, @@ -801,22 +830,41 @@ public virtual void Initialize() {} /// You can also use the [Input System package], which provides a more flexible and /// configurable input system. /// - /// public override void Heuristic(float[] actionsOut) + /// public override void Heuristic(float[] continuousActionsOut, int[] discreteActionsOut) /// { - /// actionsOut[0] = Input.GetAxis("Horizontal"); - /// actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f; - /// actionsOut[2] = Input.GetAxis("Vertical"); + /// continuousActions[0] = Input.GetAxis("Horizontal"); + /// continuousActions[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f; + /// continuousActions[2] = Input.GetAxis("Vertical"); /// } /// /// [Input Manager]: https://docs.unity3d.com/Manual/class-InputManager.html /// [Input System package]: https://docs.unity3d.com/Packages/com.unity.inputsystem@1.0/manual/index.html /// - /// Array for the output actions. - /// - public virtual void Heuristic(float[] actionsOut) + /// Array to write the continuous actions to. + /// Array to write the discreteActions to. + /// + public virtual void Heuristic(float[] continuousActionsOut, int[] discreteActionsOut) { Debug.LogWarning("Heuristic method called but not implemented. Returning placeholder actions."); - Array.Clear(actionsOut, 0, actionsOut.Length); + // For backward compatibility + switch (m_PolicyFactory.BrainParameters.VectorActionSpaceType) + { + case SpaceType.Continuous: + #pragma warning disable CS0618 + Heuristic(continuousActionsOut); + #pragma warning restore CS0618 + Array.Clear(discreteActionsOut, 0, discreteActionsOut.Length); + break; + case SpaceType.Discrete: + var convertedOut = Array.ConvertAll(discreteActionsOut, x => (float)x); + #pragma warning disable CS0618 + Heuristic(convertedOut); + #pragma warning restore CS0618 + var convertedBackToInt = Array.ConvertAll(convertedOut, x => (int)x); + Array.Copy(convertedBackToInt, 0, discreteActionsOut, 0, discreteActionsOut.Length); + Array.Clear(continuousActionsOut, 0, continuousActionsOut.Length); + break; + } } /// @@ -875,6 +923,7 @@ internal void InitializeSensors() #if DEBUG // Make sure the names are actually unique + for (var i = 0; i < sensors.Count - 1; i++) { Debug.Assert( @@ -884,6 +933,32 @@ internal void InitializeSensors() #endif } + void InitializeActuators() + { + ActuatorComponent[] attachedActuators; + if (m_PolicyFactory.UseChildActuators) + { + attachedActuators = GetComponentsInChildren(); + } + else + { + attachedActuators = GetComponents(); + } + + // Support legacy OnActionReceived + var param = m_PolicyFactory.BrainParameters; + m_VectorActuator = new VectorActuator(this, param.VectorActionSize, param.VectorActionSpaceType); + m_Actuators = new ActuatorManager(attachedActuators.Length + 1); + m_LegacyActionCache = new float[m_VectorActuator.TotalNumberOfActions]; + + m_Actuators.Add(m_VectorActuator); + + foreach (var actuatorComponent in attachedActuators) + { + m_Actuators.Add(actuatorComponent.CreateActuator()); + } + } + /// /// Sends the Agent info to the linked Brain. /// @@ -902,13 +977,13 @@ void SendInfoToBrain() if (m_Info.done) { - Array.Clear(m_Info.storedVectorActions, 0, m_Info.storedVectorActions.Length); + m_Info.ClearActions(); } else { - Array.Copy(m_Action.vectorActions, m_Info.storedVectorActions, m_Action.vectorActions.Length); + m_Info.CopyActions(m_Actuators.StoredContinuousActions, m_Actuators.StoredDiscreteActions); } - m_ActionMasker.ResetMask(); + UpdateSensors(); using (TimerStack.Instance.Scoped("CollectObservations")) { @@ -918,11 +993,11 @@ void SendInfoToBrain() { if (m_PolicyFactory.BrainParameters.VectorActionSpaceType == SpaceType.Discrete) { - CollectDiscreteActionMasks(m_ActionMasker); + m_Actuators.WriteActionMask(); } } - m_Info.discreteActionMasks = m_ActionMasker.GetMask(); + m_Info.discreteActionMasks = m_Actuators.DiscreteActionMask?.GetMask(); m_Info.reward = m_Reward; m_Info.done = false; m_Info.maxStepReached = false; @@ -1029,11 +1104,20 @@ public ReadOnlyCollection GetObservations() /// /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#actions /// - /// - public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker) + /// + public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { + if (m_ActionMasker == null) + { + m_ActionMasker = new DiscreteActionMasker(actionMask); + } + #pragma warning disable 618 + CollectDiscreteActionMasks(m_ActionMasker); + #pragma warning restore 618 } + ActionSpec IActionReceiver.ActionSpec { get; } + /// /// Implement `OnActionReceived()` to specify agent behavior at every step, based /// on the provided action. @@ -1049,7 +1133,7 @@ public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker /// three values in the action array to use as the force components. During /// training, the agent's policy learns to set those particular elements of /// the array to maximize the training rewards the agent receives. (Of course, - /// if you implement a function, it must use the same + /// if you implement a function, it must use the same /// elements of the action array for the same purpose since there is no learning /// involved.) /// @@ -1099,12 +1183,34 @@ public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker /// /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#actions /// - /// - /// An array containing the action vector. The length of the array is specified - /// by the of the agent's associated - /// component. + /// + /// Struct containing the buffers of actions to be executed at this step. /// - public virtual void OnActionReceived(float[] vectorAction) {} + public virtual void OnActionReceived(ActionBuffers actions) + { + // Copy the actions into our local array and call the original method for + // backward compatibility. + // For now we need to check which array has the actions in them in order to pass it back to the old method. + if (actions.ContinuousActions.Length > 0) + { + Array.Copy(actions.ContinuousActions.Array, + actions.ContinuousActions.Offset, + m_LegacyActionCache, + 0, + actions.ContinuousActions.Length); + } + else if (actions.DiscreteActions.Length > 0) + { + Array.Copy(actions.DiscreteActions.Array, + actions.DiscreteActions.Offset, + m_LegacyActionCache, + 0, + actions.DiscreteActions.Length); + } + #pragma warning disable CS0618 + OnActionReceived(m_LegacyActionCache); + #pragma warning restore CS0618 + } /// /// Implement `OnEpisodeBegin()` to set up an Agent instance at the beginning @@ -1114,16 +1220,14 @@ public virtual void OnActionReceived(float[] vectorAction) {} /// public virtual void OnEpisodeBegin() {} - /// - /// Returns the last action that was decided on by the Agent. - /// - /// - /// The last action that was decided by the Agent (or null if no decision has been made). - /// - /// - public float[] GetAction() + public float[] GetStoredContinuousActions() { - return m_Action.vectorActions; + return m_Actuators.StoredContinuousActions; + } + + public int[] GetStoredDiscreteActions() + { + return m_Actuators.StoredDiscreteActions; } /// @@ -1177,7 +1281,7 @@ void AgentStep() if ((m_RequestAction) && (m_Brain != null)) { m_RequestAction = false; - OnActionReceived(m_Action.vectorActions); + m_Actuators.ExecuteActions(); } if ((m_StepCount >= MaxStep) && (MaxStep > 0)) @@ -1189,20 +1293,13 @@ void AgentStep() void DecideAction() { - if (m_Action.vectorActions == null) + if (m_Actuators.StoredContinuousActions == null) { ResetData(); } - var action = m_Brain?.DecideAction(); - - if (action == null) - { - Array.Clear(m_Action.vectorActions, 0, m_Action.vectorActions.Length); - } - else - { - Array.Copy(action, m_Action.vectorActions, action.Length); - } + var action = m_Brain?.DecideAction() ?? (continuousActions : Array.Empty(), discreteActions : Array.Empty()); + m_Info.CopyActions(action.continuousActions, action.discreteActions); + m_Actuators.UpdateActions(action.continuousActions, action.discreteActions); } } } diff --git a/com.unity.ml-agents/Runtime/Agent.deprecated.cs b/com.unity.ml-agents/Runtime/Agent.deprecated.cs new file mode 100644 index 0000000000..cbd785400a --- /dev/null +++ b/com.unity.ml-agents/Runtime/Agent.deprecated.cs @@ -0,0 +1,38 @@ +using System; +using UnityEngine; + +namespace Unity.MLAgents +{ + public partial class Agent + { + // [Obsolete("CollectDiscreteActionMasks has been deprecated. Please use WriteDiscreteActionMask instead.", false)] + public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker) + { + } + + // [Obsolete("The Heuristic(float[]) method has been deprecated. Please use Heuristic(float[], int[]) instead.")] + public virtual void Heuristic(float[] continuousActionsOut) + { + Debug.LogWarning("Heuristic method called but not implemented. Returning placeholder actions."); + Array.Clear(continuousActionsOut, 0, continuousActionsOut.Length); + } + + // [Obsolete("The OnActionReceived(float[]) method has been deprecated" + + // " Please use OnActionReceived(ActionSegment, ActionSegment).", false)] + public virtual void OnActionReceived(float[] vectorAction) {} + + /// + /// Returns the last action that was decided on by the Agent. + /// + /// + /// The last action that was decided by the Agent (or null if no decision has been made). + /// + /// + // [Obsolete("GetAction has been deprecated, please use GetStoredContinuousActions, Or GetStoredDiscreteActions.")] + public float[] GetAction() + { + return m_Info.storedVectorActions; + } + + } +} diff --git a/com.unity.ml-agents/Runtime/Agent.deprecated.cs.meta b/com.unity.ml-agents/Runtime/Agent.deprecated.cs.meta new file mode 100644 index 0000000000..767483f7f7 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Agent.deprecated.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 9650a482703b47db8cd7fb2df8caa1bf +timeCreated: 1595613441 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index e41353fc44..deb2bb891d 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -175,20 +175,12 @@ public static UnityRLInitParameters ToUnityRLInitParameters(this UnityRLInitiali } #region AgentAction - public static AgentAction ToAgentAction(this AgentActionProto aap) + public static List ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto) { - return new AgentAction - { - vectorActions = aap.VectorActions.ToArray() - }; - } - - public static List ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto) - { - var agentActions = new List(proto.Value.Count); + var agentActions = new List(proto.Value.Count); foreach (var ap in proto.Value) { - agentActions.Add(ap.ToAgentAction()); + agentActions.Add(ap.VectorActions.ToArray()); } return agentActions; } diff --git a/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs index 6c9ce6655b..afc4544242 100644 --- a/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs +++ b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs @@ -412,7 +412,7 @@ void SendBatchedMessageHelper() var agentId = m_OrderedAgentsRequestingDecisions[brainName][i]; if (m_LastActionsReceived[brainName].ContainsKey(agentId)) { - m_LastActionsReceived[brainName][agentId] = agentAction.vectorActions; + m_LastActionsReceived[brainName][agentId] = agentAction; } } } diff --git a/com.unity.ml-agents/Runtime/DecisionRequester.cs b/com.unity.ml-agents/Runtime/DecisionRequester.cs index fabdaecc50..fc7cc55afd 100644 --- a/com.unity.ml-agents/Runtime/DecisionRequester.cs +++ b/com.unity.ml-agents/Runtime/DecisionRequester.cs @@ -26,7 +26,7 @@ public class DecisionRequester : MonoBehaviour /// that the Agent will request a decision every 5 Academy steps. /// [Range(1, 20)] [Tooltip("The frequency with which the agent requests a decision. A DecisionPeriod " + - "of 5 means that the Agent will request a decision every 5 Academy steps.")] + "of 5 means that the Agent will request a decision every 5 Academy steps.")] public int DecisionPeriod = 5; /// @@ -34,8 +34,8 @@ public class DecisionRequester : MonoBehaviour /// it does not request a decision. Has no effect when DecisionPeriod is set to 1. /// [Tooltip("Indicates whether or not the agent will take an action during the Academy " + - "steps where it does not request a decision. Has no effect when DecisionPeriod " + - "is set to 1.")] + "steps where it does not request a decision. Has no effect when DecisionPeriod " + + "is set to 1.")] [FormerlySerializedAs("RepeatAction")] public bool TakeActionsBetweenDecisions = true; diff --git a/com.unity.ml-agents/Runtime/DiscreteActionMasker.cs b/com.unity.ml-agents/Runtime/DiscreteActionMasker.cs index 1a9b322a98..c5a69a3c31 100644 --- a/com.unity.ml-agents/Runtime/DiscreteActionMasker.cs +++ b/com.unity.ml-agents/Runtime/DiscreteActionMasker.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; -using System.Linq; -using Unity.MLAgents.Policies; +using Unity.MLAgents.Actuators; namespace Unity.MLAgents { @@ -15,19 +14,13 @@ namespace Unity.MLAgents /// may be illegal. For example, if an agent is adjacent to a wall or other obstacle /// you could mask any actions that direct the agent to move into the blocked space. /// - public class DiscreteActionMasker + public class DiscreteActionMasker : IDiscreteActionMask { - /// When using discrete control, is the starting indices of the actions - /// when all the branches are concatenated with each other. - int[] m_StartingActionIndices; + IDiscreteActionMask m_Delegate; - bool[] m_CurrentMask; - - readonly BrainParameters m_BrainParameters; - - internal DiscreteActionMasker(BrainParameters brainParameters) + internal DiscreteActionMasker(IDiscreteActionMask actionMask) { - m_BrainParameters = brainParameters; + m_Delegate = actionMask; } /// @@ -46,109 +39,22 @@ internal DiscreteActionMasker(BrainParameters brainParameters) /// The indices of the masked actions. public void SetMask(int branch, IEnumerable actionIndices) { - // If the branch does not exist, raise an error - if (branch >= m_BrainParameters.VectorActionSize.Length) - throw new UnityAgentsException( - "Invalid Action Masking : Branch " + branch + " does not exist."); - - var totalNumberActions = m_BrainParameters.VectorActionSize.Sum(); - - // By default, the masks are null. If we want to specify a new mask, we initialize - // the actionMasks with trues. - if (m_CurrentMask == null) - { - m_CurrentMask = new bool[totalNumberActions]; - } - - // If this is the first time the masked actions are used, we generate the starting - // indices for each branch. - if (m_StartingActionIndices == null) - { - m_StartingActionIndices = Utilities.CumSum(m_BrainParameters.VectorActionSize); - } - - // Perform the masking - foreach (var actionIndex in actionIndices) - { - if (actionIndex >= m_BrainParameters.VectorActionSize[branch]) - { - throw new UnityAgentsException( - "Invalid Action Masking: Action Mask is too large for specified branch."); - } - m_CurrentMask[actionIndex + m_StartingActionIndices[branch]] = true; - } - } - - /// - /// Get the current mask for an agent. - /// - /// A mask for the agent. A boolean array of length equal to the total number of - /// actions. - internal bool[] GetMask() - { - if (m_CurrentMask != null) - { - AssertMask(); - } - return m_CurrentMask; + m_Delegate.WriteMask(branch, actionIndices); } - /// - /// Makes sure that the current mask is usable. - /// - void AssertMask() + public void WriteMask(int branch, IEnumerable actionIndices) { - // Action Masks can only be used in Discrete Control. - if (m_BrainParameters.VectorActionSpaceType != SpaceType.Discrete) - { - throw new UnityAgentsException( - "Invalid Action Masking : Can only set action mask for Discrete Control."); - } - - var numBranches = m_BrainParameters.VectorActionSize.Length; - for (var branchIndex = 0; branchIndex < numBranches; branchIndex++) - { - if (AreAllActionsMasked(branchIndex)) - { - throw new UnityAgentsException( - "Invalid Action Masking : All the actions of branch " + branchIndex + - " are masked."); - } - } + m_Delegate.WriteMask(branch, actionIndices); } - /// - /// Resets the current mask for an agent. - /// - internal void ResetMask() + public bool[] GetMask() { - if (m_CurrentMask != null) - { - Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length); - } + return m_Delegate.GetMask(); } - /// - /// Checks if all the actions in the input branch are masked. - /// - /// The index of the branch to check. - /// True if all the actions of the branch are masked. - bool AreAllActionsMasked(int branch) + public void ResetMask() { - if (m_CurrentMask == null) - { - return false; - } - var start = m_StartingActionIndices[branch]; - var end = m_StartingActionIndices[branch + 1]; - for (var i = start; i < end; i++) - { - if (!m_CurrentMask[i]) - { - return false; - } - } - return true; + m_Delegate.ResetMask(); } } } diff --git a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs index 6c6bcec4c1..dd4c802493 100644 --- a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs @@ -1,3 +1,4 @@ +using System; using Unity.Barracuda; using System.Collections.Generic; using Unity.MLAgents.Inference; @@ -36,6 +37,7 @@ internal class BarracudaPolicy : IPolicy /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors. /// List m_SensorShapes; + SpaceType m_SapceType; /// public BarracudaPolicy( @@ -45,6 +47,7 @@ public BarracudaPolicy( { var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, brainParameters, inferenceDevice); m_ModelRunner = modelRunner; + m_SapceType = brainParameters.VectorActionSpaceType; } /// @@ -55,10 +58,15 @@ public void RequestDecision(AgentInfo info, List sensors) } /// - public float[] DecideAction() + public (float[], int[]) DecideAction() { m_ModelRunner?.DecideBatch(); - return m_ModelRunner?.GetAction(m_AgentId); + var actions = m_ModelRunner?.GetAction(m_AgentId); + if (m_SapceType == SpaceType.Continuous) + { + return (actions, Array.Empty()); + } + return (Array.Empty(), actions == null ? Array.Empty() : Array.ConvertAll(actions, x => (int)x)); } public void Dispose() diff --git a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs index b8d38f49d4..72b0038a9a 100644 --- a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs +++ b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs @@ -150,6 +150,11 @@ public string BehaviorName [Tooltip("Use all Sensor components attached to child GameObjects of this Agent.")] bool m_UseChildSensors = true; + [HideInInspector] + [SerializeField] + [Tooltip("Use all Actuator components attached to child GameObjects of this Agent.")] + bool m_UseChildActuators = true; + /// /// Whether or not to use all the sensor components attached to child GameObjects of the agent. /// Note that changing this after the Agent has been initialized will not have any effect. @@ -160,6 +165,16 @@ public bool UseChildSensors set { m_UseChildSensors = value; } } + /// + /// Whether or not to use all the actuator components attached to child GameObjects of the agent. + /// Note that changing this after the Agent has been initialized will not have any effect. + /// + public bool UseChildActuators + { + get { return m_UseChildActuators; } + set { m_UseChildActuators = value; } + } + [HideInInspector, SerializeField] ObservableAttributeOptions m_ObservableAttributeHandling = ObservableAttributeOptions.Ignore; @@ -185,7 +200,7 @@ internal IPolicy GeneratePolicy(HeuristicPolicy.ActionGenerator heuristic) switch (m_BehaviorType) { case BehaviorType.HeuristicOnly: - return new HeuristicPolicy(heuristic, m_BrainParameters.NumActions); + return GenerateHeuristicPolicy(heuristic); case BehaviorType.InferenceOnly: { if (m_Model == null) @@ -209,13 +224,29 @@ internal IPolicy GeneratePolicy(HeuristicPolicy.ActionGenerator heuristic) } else { - return new HeuristicPolicy(heuristic, m_BrainParameters.NumActions); + return GenerateHeuristicPolicy(heuristic); } default: - return new HeuristicPolicy(heuristic, m_BrainParameters.NumActions); + return GenerateHeuristicPolicy(heuristic); } } + internal IPolicy GenerateHeuristicPolicy(HeuristicPolicy.ActionGenerator heuristic) + { + var numContinuousActions = 0; + var numDiscreteActions = 0; + if (m_BrainParameters.VectorActionSpaceType == SpaceType.Continuous) + { + numContinuousActions = m_BrainParameters.NumActions; + } + else if (m_BrainParameters.VectorActionSpaceType == SpaceType.Discrete) + { + numDiscreteActions = m_BrainParameters.NumActions; + } + + return new HeuristicPolicy(heuristic, numContinuousActions, numDiscreteActions); + } + internal void UpdateAgentPolicy() { var agent = GetComponent(); diff --git a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs index 12d304e17f..32d111121f 100644 --- a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs @@ -12,9 +12,10 @@ namespace Unity.MLAgents.Policies /// internal class HeuristicPolicy : IPolicy { - public delegate void ActionGenerator(float[] actionsOut); + public delegate void ActionGenerator(float[] continuousActionsOut, int[] discreteActionsOut); ActionGenerator m_Heuristic; - float[] m_LastDecision; + float[] m_LastContinuousDecision; + int[] m_LastDiscreteDecision; bool m_Done; bool m_DecisionRequested; @@ -23,10 +24,11 @@ internal class HeuristicPolicy : IPolicy /// - public HeuristicPolicy(ActionGenerator heuristic, int numActions) + public HeuristicPolicy(ActionGenerator heuristic, int numContinuousActions, int numDiscreteActions) { m_Heuristic = heuristic; - m_LastDecision = new float[numActions]; + m_LastContinuousDecision = new float[numContinuousActions]; + m_LastDiscreteDecision = new int[numDiscreteActions]; } /// @@ -35,18 +37,17 @@ public void RequestDecision(AgentInfo info, List sensors) StepSensors(sensors); m_Done = info.done; m_DecisionRequested = true; - } /// - public float[] DecideAction() + public (float[], int[]) DecideAction() { if (!m_Done && m_DecisionRequested) { - m_Heuristic.Invoke(m_LastDecision); + m_Heuristic.Invoke(m_LastContinuousDecision, m_LastDiscreteDecision); } m_DecisionRequested = false; - return m_LastDecision; + return (m_LastContinuousDecision, m_LastDiscreteDecision); } public void Dispose() @@ -110,7 +111,7 @@ public void RemoveAt(int index) public float this[int index] { get { return 0.0f; } - set { } + set {} } } diff --git a/com.unity.ml-agents/Runtime/Policies/IPolicy.cs b/com.unity.ml-agents/Runtime/Policies/IPolicy.cs index 7203853589..3dc85ffccb 100644 --- a/com.unity.ml-agents/Runtime/Policies/IPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/IPolicy.cs @@ -26,6 +26,6 @@ internal interface IPolicy : IDisposable /// it must be taken now. The Brain is expected to update the actions /// of the Agents at this point the latest. /// - float[] DecideAction(); + (float[] continuousActions, int[] discreteActions) DecideAction(); } } diff --git a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs index 2f88d37f53..dd6fb9935d 100644 --- a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs @@ -13,6 +13,7 @@ internal class RemotePolicy : IPolicy { int m_AgentId; string m_FullyQualifiedBehaviorName; + SpaceType m_SpaceType; internal ICommunicator m_Communicator; @@ -23,6 +24,7 @@ public RemotePolicy( { m_FullyQualifiedBehaviorName = fullyQualifiedBehaviorName; m_Communicator = Academy.Instance.Communicator; + m_SpaceType = brainParameters.VectorActionSpaceType; m_Communicator.SubscribeBrain(m_FullyQualifiedBehaviorName, brainParameters); } @@ -34,10 +36,15 @@ public void RequestDecision(AgentInfo info, List sensors) } /// - public float[] DecideAction() + public (float[], int[]) DecideAction() { m_Communicator?.DecideBatch(); - return m_Communicator?.GetActions(m_FullyQualifiedBehaviorName, m_AgentId); + var actions = m_Communicator?.GetActions(m_FullyQualifiedBehaviorName, m_AgentId); + if (m_SpaceType == SpaceType.Continuous) + { + return (actions, Array.Empty()); + } + return (Array.Empty(), Array.ConvertAll(actions, x => (int)x)); } public void Dispose() diff --git a/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs index 7364e3173a..5b18e6427f 100644 --- a/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs +++ b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs @@ -7,7 +7,7 @@ namespace Unity.MLAgents.Tests [TestFixture] public class BehaviorParameterTests { - static void DummyHeuristic(float[] actionsOut) + static void DummyHeuristic(float[] actionsOut, int[] discreteActionsOut) { // No-op } diff --git a/com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs b/com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs deleted file mode 100644 index 0f5c24edbd..0000000000 --- a/com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs +++ /dev/null @@ -1,142 +0,0 @@ -using NUnit.Framework; -using Unity.MLAgents.Policies; - -namespace Unity.MLAgents.Tests -{ - public class EditModeTestActionMasker - { - [Test] - public void Contruction() - { - var bp = new BrainParameters(); - var masker = new DiscreteActionMasker(bp); - Assert.IsNotNull(masker); - } - - [Test] - public void FailsWithContinuous() - { - var bp = new BrainParameters(); - bp.VectorActionSpaceType = SpaceType.Continuous; - bp.VectorActionSize = new[] {4}; - var masker = new DiscreteActionMasker(bp); - masker.SetMask(0, new[] {0}); - Assert.Catch(() => masker.GetMask()); - } - - [Test] - public void NullMask() - { - var bp = new BrainParameters(); - bp.VectorActionSpaceType = SpaceType.Discrete; - var masker = new DiscreteActionMasker(bp); - var mask = masker.GetMask(); - Assert.IsNull(mask); - } - - [Test] - public void FirstBranchMask() - { - var bp = new BrainParameters(); - bp.VectorActionSpaceType = SpaceType.Discrete; - bp.VectorActionSize = new[] {4, 5, 6}; - var masker = new DiscreteActionMasker(bp); - var mask = masker.GetMask(); - Assert.IsNull(mask); - masker.SetMask(0, new[] {1, 2, 3}); - mask = masker.GetMask(); - Assert.IsFalse(mask[0]); - Assert.IsTrue(mask[1]); - Assert.IsTrue(mask[2]); - Assert.IsTrue(mask[3]); - Assert.IsFalse(mask[4]); - Assert.AreEqual(mask.Length, 15); - } - - [Test] - public void SecondBranchMask() - { - var bp = new BrainParameters - { - VectorActionSpaceType = SpaceType.Discrete, - VectorActionSize = new[] { 4, 5, 6 } - }; - var masker = new DiscreteActionMasker(bp); - masker.SetMask(1, new[] {1, 2, 3}); - var mask = masker.GetMask(); - Assert.IsFalse(mask[0]); - Assert.IsFalse(mask[4]); - Assert.IsTrue(mask[5]); - Assert.IsTrue(mask[6]); - Assert.IsTrue(mask[7]); - Assert.IsFalse(mask[8]); - Assert.IsFalse(mask[9]); - } - - [Test] - public void MaskReset() - { - var bp = new BrainParameters - { - VectorActionSpaceType = SpaceType.Discrete, - VectorActionSize = new[] { 4, 5, 6 } - }; - var masker = new DiscreteActionMasker(bp); - masker.SetMask(1, new[] {1, 2, 3}); - masker.ResetMask(); - var mask = masker.GetMask(); - for (var i = 0; i < 15; i++) - { - Assert.IsFalse(mask[i]); - } - } - - [Test] - public void ThrowsError() - { - var bp = new BrainParameters - { - VectorActionSpaceType = SpaceType.Discrete, - VectorActionSize = new[] { 4, 5, 6 } - }; - var masker = new DiscreteActionMasker(bp); - - Assert.Catch( - () => masker.SetMask(0, new[] {5})); - Assert.Catch( - () => masker.SetMask(1, new[] {5})); - masker.SetMask(2, new[] {5}); - Assert.Catch( - () => masker.SetMask(3, new[] {1})); - masker.GetMask(); - masker.ResetMask(); - masker.SetMask(0, new[] {0, 1, 2, 3}); - Assert.Catch( - () => masker.GetMask()); - } - - [Test] - public void MultipleMaskEdit() - { - var bp = new BrainParameters(); - bp.VectorActionSpaceType = SpaceType.Discrete; - bp.VectorActionSize = new[] {4, 5, 6}; - var masker = new DiscreteActionMasker(bp); - masker.SetMask(0, new[] {0, 1}); - masker.SetMask(0, new[] {3}); - masker.SetMask(2, new[] {1}); - var mask = masker.GetMask(); - for (var i = 0; i < 15; i++) - { - if ((i == 0) || (i == 1) || (i == 3) || (i == 10)) - { - Assert.IsTrue(mask[i]); - } - else - { - Assert.IsFalse(mask[i]); - } - } - } - } -} diff --git a/com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs.meta b/com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs.meta deleted file mode 100644 index 1325856f78..0000000000 --- a/com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs.meta +++ /dev/null @@ -1,11 +0,0 @@ -fileFormatVersion: 2 -guid: 2e2810ee6c8c64fb39abdf04b5d17f50 -MonoImporter: - externalObjects: {} - serializedVersion: 2 - defaultReferences: [] - executionOrder: 0 - icon: {instanceID: 0} - userData: - assetBundleName: - assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index 28a7bc3368..599cf3e64c 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -23,7 +23,7 @@ public void RequestDecision(AgentInfo info, List sensors) OnRequestDecision?.Invoke(); } - public float[] DecideAction() { return new float[0]; } + public (float[] continuousActions, int[] discreteActions) DecideAction() { return (new float[0], new int[0]); } public void Dispose() {} } From 9049469b06f94732d6fd8ed9d2c7713ebd0626bc Mon Sep 17 00:00:00 2001 From: Christopher Goy Date: Thu, 6 Aug 2020 12:42:42 -0700 Subject: [PATCH 2/9] Fix up some weird copy pasta. --- com.unity.ml-agents/Runtime/Agent.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 443c2895d2..d0b2c42f04 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -69,9 +69,9 @@ public void CopyActions(float[] continuousActions, int[] discreteActions) return; } - if (continuousActions != null) + if (discreteActions != null) { - Array.Copy(discreteActions, 0, storedVectorActions, continuousActions.Length, discreteActions.Length); + Array.Copy(discreteActions, 0, storedVectorActions, start, discreteActions.Length); } } } From dd6e6f5b1a607069516546cc9e6f9a78ef1e5621 Mon Sep 17 00:00:00 2001 From: Chris Goy Date: Thu, 6 Aug 2020 14:03:01 -0700 Subject: [PATCH 3/9] Update name to VectorActuator Co-authored-by: Chris Elion --- com.unity.ml-agents/Runtime/Agent.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index d0b2c42f04..9dd929aa51 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -301,7 +301,7 @@ internal struct AgentParameters ActuatorManager m_Actuators; /// - /// DiscreteVectorActuator which is used by default if no other sensors exist on this Agent. This VectorSensor will + /// VectorActuator which is used by default if no other sensors exist on this Agent. This VectorSensor will /// delegate its actions to by default in order to keep backward compatibility /// with the current behavior of Agent. /// From 15879d2b66bcedbbb53c5a28f6a3a50285c34a0b Mon Sep 17 00:00:00 2001 From: Christopher Goy Date: Thu, 6 Aug 2020 14:13:53 -0700 Subject: [PATCH 4/9] Address PR feedback. Rename member, avoid using SpaceType when writing the action mask in Agent. --- .../Runtime/Actuators/ActuatorManager.cs | 10 ++++-- com.unity.ml-agents/Runtime/Agent.cs | 33 +++++++++---------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs index a1b953118f..57ca0755b8 100644 --- a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs +++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs @@ -148,9 +148,12 @@ public void WriteActionMask() for (var i = 0; i < m_Actuators.Count; i++) { var actuator = m_Actuators[i]; - m_DiscreteActionMask.CurrentBranchOffset = offset; - actuator.WriteDiscreteActionMask(m_DiscreteActionMask); - offset += actuator.ActionSpec.NumDiscreteActions; + if (actuator.ActionSpec.NumDiscreteActions > 0) + { + m_DiscreteActionMask.CurrentBranchOffset = offset; + actuator.WriteDiscreteActionMask(m_DiscreteActionMask); + offset += actuator.ActionSpec.NumDiscreteActions; + } } } @@ -208,6 +211,7 @@ public void ResetData() { m_Actuators[i].ResetData(); } + m_DiscreteActionMask.ResetMask(); } diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 9dd929aa51..3e4e93d37c 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -298,7 +298,7 @@ internal struct AgentParameters /// /// List of IActuators that this Agent will delegate actions to if any exist. /// - ActuatorManager m_Actuators; + ActuatorManager m_ActuatorManager; /// /// VectorActuator which is used by default if no other sensors exist on this Agent. This VectorSensor will @@ -438,7 +438,7 @@ public void LazyInitialize() InitializeActuators(); } - m_Info.storedVectorActions = new float[m_Actuators.TotalNumberOfActions]; + m_Info.storedVectorActions = new float[m_ActuatorManager.TotalNumberOfActions]; // The first time the Academy resets, all Agents in the scene will be // forced to reset through the event. @@ -766,7 +766,7 @@ public void RequestAction() /// at the end of an episode. void ResetData() { - m_Actuators?.ResetData(); + m_ActuatorManager?.ResetData(); } /// @@ -948,14 +948,14 @@ void InitializeActuators() // Support legacy OnActionReceived var param = m_PolicyFactory.BrainParameters; m_VectorActuator = new VectorActuator(this, param.VectorActionSize, param.VectorActionSpaceType); - m_Actuators = new ActuatorManager(attachedActuators.Length + 1); + m_ActuatorManager = new ActuatorManager(attachedActuators.Length + 1); m_LegacyActionCache = new float[m_VectorActuator.TotalNumberOfActions]; - m_Actuators.Add(m_VectorActuator); + m_ActuatorManager.Add(m_VectorActuator); foreach (var actuatorComponent in attachedActuators) { - m_Actuators.Add(actuatorComponent.CreateActuator()); + m_ActuatorManager.Add(actuatorComponent.CreateActuator()); } } @@ -981,7 +981,7 @@ void SendInfoToBrain() } else { - m_Info.CopyActions(m_Actuators.StoredContinuousActions, m_Actuators.StoredDiscreteActions); + m_Info.CopyActions(m_ActuatorManager.StoredContinuousActions, m_ActuatorManager.StoredDiscreteActions); } UpdateSensors(); @@ -991,13 +991,10 @@ void SendInfoToBrain() } using (TimerStack.Instance.Scoped("CollectDiscreteActionMasks")) { - if (m_PolicyFactory.BrainParameters.VectorActionSpaceType == SpaceType.Discrete) - { - m_Actuators.WriteActionMask(); - } + m_ActuatorManager.WriteActionMask(); } - m_Info.discreteActionMasks = m_Actuators.DiscreteActionMask?.GetMask(); + m_Info.discreteActionMasks = m_ActuatorManager.DiscreteActionMask?.GetMask(); m_Info.reward = m_Reward; m_Info.done = false; m_Info.maxStepReached = false; @@ -1133,7 +1130,7 @@ public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask) /// three values in the action array to use as the force components. During /// training, the agent's policy learns to set those particular elements of /// the array to maximize the training rewards the agent receives. (Of course, - /// if you implement a function, it must use the same + /// if you implement a function, it must use the same /// elements of the action array for the same purpose since there is no learning /// involved.) /// @@ -1222,12 +1219,12 @@ public virtual void OnEpisodeBegin() {} public float[] GetStoredContinuousActions() { - return m_Actuators.StoredContinuousActions; + return m_ActuatorManager.StoredContinuousActions; } public int[] GetStoredDiscreteActions() { - return m_Actuators.StoredDiscreteActions; + return m_ActuatorManager.StoredDiscreteActions; } /// @@ -1281,7 +1278,7 @@ void AgentStep() if ((m_RequestAction) && (m_Brain != null)) { m_RequestAction = false; - m_Actuators.ExecuteActions(); + m_ActuatorManager.ExecuteActions(); } if ((m_StepCount >= MaxStep) && (MaxStep > 0)) @@ -1293,13 +1290,13 @@ void AgentStep() void DecideAction() { - if (m_Actuators.StoredContinuousActions == null) + if (m_ActuatorManager.StoredContinuousActions == null) { ResetData(); } var action = m_Brain?.DecideAction() ?? (continuousActions : Array.Empty(), discreteActions : Array.Empty()); m_Info.CopyActions(action.continuousActions, action.discreteActions); - m_Actuators.UpdateActions(action.continuousActions, action.discreteActions); + m_ActuatorManager.UpdateActions(action.continuousActions, action.discreteActions); } } } From a543343e7222635f3d015a118bf38cf1cb564935 Mon Sep 17 00:00:00 2001 From: Christopher Goy Date: Thu, 6 Aug 2020 14:15:17 -0700 Subject: [PATCH 5/9] Revert meta file change. --- com.unity.ml-agents/Runtime/Actuators.meta | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Actuators.meta b/com.unity.ml-agents/Runtime/Actuators.meta index 588dd29561..6297a1d492 100644 --- a/com.unity.ml-agents/Runtime/Actuators.meta +++ b/com.unity.ml-agents/Runtime/Actuators.meta @@ -1,3 +1,8 @@ -fileFormatVersion: 2 +fileFormatVersion: 2 guid: 528e32f655ff4952b857b698e3efdd8e -timeCreated: 1592848300 \ No newline at end of file +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: From fe3fd6db42b515da1a298390d025d25c3d84c9fe Mon Sep 17 00:00:00 2001 From: Christopher Goy Date: Thu, 6 Aug 2020 14:19:00 -0700 Subject: [PATCH 6/9] Revert meta file completely. --- com.unity.ml-agents/Runtime/Actuators.meta | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/com.unity.ml-agents/Runtime/Actuators.meta b/com.unity.ml-agents/Runtime/Actuators.meta index 6297a1d492..96bbfb99b3 100644 --- a/com.unity.ml-agents/Runtime/Actuators.meta +++ b/com.unity.ml-agents/Runtime/Actuators.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 528e32f655ff4952b857b698e3efdd8e +guid: 26733e59183b6479e8f0e892a8bf09a4 folderAsset: yes DefaultImporter: externalObjects: {} From 7269c3aae12d74c15d7b0ebc9fb1ec0541f28318 Mon Sep 17 00:00:00 2001 From: Christopher Goy Date: Fri, 7 Aug 2020 15:48:29 -0700 Subject: [PATCH 7/9] Use ActionBuffers more widely. --- .../Runtime/Actuators/ActionSegment.cs | 23 ++++++ .../Runtime/Actuators/ActuatorManager.cs | 41 ++++++---- .../Runtime/Actuators/IActionReceiver.cs | 37 +++++++++ com.unity.ml-agents/Runtime/Agent.cs | 80 ++++++++----------- .../Runtime/Policies/BarracudaPolicy.cs | 14 +++- .../Runtime/Policies/HeuristicPolicy.cs | 17 ++-- .../Runtime/Policies/IPolicy.cs | 3 +- .../Runtime/Policies/RemotePolicy.cs | 10 ++- .../Editor/Actuators/ActuatorManagerTests.cs | 46 +++++------ .../Tests/Editor/BehaviorParameterTests.cs | 3 +- .../Tests/Editor/MLAgentsEditModeTest.cs | 4 +- 11 files changed, 176 insertions(+), 102 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs index 9e18bd95e9..d27997ce5d 100644 --- a/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs +++ b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs @@ -42,6 +42,13 @@ static void CheckParameters(T[] actionArray, int offset, int length) #endif } + /// + /// Construct an with just an actionArray. The will + /// be set to 0 and the will be set to `actionArray.Length`. + /// + /// The action array to use for the this segment. + public ActionSegment(T[] actionArray) : this(actionArray, 0, actionArray.Length) {} + /// /// Construct an with an underlying array /// and offset, and a length. @@ -78,6 +85,22 @@ public T this[int index] } return Array[Offset + index]; } + set + { + if (index < 0 || index > Length) + { + throw new IndexOutOfRangeException($"Index out of bounds, expected a number between 0 and {Length}"); + } + Array[Offset + index] = value; + } + } + + /// + /// Sets the segment of the backing array to all zeros. + /// + public void Clear() + { + System.Array.Clear(Array, Offset, Length); } /// diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs index 57ca0755b8..77777b10f3 100644 --- a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs +++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs @@ -50,12 +50,14 @@ internal class ActuatorManager : IList /// /// Returns the previously stored actions for the actuators in this list. /// - public float[] StoredContinuousActions { get; private set; } + // public float[] StoredContinuousActions { get; private set; } /// /// Returns the previously stored actions for the actuators in this list. /// - public int[] StoredDiscreteActions { get; private set; } + // public int[] StoredDiscreteActions { get; private set; } + + public ActionBuffers StoredActions { get; private set; } /// /// Create an ActuatorList with a preset capacity. @@ -99,8 +101,11 @@ internal void ReadyActuatorsForExecution(IList actuators, int numCont // Sort the Actuators by name to ensure determinism SortActuators(); - StoredContinuousActions = numContinuousActions == 0 ? Array.Empty() : new float[numContinuousActions]; - StoredDiscreteActions = numDiscreteBranches == 0 ? Array.Empty() : new int[numDiscreteBranches]; + var continuousActions = numContinuousActions == 0 ? ActionSegment.Empty : + new ActionSegment(new float[numContinuousActions]); + var discreteActions = numDiscreteBranches == 0 ? ActionSegment.Empty : new ActionSegment(new int[numDiscreteBranches]); + + StoredActions = new ActionBuffers(continuousActions, discreteActions); m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches); m_ReadyForExecution = true; } @@ -113,18 +118,19 @@ internal void ReadyActuatorsForExecution(IList actuators, int numCont /// continuous actions for the IActuators in this list. /// The action buffer which contains all of the /// discrete actions for the IActuators in this list. - public void UpdateActions(float[] continuousActionBuffer, int[] discreteActionBuffer) + public void UpdateActions(ActionBuffers actions) { ReadyActuatorsForExecution(); - UpdateActionArray(continuousActionBuffer, StoredContinuousActions); - UpdateActionArray(discreteActionBuffer, StoredDiscreteActions); + UpdateActionArray(actions.ContinuousActions, StoredActions.ContinuousActions); + UpdateActionArray(actions.DiscreteActions, StoredActions.DiscreteActions); } - static void UpdateActionArray(T[] sourceActionBuffer, T[] destination) + static void UpdateActionArray(ActionSegment sourceActionBuffer, ActionSegment destination) + where T : struct { - if (sourceActionBuffer == null || sourceActionBuffer.Length == 0) + if (sourceActionBuffer.Length <= 0) { - Array.Clear(destination, 0, destination.Length); + destination.Clear(); } else { @@ -132,7 +138,11 @@ static void UpdateActionArray(T[] sourceActionBuffer, T[] destination) $"sourceActionBuffer:{sourceActionBuffer.Length} is a different" + $" size than destination: {destination.Length}."); - Array.Copy(sourceActionBuffer, destination, destination.Length); + Array.Copy(sourceActionBuffer.Array, + sourceActionBuffer.Offset, + destination.Array, + destination.Offset, + destination.Length); } } @@ -176,7 +186,7 @@ public void ExecuteActions() var continuousActions = ActionSegment.Empty; if (numContinuousActions > 0) { - continuousActions = new ActionSegment(StoredContinuousActions, + continuousActions = new ActionSegment(StoredActions.ContinuousActions.Array, continuousStart, numContinuousActions); } @@ -184,7 +194,7 @@ public void ExecuteActions() var discreteActions = ActionSegment.Empty; if (numDiscreteActions > 0) { - discreteActions = new ActionSegment(StoredDiscreteActions, + discreteActions = new ActionSegment(StoredActions.DiscreteActions.Array, discreteStart, numDiscreteActions); } @@ -196,7 +206,7 @@ public void ExecuteActions() } /// - /// Resets the and buffers to be all + /// Resets the to be all /// zeros and calls on each managed by this object. /// public void ResetData() @@ -205,8 +215,7 @@ public void ResetData() { return; } - Array.Clear(StoredContinuousActions, 0, StoredContinuousActions.Length); - Array.Clear(StoredDiscreteActions, 0, StoredDiscreteActions.Length); + StoredActions.Clear(); for (var i = 0; i < m_Actuators.Count; i++) { m_Actuators[i].ResetData(); diff --git a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs index 9ff7623ecc..166079f524 100644 --- a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs +++ b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs @@ -24,6 +24,9 @@ public readonly struct ActionBuffers /// public ActionSegment DiscreteActions { get; } + public ActionBuffers(float[] continuousActions, int[] discreteActions) + : this(new ActionSegment(continuousActions), new ActionSegment(discreteActions)) { } + /// /// Construct an instance with the continuous and discrete actions that will /// be used. @@ -49,6 +52,12 @@ public override bool Equals(object obj) ab.DiscreteActions.SequenceEqual(DiscreteActions); } + public void Clear() + { + ContinuousActions.Clear(); + DiscreteActions.Clear(); + } + /// public override int GetHashCode() { @@ -57,6 +66,34 @@ public override int GetHashCode() return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode(); } } + + /// + /// Packs the continuous and discrete actions into one float array. The array passed into this method + /// must have a Length that is greater than or equal to the sum of the Lengths of + /// and . + /// + /// A float array to pack actions into whose length is greater than or + /// equal to the addition of the Lengths of this objects and + /// segments. + public void PackActions(in float[] destination) + { + + var start = 0; + if (ContinuousActions.Length > 0) + { + Array.Copy(ContinuousActions.Array, ContinuousActions.Offset, destination, start, ContinuousActions.Length); + start = ContinuousActions.Length; + } + if (start >= destination.Length) + { + return; + } + + if (DiscreteActions.Length > 0) + { + Array.Copy(DiscreteActions.Array, DiscreteActions.Offset, destination, start, DiscreteActions.Length); + } + } } /// diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 3e4e93d37c..d4610e1895 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -56,23 +56,9 @@ public void ClearActions() Array.Clear(storedVectorActions, 0, storedVectorActions.Length); } - public void CopyActions(float[] continuousActions, int[] discreteActions) + public void CopyActions(ActionBuffers actionBuffers) { - var start = 0; - if (continuousActions != null) - { - Array.Copy(continuousActions, 0, storedVectorActions, start, continuousActions.Length); - start = continuousActions.Length; - } - if (start >= storedVectorActions.Length) - { - return; - } - - if (discreteActions != null) - { - Array.Copy(discreteActions, 0, storedVectorActions, start, discreteActions.Length); - } + actionBuffers.PackActions(storedVectorActions); } } @@ -794,7 +780,8 @@ public virtual void Initialize() {} /// control of an agent using keyboard, mouse, or game controller input. /// /// Your heuristic implementation can use any decision making logic you specify. Assign decision - /// values to the float[] array, , passed to your function as a parameter. + /// values to the and + /// arrays , passed to your function as a parameter. /// The same array will be reused between steps. It is up to the user to initialize /// the values on each call, for example by calling `Array.Clear(actionsOut, 0, actionsOut.Length);`. /// Add values to the array at the same indexes as they are used in your @@ -830,39 +817,42 @@ public virtual void Initialize() {} /// You can also use the [Input System package], which provides a more flexible and /// configurable input system. /// - /// public override void Heuristic(float[] continuousActionsOut, int[] discreteActionsOut) + /// public override void Heuristic(ActionBuffers actionsOut) /// { - /// continuousActions[0] = Input.GetAxis("Horizontal"); - /// continuousActions[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f; - /// continuousActions[2] = Input.GetAxis("Vertical"); + /// actionsOut.ContinuousActions[0] = Input.GetAxis("Horizontal"); + /// actionsOut.ContinuousActions[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f; + /// actionsOut.ContinuousActions[2] = Input.GetAxis("Vertical"); /// } /// /// [Input Manager]: https://docs.unity3d.com/Manual/class-InputManager.html /// [Input System package]: https://docs.unity3d.com/Packages/com.unity.inputsystem@1.0/manual/index.html /// - /// Array to write the continuous actions to. - /// Array to write the discreteActions to. + /// The which contain the continuous and + /// discrete action buffers to write to. /// - public virtual void Heuristic(float[] continuousActionsOut, int[] discreteActionsOut) + public virtual void Heuristic(in ActionBuffers actionsOut) { - Debug.LogWarning("Heuristic method called but not implemented. Returning placeholder actions."); // For backward compatibility switch (m_PolicyFactory.BrainParameters.VectorActionSpaceType) { case SpaceType.Continuous: #pragma warning disable CS0618 - Heuristic(continuousActionsOut); + Heuristic(actionsOut.ContinuousActions.Array); #pragma warning restore CS0618 - Array.Clear(discreteActionsOut, 0, discreteActionsOut.Length); + actionsOut.DiscreteActions.Clear(); break; case SpaceType.Discrete: - var convertedOut = Array.ConvertAll(discreteActionsOut, x => (float)x); - #pragma warning disable CS0618 + var convertedOut = Array.ConvertAll(actionsOut.DiscreteActions.Array, x => (float)x); +#pragma warning disable CS0618 Heuristic(convertedOut); - #pragma warning restore CS0618 - var convertedBackToInt = Array.ConvertAll(convertedOut, x => (int)x); - Array.Copy(convertedBackToInt, 0, discreteActionsOut, 0, discreteActionsOut.Length); - Array.Clear(continuousActionsOut, 0, continuousActionsOut.Length); +#pragma warning restore CS0618 + var backToInt = Array.ConvertAll(convertedOut, x => (int)x); + Array.Copy(backToInt, + 0, + actionsOut.DiscreteActions.Array, + actionsOut.DiscreteActions.Offset, + actionsOut.DiscreteActions.Length); + actionsOut.ContinuousActions.Clear(); break; } } @@ -981,7 +971,7 @@ void SendInfoToBrain() } else { - m_Info.CopyActions(m_ActuatorManager.StoredContinuousActions, m_ActuatorManager.StoredDiscreteActions); + m_ActuatorManager.StoredActions.PackActions(m_Info.storedVectorActions); } UpdateSensors(); @@ -1077,7 +1067,7 @@ public virtual void CollectObservations(VectorSensor sensor) /// /// Returns a read-only view of the observations that were generated in /// . This is mainly useful inside of a - /// method to avoid recomputing the observations. + /// method to avoid recomputing the observations. /// /// A read-only view of the observations list. public ReadOnlyCollection GetObservations() @@ -1217,14 +1207,12 @@ public virtual void OnActionReceived(ActionBuffers actions) /// public virtual void OnEpisodeBegin() {} - public float[] GetStoredContinuousActions() - { - return m_ActuatorManager.StoredContinuousActions; - } - - public int[] GetStoredDiscreteActions() + /// + /// Gets the last ActionBuffer for this agent. + /// + public ActionBuffers GetStoredContinuousActions() { - return m_ActuatorManager.StoredDiscreteActions; + return m_ActuatorManager.StoredActions; } /// @@ -1290,13 +1278,13 @@ void AgentStep() void DecideAction() { - if (m_ActuatorManager.StoredContinuousActions == null) + if (m_ActuatorManager.StoredActions.ContinuousActions.Array == null) { ResetData(); } - var action = m_Brain?.DecideAction() ?? (continuousActions : Array.Empty(), discreteActions : Array.Empty()); - m_Info.CopyActions(action.continuousActions, action.discreteActions); - m_ActuatorManager.UpdateActions(action.continuousActions, action.discreteActions); + var actions = m_Brain?.DecideAction() ?? new ActionBuffers(); + m_Info.CopyActions(actions); + m_ActuatorManager.UpdateActions(actions); } } } diff --git a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs index dd4c802493..47da5d9d14 100644 --- a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs @@ -1,6 +1,7 @@ using System; using Unity.Barracuda; using System.Collections.Generic; +using Unity.MLAgents.Actuators; using Unity.MLAgents.Inference; using Unity.MLAgents.Sensors; @@ -30,6 +31,7 @@ public enum InferenceDevice internal class BarracudaPolicy : IPolicy { protected ModelRunner m_ModelRunner; + ActionBuffers m_LastActionBuffer; int m_AgentId; @@ -58,15 +60,21 @@ public void RequestDecision(AgentInfo info, List sensors) } /// - public (float[], int[]) DecideAction() + public ref readonly ActionBuffers DecideAction() { m_ModelRunner?.DecideBatch(); var actions = m_ModelRunner?.GetAction(m_AgentId); if (m_SapceType == SpaceType.Continuous) { - return (actions, Array.Empty()); + m_LastActionBuffer = new ActionBuffers(actions, Array.Empty()); + return ref m_LastActionBuffer; } - return (Array.Empty(), actions == null ? Array.Empty() : Array.ConvertAll(actions, x => (int)x)); + + // Need to use ConvertAll since you cannot copy a float[] int an int[]. + m_LastActionBuffer = new ActionBuffers(ActionSegment.Empty, actions == null ? ActionSegment.Empty + : new ActionSegment(Array.ConvertAll(actions, + x => (int)x))); + return ref m_LastActionBuffer; } public void Dispose() diff --git a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs index 32d111121f..232f459eb0 100644 --- a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using System; using System.Collections; +using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Policies @@ -12,10 +13,9 @@ namespace Unity.MLAgents.Policies /// internal class HeuristicPolicy : IPolicy { - public delegate void ActionGenerator(float[] continuousActionsOut, int[] discreteActionsOut); + public delegate void ActionGenerator(in ActionBuffers actionBuffers); ActionGenerator m_Heuristic; - float[] m_LastContinuousDecision; - int[] m_LastDiscreteDecision; + ActionBuffers m_ActionBuffers; bool m_Done; bool m_DecisionRequested; @@ -27,8 +27,9 @@ internal class HeuristicPolicy : IPolicy public HeuristicPolicy(ActionGenerator heuristic, int numContinuousActions, int numDiscreteActions) { m_Heuristic = heuristic; - m_LastContinuousDecision = new float[numContinuousActions]; - m_LastDiscreteDecision = new int[numDiscreteActions]; + var continuousDecision = new ActionSegment(new float[numContinuousActions], 0, numContinuousActions); + var discreteDecision = new ActionSegment(new int[numDiscreteActions], 0, numDiscreteActions); + m_ActionBuffers = new ActionBuffers(continuousDecision, discreteDecision); } /// @@ -40,14 +41,14 @@ public void RequestDecision(AgentInfo info, List sensors) } /// - public (float[], int[]) DecideAction() + public ref readonly ActionBuffers DecideAction() { if (!m_Done && m_DecisionRequested) { - m_Heuristic.Invoke(m_LastContinuousDecision, m_LastDiscreteDecision); + m_Heuristic.Invoke(m_ActionBuffers); } m_DecisionRequested = false; - return (m_LastContinuousDecision, m_LastDiscreteDecision); + return ref m_ActionBuffers; } public void Dispose() diff --git a/com.unity.ml-agents/Runtime/Policies/IPolicy.cs b/com.unity.ml-agents/Runtime/Policies/IPolicy.cs index 3dc85ffccb..4079a1f25a 100644 --- a/com.unity.ml-agents/Runtime/Policies/IPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/IPolicy.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Policies @@ -26,6 +27,6 @@ internal interface IPolicy : IDisposable /// it must be taken now. The Brain is expected to update the actions /// of the Agents at this point the latest. /// - (float[] continuousActions, int[] discreteActions) DecideAction(); + ref readonly ActionBuffers DecideAction(); } } diff --git a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs index dd6fb9935d..31354e1387 100644 --- a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs @@ -1,6 +1,7 @@ using UnityEngine; using System.Collections.Generic; using System; +using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Policies @@ -14,6 +15,7 @@ internal class RemotePolicy : IPolicy int m_AgentId; string m_FullyQualifiedBehaviorName; SpaceType m_SpaceType; + ActionBuffers m_LasActionBuffer; internal ICommunicator m_Communicator; @@ -36,15 +38,17 @@ public void RequestDecision(AgentInfo info, List sensors) } /// - public (float[], int[]) DecideAction() + public ref readonly ActionBuffers DecideAction() { m_Communicator?.DecideBatch(); var actions = m_Communicator?.GetActions(m_FullyQualifiedBehaviorName, m_AgentId); if (m_SpaceType == SpaceType.Continuous) { - return (actions, Array.Empty()); + m_LasActionBuffer = new ActionBuffers(actions, Array.Empty()); + return ref m_LasActionBuffer; } - return (Array.Empty(), Array.ConvertAll(actions, x => (int)x)); + m_LasActionBuffer = new ActionBuffers(Array.Empty(), Array.ConvertAll(actions ?? Array.Empty(), x => (int)x)); + return ref m_LasActionBuffer; } public void Dispose() diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs index 974c1fd35d..c115ae9879 100644 --- a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs @@ -29,14 +29,14 @@ public void TestEnsureBufferSizeContinuous() actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes, actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions); - manager.UpdateActions(new[] - { 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f }, Array.Empty()); + manager.UpdateActions(new ActionBuffers(new[] + { 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f }, Array.Empty())); Assert.IsTrue(12 == manager.NumContinuousActions); Assert.IsTrue(0 == manager.NumDiscreteActions); Assert.IsTrue(0 == manager.SumOfDiscreteBranchSizes); - Assert.IsTrue(12 == manager.StoredContinuousActions.Length); - Assert.IsTrue(0 == manager.StoredDiscreteActions.Length); + Assert.IsTrue(12 == manager.StoredActions.ContinuousActions.Length); + Assert.IsTrue(0 == manager.StoredActions.DiscreteActions.Length); } [Test] @@ -54,14 +54,14 @@ public void TestEnsureBufferDiscrete() actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes, actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions); - manager.UpdateActions(Array.Empty(), - new[] { 0, 1, 2, 3, 4, 5, 6}); + manager.UpdateActions(new ActionBuffers(Array.Empty(), + new[] { 0, 1, 2, 3, 4, 5, 6})); Assert.IsTrue(0 == manager.NumContinuousActions); Assert.IsTrue(7 == manager.NumDiscreteActions); Assert.IsTrue(13 == manager.SumOfDiscreteBranchSizes); - Assert.IsTrue(0 == manager.StoredContinuousActions.Length); - Assert.IsTrue(7 == manager.StoredDiscreteActions.Length); + Assert.IsTrue(0 == manager.StoredActions.ContinuousActions.Length); + Assert.IsTrue(7 == manager.StoredActions.DiscreteActions.Length); } [Test] @@ -98,8 +98,8 @@ public void TestExecuteActionsDiscrete() manager.Add(actuator2); var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5, 6}; - manager.UpdateActions(Array.Empty(), - discreteActionBuffer); + manager.UpdateActions(new ActionBuffers(Array.Empty(), + discreteActionBuffer)); manager.ExecuteActions(); var actuator1Actions = actuator1.LastActionBuffer.DiscreteActions; @@ -118,8 +118,8 @@ public void TestExecuteActionsContinuous() manager.Add(actuator2); var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f}; - manager.UpdateActions(continuousActionBuffer, - Array.Empty()); + manager.UpdateActions(new ActionBuffers(continuousActionBuffer, + Array.Empty())); manager.ExecuteActions(); var actuator1Actions = actuator1.LastActionBuffer.ContinuousActions; @@ -149,10 +149,10 @@ public void TestUpdateActionsContinuous() manager.Add(actuator1); manager.Add(actuator2); var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f}; - manager.UpdateActions(continuousActionBuffer, - Array.Empty()); + manager.UpdateActions(new ActionBuffers(continuousActionBuffer, + Array.Empty())); - Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer)); + Assert.IsTrue(manager.StoredActions.ContinuousActions.SequenceEqual(continuousActionBuffer)); } [Test] @@ -165,12 +165,12 @@ public void TestUpdateActionsDiscrete() manager.Add(actuator1); manager.Add(actuator2); var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5}; - manager.UpdateActions(Array.Empty(), - discreteActionBuffer); + manager.UpdateActions(new ActionBuffers(Array.Empty(), + discreteActionBuffer)); - Debug.Log(manager.StoredDiscreteActions); + Debug.Log(manager.StoredActions.DiscreteActions); Debug.Log(discreteActionBuffer); - Assert.IsTrue(manager.StoredDiscreteActions.SequenceEqual(discreteActionBuffer)); + Assert.IsTrue(manager.StoredActions.DiscreteActions.SequenceEqual(discreteActionBuffer)); } [Test] @@ -261,14 +261,14 @@ public void TestResetData() manager.Add(actuator1); manager.Add(actuator2); var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f}; - manager.UpdateActions(continuousActionBuffer, - Array.Empty()); + manager.UpdateActions(new ActionBuffers(continuousActionBuffer, + Array.Empty())); - Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer)); + Assert.IsTrue(manager.StoredActions.ContinuousActions.SequenceEqual(continuousActionBuffer)); Assert.IsTrue(manager.NumContinuousActions == 6); manager.ResetData(); - Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(new[] { 0f, 0f, 0f, 0f, 0f, 0f})); + Assert.IsTrue(manager.StoredActions.ContinuousActions.SequenceEqual(new[] { 0f, 0f, 0f, 0f, 0f, 0f})); } [Test] diff --git a/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs index 5b18e6427f..aa0a87ed32 100644 --- a/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs +++ b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs @@ -1,4 +1,5 @@ using NUnit.Framework; +using Unity.MLAgents.Actuators; using UnityEngine; using Unity.MLAgents.Policies; @@ -7,7 +8,7 @@ namespace Unity.MLAgents.Tests [TestFixture] public class BehaviorParameterTests { - static void DummyHeuristic(float[] actionsOut, int[] discreteActionsOut) + static void DummyHeuristic(in ActionBuffers actionsOut) { // No-op } diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index 599cf3e64c..b51b8fc2e6 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -3,6 +3,7 @@ using NUnit.Framework; using System.Reflection; using System.Collections.Generic; +using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; using Unity.MLAgents.Sensors.Reflection; using Unity.MLAgents.Policies; @@ -14,6 +15,7 @@ internal class TestPolicy : IPolicy { public Action OnRequestDecision; ObservationWriter m_ObsWriter = new ObservationWriter(); + static ActionBuffers s_EmptyActionBuffers = new ActionBuffers(Array.Empty(), Array.Empty()); public void RequestDecision(AgentInfo info, List sensors) { foreach (var sensor in sensors) @@ -23,7 +25,7 @@ public void RequestDecision(AgentInfo info, List sensors) OnRequestDecision?.Invoke(); } - public (float[] continuousActions, int[] discreteActions) DecideAction() { return (new float[0], new int[0]); } + public ref readonly ActionBuffers DecideAction() { return ref s_EmptyActionBuffers; } public void Dispose() {} } From b75e33ab964f314d490c2f6488c0d97992a59295 Mon Sep 17 00:00:00 2001 From: Christopher Goy Date: Thu, 13 Aug 2020 14:09:39 -0700 Subject: [PATCH 8/9] PR feedback. --- .../Runtime/Actuators/ActionSegment.cs | 10 ++-- .../Runtime/Actuators/IActionReceiver.cs | 47 +++++++++++++++---- com.unity.ml-agents/Runtime/Agent.cs | 22 +-------- .../Runtime/Agent.deprecated.cs | 12 ++--- .../Runtime/Policies/BarracudaPolicy.cs | 11 ++--- .../Runtime/Policies/RemotePolicy.cs | 10 ++-- 6 files changed, 62 insertions(+), 50 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs index d27997ce5d..8b492ef09e 100644 --- a/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs +++ b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs @@ -30,14 +30,14 @@ namespace Unity.MLAgents.Actuators /// public static ActionSegment Empty = new ActionSegment(System.Array.Empty(), 0, 0); - static void CheckParameters(T[] actionArray, int offset, int length) + static void CheckParameters(IReadOnlyCollection actionArray, int offset, int length) { #if DEBUG - if (offset + length > actionArray.Length) + if (offset + length > actionArray.Count) { throw new ArgumentOutOfRangeException(nameof(offset), $"Arguments offset: {offset} and length: {length} " + - $"are out of bounds of actionArray: {actionArray.Length}."); + $"are out of bounds of actionArray: {actionArray.Count}."); } #endif } @@ -47,7 +47,7 @@ static void CheckParameters(T[] actionArray, int offset, int length) /// be set to 0 and the will be set to `actionArray.Length`. /// /// The action array to use for the this segment. - public ActionSegment(T[] actionArray) : this(actionArray, 0, actionArray.Length) {} + public ActionSegment(T[] actionArray) : this(actionArray, 0, actionArray.Length) { } /// /// Construct an with an underlying array @@ -58,7 +58,9 @@ public ActionSegment(T[] actionArray) : this(actionArray, 0, actionArray.Length) /// The length of the segment. public ActionSegment(T[] actionArray, int offset, int length) { +#if DEBUG CheckParameters(actionArray, offset, length); +#endif Array = actionArray; Offset = offset; Length = length; diff --git a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs index 166079f524..88559cb0f3 100644 --- a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs +++ b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs @@ -1,5 +1,6 @@ using System; using System.Linq; +using UnityEngine; namespace Unity.MLAgents.Actuators { @@ -24,6 +25,21 @@ public readonly struct ActionBuffers /// public ActionSegment DiscreteActions { get; } + /// + /// Create an instance with discrete actions stored as a float array. This exists + /// to achieve backward compatibility with the former Agent methods which used a float array for both continuous + /// and discrete actions. + /// + /// The float array of discrete actions. + /// An instance initialized with a + /// initialized from a float array. + public static ActionBuffers FromDiscreteActions(float[] discreteActions) + { + return new ActionBuffers(ActionSegment.Empty, discreteActions == null ? ActionSegment.Empty + : new ActionSegment(Array.ConvertAll(discreteActions, + x => (int)x))); + } + public ActionBuffers(float[] continuousActions, int[] discreteActions) : this(new ActionSegment(continuousActions), new ActionSegment(discreteActions)) { } @@ -39,6 +55,15 @@ public ActionBuffers(ActionSegment continuousActions, ActionSegment DiscreteActions = discreteActions; } + /// + /// Clear the and segments to be all zeros. + /// + public void Clear() + { + ContinuousActions.Clear(); + DiscreteActions.Clear(); + } + /// public override bool Equals(object obj) { @@ -52,12 +77,6 @@ public override bool Equals(object obj) ab.DiscreteActions.SequenceEqual(DiscreteActions); } - public void Clear() - { - ContinuousActions.Clear(); - DiscreteActions.Clear(); - } - /// public override int GetHashCode() { @@ -77,11 +96,19 @@ public override int GetHashCode() /// segments. public void PackActions(in float[] destination) { + Debug.Assert(destination.Length >= ContinuousActions.Length + DiscreteActions.Length, + $"argument '{nameof(destination)}' is not large enough to pack the actions into.\n" + + $"{nameof(destination)}.Length: {destination.Length}\n" + + $"{nameof(ContinuousActions)}.Length + {nameof(DiscreteActions)}.Length: {ContinuousActions.Length + DiscreteActions.Length}"); var start = 0; if (ContinuousActions.Length > 0) { - Array.Copy(ContinuousActions.Array, ContinuousActions.Offset, destination, start, ContinuousActions.Length); + Array.Copy(ContinuousActions.Array, + ContinuousActions.Offset, + destination, + start, + ContinuousActions.Length); start = ContinuousActions.Length; } if (start >= destination.Length) @@ -91,7 +118,11 @@ public void PackActions(in float[] destination) if (DiscreteActions.Length > 0) { - Array.Copy(DiscreteActions.Array, DiscreteActions.Offset, destination, start, DiscreteActions.Length); + Array.Copy(DiscreteActions.Array, + DiscreteActions.Offset, + destination, + start, + DiscreteActions.Length); } } } diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index d4610e1895..9f456a172c 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -288,7 +288,7 @@ internal struct AgentParameters /// /// VectorActuator which is used by default if no other sensors exist on this Agent. This VectorSensor will - /// delegate its actions to by default in order to keep backward compatibility + /// delegate its actions to by default in order to keep backward compatibility /// with the current behavior of Agent. /// IActuator m_VectorActuator; @@ -1175,25 +1175,7 @@ public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask) /// public virtual void OnActionReceived(ActionBuffers actions) { - // Copy the actions into our local array and call the original method for - // backward compatibility. - // For now we need to check which array has the actions in them in order to pass it back to the old method. - if (actions.ContinuousActions.Length > 0) - { - Array.Copy(actions.ContinuousActions.Array, - actions.ContinuousActions.Offset, - m_LegacyActionCache, - 0, - actions.ContinuousActions.Length); - } - else if (actions.DiscreteActions.Length > 0) - { - Array.Copy(actions.DiscreteActions.Array, - actions.DiscreteActions.Offset, - m_LegacyActionCache, - 0, - actions.DiscreteActions.Length); - } + actions.PackActions(m_LegacyActionCache); #pragma warning disable CS0618 OnActionReceived(m_LegacyActionCache); #pragma warning restore CS0618 diff --git a/com.unity.ml-agents/Runtime/Agent.deprecated.cs b/com.unity.ml-agents/Runtime/Agent.deprecated.cs index cbd785400a..8a09370957 100644 --- a/com.unity.ml-agents/Runtime/Agent.deprecated.cs +++ b/com.unity.ml-agents/Runtime/Agent.deprecated.cs @@ -5,20 +5,20 @@ namespace Unity.MLAgents { public partial class Agent { - // [Obsolete("CollectDiscreteActionMasks has been deprecated. Please use WriteDiscreteActionMask instead.", false)] public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker) { } - // [Obsolete("The Heuristic(float[]) method has been deprecated. Please use Heuristic(float[], int[]) instead.")] - public virtual void Heuristic(float[] continuousActionsOut) + /// + /// This method passes in a float array that is to be populated with actions. The actions + /// + /// + public virtual void Heuristic(float[] actionsOut) { Debug.LogWarning("Heuristic method called but not implemented. Returning placeholder actions."); - Array.Clear(continuousActionsOut, 0, continuousActionsOut.Length); + Array.Clear(actionsOut, 0, actionsOut.Length); } - // [Obsolete("The OnActionReceived(float[]) method has been deprecated" + - // " Please use OnActionReceived(ActionSegment, ActionSegment).", false)] public virtual void OnActionReceived(float[] vectorAction) {} /// diff --git a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs index 47da5d9d14..b583e4aa39 100644 --- a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs @@ -39,7 +39,7 @@ internal class BarracudaPolicy : IPolicy /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors. /// List m_SensorShapes; - SpaceType m_SapceType; + SpaceType m_SpaceType; /// public BarracudaPolicy( @@ -49,7 +49,7 @@ public BarracudaPolicy( { var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, brainParameters, inferenceDevice); m_ModelRunner = modelRunner; - m_SapceType = brainParameters.VectorActionSpaceType; + m_SpaceType = brainParameters.VectorActionSpaceType; } /// @@ -64,16 +64,13 @@ public ref readonly ActionBuffers DecideAction() { m_ModelRunner?.DecideBatch(); var actions = m_ModelRunner?.GetAction(m_AgentId); - if (m_SapceType == SpaceType.Continuous) + if (m_SpaceType == SpaceType.Continuous) { m_LastActionBuffer = new ActionBuffers(actions, Array.Empty()); return ref m_LastActionBuffer; } - // Need to use ConvertAll since you cannot copy a float[] int an int[]. - m_LastActionBuffer = new ActionBuffers(ActionSegment.Empty, actions == null ? ActionSegment.Empty - : new ActionSegment(Array.ConvertAll(actions, - x => (int)x))); + m_LastActionBuffer = ActionBuffers.FromDiscreteActions(actions); return ref m_LastActionBuffer; } diff --git a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs index 31354e1387..f8c7c76eaa 100644 --- a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs @@ -15,7 +15,7 @@ internal class RemotePolicy : IPolicy int m_AgentId; string m_FullyQualifiedBehaviorName; SpaceType m_SpaceType; - ActionBuffers m_LasActionBuffer; + ActionBuffers m_LastActionBuffer; internal ICommunicator m_Communicator; @@ -44,11 +44,11 @@ public ref readonly ActionBuffers DecideAction() var actions = m_Communicator?.GetActions(m_FullyQualifiedBehaviorName, m_AgentId); if (m_SpaceType == SpaceType.Continuous) { - m_LasActionBuffer = new ActionBuffers(actions, Array.Empty()); - return ref m_LasActionBuffer; + m_LastActionBuffer = new ActionBuffers(actions, Array.Empty()); + return ref m_LastActionBuffer; } - m_LasActionBuffer = new ActionBuffers(Array.Empty(), Array.ConvertAll(actions ?? Array.Empty(), x => (int)x)); - return ref m_LasActionBuffer; + m_LastActionBuffer = ActionBuffers.FromDiscreteActions(actions); + return ref m_LastActionBuffer; } public void Dispose() From 8fbca5def8ed9385c334843ff2a9b08d5f86729f Mon Sep 17 00:00:00 2001 From: Christopher Goy Date: Thu, 13 Aug 2020 14:22:34 -0700 Subject: [PATCH 9/9] Remove extra allocation. Remove pragma directives. --- com.unity.ml-agents/Runtime/Agent.cs | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 9f456a172c..5e34c4986e 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -836,22 +836,17 @@ public virtual void Heuristic(in ActionBuffers actionsOut) switch (m_PolicyFactory.BrainParameters.VectorActionSpaceType) { case SpaceType.Continuous: - #pragma warning disable CS0618 Heuristic(actionsOut.ContinuousActions.Array); - #pragma warning restore CS0618 actionsOut.DiscreteActions.Clear(); break; case SpaceType.Discrete: var convertedOut = Array.ConvertAll(actionsOut.DiscreteActions.Array, x => (float)x); -#pragma warning disable CS0618 Heuristic(convertedOut); -#pragma warning restore CS0618 - var backToInt = Array.ConvertAll(convertedOut, x => (int)x); - Array.Copy(backToInt, - 0, - actionsOut.DiscreteActions.Array, - actionsOut.DiscreteActions.Offset, - actionsOut.DiscreteActions.Length); + var discreteActionSegment = actionsOut.DiscreteActions; + for (var i = 0; i < actionsOut.DiscreteActions.Length; i++) + { + discreteActionSegment[i] = (int)convertedOut[i]; + } actionsOut.ContinuousActions.Clear(); break; } @@ -1098,9 +1093,7 @@ public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { m_ActionMasker = new DiscreteActionMasker(actionMask); } - #pragma warning disable 618 CollectDiscreteActionMasks(m_ActionMasker); - #pragma warning restore 618 } ActionSpec IActionReceiver.ActionSpec { get; } @@ -1176,9 +1169,7 @@ public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask) public virtual void OnActionReceived(ActionBuffers actions) { actions.PackActions(m_LegacyActionCache); - #pragma warning disable CS0618 OnActionReceived(m_LegacyActionCache); - #pragma warning restore CS0618 } ///