diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
index feb06a708d..8b492ef09e 100644
--- a/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
+++ b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
@@ -11,7 +11,7 @@ namespace Unity.MLAgents.Actuators
/// the offset into the original array, and an length.
///
/// The type of object stored in the underlying
- internal readonly struct ActionSegment : IEnumerable, IEquatable>
+ public readonly struct ActionSegment : IEnumerable, IEquatable>
where T : struct
{
///
@@ -30,18 +30,25 @@ namespace Unity.MLAgents.Actuators
///
public static ActionSegment Empty = new ActionSegment(System.Array.Empty(), 0, 0);
- static void CheckParameters(T[] actionArray, int offset, int length)
+ static void CheckParameters(IReadOnlyCollection actionArray, int offset, int length)
{
#if DEBUG
- if (offset + length > actionArray.Length)
+ if (offset + length > actionArray.Count)
{
throw new ArgumentOutOfRangeException(nameof(offset),
$"Arguments offset: {offset} and length: {length} " +
- $"are out of bounds of actionArray: {actionArray.Length}.");
+ $"are out of bounds of actionArray: {actionArray.Count}.");
}
#endif
}
+ ///
+ /// Construct an with just an actionArray. The will
+ /// be set to 0 and the will be set to `actionArray.Length`.
+ ///
+ /// The action array to use for the this segment.
+ public ActionSegment(T[] actionArray) : this(actionArray, 0, actionArray.Length) { }
+
///
/// Construct an with an underlying array
/// and offset, and a length.
@@ -51,7 +58,9 @@ static void CheckParameters(T[] actionArray, int offset, int length)
/// The length of the segment.
public ActionSegment(T[] actionArray, int offset, int length)
{
+#if DEBUG
CheckParameters(actionArray, offset, length);
+#endif
Array = actionArray;
Offset = offset;
Length = length;
@@ -78,6 +87,22 @@ public T this[int index]
}
return Array[Offset + index];
}
+ set
+ {
+ if (index < 0 || index > Length)
+ {
+ throw new IndexOutOfRangeException($"Index out of bounds, expected a number between 0 and {Length}");
+ }
+ Array[Offset + index] = value;
+ }
+ }
+
+ ///
+ /// Sets the segment of the backing array to all zeros.
+ ///
+ public void Clear()
+ {
+ System.Array.Clear(Array, Offset, Length);
}
///
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs b/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
index fbee0c4476..5b9175680f 100644
--- a/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
+++ b/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
@@ -8,7 +8,7 @@ namespace Unity.MLAgents.Actuators
///
/// Defines the structure of an Action Space to be used by the Actuator system.
///
- internal readonly struct ActionSpec
+ public readonly struct ActionSpec
{
///
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
index a1b953118f..77777b10f3 100644
--- a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
+++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
@@ -50,12 +50,14 @@ internal class ActuatorManager : IList
///
/// Returns the previously stored actions for the actuators in this list.
///
- public float[] StoredContinuousActions { get; private set; }
+ // public float[] StoredContinuousActions { get; private set; }
///
/// Returns the previously stored actions for the actuators in this list.
///
- public int[] StoredDiscreteActions { get; private set; }
+ // public int[] StoredDiscreteActions { get; private set; }
+
+ public ActionBuffers StoredActions { get; private set; }
///
/// Create an ActuatorList with a preset capacity.
@@ -99,8 +101,11 @@ internal void ReadyActuatorsForExecution(IList actuators, int numCont
// Sort the Actuators by name to ensure determinism
SortActuators();
- StoredContinuousActions = numContinuousActions == 0 ? Array.Empty() : new float[numContinuousActions];
- StoredDiscreteActions = numDiscreteBranches == 0 ? Array.Empty() : new int[numDiscreteBranches];
+ var continuousActions = numContinuousActions == 0 ? ActionSegment.Empty :
+ new ActionSegment(new float[numContinuousActions]);
+ var discreteActions = numDiscreteBranches == 0 ? ActionSegment.Empty : new ActionSegment(new int[numDiscreteBranches]);
+
+ StoredActions = new ActionBuffers(continuousActions, discreteActions);
m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches);
m_ReadyForExecution = true;
}
@@ -113,18 +118,19 @@ internal void ReadyActuatorsForExecution(IList actuators, int numCont
/// continuous actions for the IActuators in this list.
/// The action buffer which contains all of the
/// discrete actions for the IActuators in this list.
- public void UpdateActions(float[] continuousActionBuffer, int[] discreteActionBuffer)
+ public void UpdateActions(ActionBuffers actions)
{
ReadyActuatorsForExecution();
- UpdateActionArray(continuousActionBuffer, StoredContinuousActions);
- UpdateActionArray(discreteActionBuffer, StoredDiscreteActions);
+ UpdateActionArray(actions.ContinuousActions, StoredActions.ContinuousActions);
+ UpdateActionArray(actions.DiscreteActions, StoredActions.DiscreteActions);
}
- static void UpdateActionArray(T[] sourceActionBuffer, T[] destination)
+ static void UpdateActionArray(ActionSegment sourceActionBuffer, ActionSegment destination)
+ where T : struct
{
- if (sourceActionBuffer == null || sourceActionBuffer.Length == 0)
+ if (sourceActionBuffer.Length <= 0)
{
- Array.Clear(destination, 0, destination.Length);
+ destination.Clear();
}
else
{
@@ -132,7 +138,11 @@ static void UpdateActionArray(T[] sourceActionBuffer, T[] destination)
$"sourceActionBuffer:{sourceActionBuffer.Length} is a different" +
$" size than destination: {destination.Length}.");
- Array.Copy(sourceActionBuffer, destination, destination.Length);
+ Array.Copy(sourceActionBuffer.Array,
+ sourceActionBuffer.Offset,
+ destination.Array,
+ destination.Offset,
+ destination.Length);
}
}
@@ -148,9 +158,12 @@ public void WriteActionMask()
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
- m_DiscreteActionMask.CurrentBranchOffset = offset;
- actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
- offset += actuator.ActionSpec.NumDiscreteActions;
+ if (actuator.ActionSpec.NumDiscreteActions > 0)
+ {
+ m_DiscreteActionMask.CurrentBranchOffset = offset;
+ actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
+ offset += actuator.ActionSpec.NumDiscreteActions;
+ }
}
}
@@ -173,7 +186,7 @@ public void ExecuteActions()
var continuousActions = ActionSegment.Empty;
if (numContinuousActions > 0)
{
- continuousActions = new ActionSegment(StoredContinuousActions,
+ continuousActions = new ActionSegment(StoredActions.ContinuousActions.Array,
continuousStart,
numContinuousActions);
}
@@ -181,7 +194,7 @@ public void ExecuteActions()
var discreteActions = ActionSegment.Empty;
if (numDiscreteActions > 0)
{
- discreteActions = new ActionSegment(StoredDiscreteActions,
+ discreteActions = new ActionSegment(StoredActions.DiscreteActions.Array,
discreteStart,
numDiscreteActions);
}
@@ -193,7 +206,7 @@ public void ExecuteActions()
}
///
- /// Resets the and buffers to be all
+ /// Resets the to be all
/// zeros and calls on each managed by this object.
///
public void ResetData()
@@ -202,12 +215,12 @@ public void ResetData()
{
return;
}
- Array.Clear(StoredContinuousActions, 0, StoredContinuousActions.Length);
- Array.Clear(StoredDiscreteActions, 0, StoredDiscreteActions.Length);
+ StoredActions.Clear();
for (var i = 0; i < m_Actuators.Count; i++)
{
m_Actuators[i].ResetData();
}
+ m_DiscreteActionMask.ResetMask();
}
diff --git a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
index 4e2a251f10..88559cb0f3 100644
--- a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
+++ b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
@@ -1,5 +1,6 @@
using System;
using System.Linq;
+using UnityEngine;
namespace Unity.MLAgents.Actuators
{
@@ -7,7 +8,7 @@ namespace Unity.MLAgents.Actuators
/// A structure that wraps the s for a particular and is
/// used when is called.
///
- internal readonly struct ActionBuffers
+ public readonly struct ActionBuffers
{
///
/// An empty action buffer.
@@ -24,6 +25,24 @@ internal readonly struct ActionBuffers
///
public ActionSegment DiscreteActions { get; }
+ ///
+ /// Create an instance with discrete actions stored as a float array. This exists
+ /// to achieve backward compatibility with the former Agent methods which used a float array for both continuous
+ /// and discrete actions.
+ ///
+ /// The float array of discrete actions.
+ /// An instance initialized with a
+ /// initialized from a float array.
+ public static ActionBuffers FromDiscreteActions(float[] discreteActions)
+ {
+ return new ActionBuffers(ActionSegment.Empty, discreteActions == null ? ActionSegment.Empty
+ : new ActionSegment(Array.ConvertAll(discreteActions,
+ x => (int)x)));
+ }
+
+ public ActionBuffers(float[] continuousActions, int[] discreteActions)
+ : this(new ActionSegment(continuousActions), new ActionSegment(discreteActions)) { }
+
///
/// Construct an instance with the continuous and discrete actions that will
/// be used.
@@ -36,6 +55,15 @@ public ActionBuffers(ActionSegment continuousActions, ActionSegment
DiscreteActions = discreteActions;
}
+ ///
+ /// Clear the and segments to be all zeros.
+ ///
+ public void Clear()
+ {
+ ContinuousActions.Clear();
+ DiscreteActions.Clear();
+ }
+
///
public override bool Equals(object obj)
{
@@ -57,12 +85,52 @@ public override int GetHashCode()
return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode();
}
}
+
+ ///
+ /// Packs the continuous and discrete actions into one float array. The array passed into this method
+ /// must have a Length that is greater than or equal to the sum of the Lengths of
+ /// and .
+ ///
+ /// A float array to pack actions into whose length is greater than or
+ /// equal to the addition of the Lengths of this objects and
+ /// segments.
+ public void PackActions(in float[] destination)
+ {
+ Debug.Assert(destination.Length >= ContinuousActions.Length + DiscreteActions.Length,
+ $"argument '{nameof(destination)}' is not large enough to pack the actions into.\n" +
+ $"{nameof(destination)}.Length: {destination.Length}\n" +
+ $"{nameof(ContinuousActions)}.Length + {nameof(DiscreteActions)}.Length: {ContinuousActions.Length + DiscreteActions.Length}");
+
+ var start = 0;
+ if (ContinuousActions.Length > 0)
+ {
+ Array.Copy(ContinuousActions.Array,
+ ContinuousActions.Offset,
+ destination,
+ start,
+ ContinuousActions.Length);
+ start = ContinuousActions.Length;
+ }
+ if (start >= destination.Length)
+ {
+ return;
+ }
+
+ if (DiscreteActions.Length > 0)
+ {
+ Array.Copy(DiscreteActions.Array,
+ DiscreteActions.Offset,
+ destination,
+ start,
+ DiscreteActions.Length);
+ }
+ }
}
///
/// An interface that describes an object that can receive actions from a Reinforcement Learning network.
///
- internal interface IActionReceiver
+ public interface IActionReceiver
{
///
diff --git a/com.unity.ml-agents/Runtime/Actuators/IActuator.cs b/com.unity.ml-agents/Runtime/Actuators/IActuator.cs
index eedb940a36..5988f0aab3 100644
--- a/com.unity.ml-agents/Runtime/Actuators/IActuator.cs
+++ b/com.unity.ml-agents/Runtime/Actuators/IActuator.cs
@@ -6,7 +6,7 @@ namespace Unity.MLAgents.Actuators
///
/// Abstraction that facilitates the execution of actions.
///
- internal interface IActuator : IActionReceiver
+ public interface IActuator : IActionReceiver
{
int TotalNumberOfActions { get; }
diff --git a/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
index 7cb0e99f72..30f8425792 100644
--- a/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
+++ b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
@@ -5,7 +5,7 @@ namespace Unity.MLAgents.Actuators
///
/// Interface for writing a mask to disable discrete actions for agents for the next decision.
///
- internal interface IDiscreteActionMask
+ public interface IDiscreteActionMask
{
///
/// Modifies an action mask for discrete control agents.
diff --git a/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
index e2635c4164..7e8d6d7642 100644
--- a/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
+++ b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
@@ -4,7 +4,7 @@
namespace Unity.MLAgents.Actuators
{
- internal class VectorActuator : IActuator
+ public class VectorActuator : IActuator
{
IActionReceiver m_ActionReceiver;
diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs
index cad4575f0b..5e34c4986e 100644
--- a/com.unity.ml-agents/Runtime/Agent.cs
+++ b/com.unity.ml-agents/Runtime/Agent.cs
@@ -1,8 +1,10 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
+using System.Linq;
using UnityEngine;
using Unity.Barracuda;
+using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using Unity.MLAgents.Demonstrations;
@@ -48,15 +50,16 @@ internal struct AgentInfo
/// to separate between different agents in the environment.
///
public int episodeId;
- }
- ///
- /// Struct that contains the action information sent from the Brain to the
- /// Agent.
- ///
- internal struct AgentAction
- {
- public float[] vectorActions;
+ public void ClearActions()
+ {
+ Array.Clear(storedVectorActions, 0, storedVectorActions.Length);
+ }
+
+ public void CopyActions(ActionBuffers actionBuffers)
+ {
+ actionBuffers.PackActions(storedVectorActions);
+ }
}
///
@@ -106,7 +109,7 @@ internal struct AgentAction
/// can only take an action when it touches the ground, so several frames might elapse between
/// one decision and the need for the next.
///
- /// Use the function to implement the actions your agent can take,
+ /// Use the function to implement the actions your agent can take,
/// such as moving to reach a goal or interacting with its environment.
///
/// When you call on an agent or the agent reaches its count,
@@ -155,7 +158,7 @@ internal struct AgentAction
"docs/Learning-Environment-Design-Agents.md")]
[Serializable]
[RequireComponent(typeof(BehaviorParameters))]
- public class Agent : MonoBehaviour, ISerializationCallbackReceiver
+ public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver
{
IPolicy m_Brain;
BehaviorParameters m_PolicyFactory;
@@ -222,9 +225,6 @@ internal struct AgentParameters
/// Current Agent information (message sent to Brain).
AgentInfo m_Info;
- /// Current Agent action (message sent from Brain).
- AgentAction m_Action;
-
/// Represents the reward the agent accumulated during the current step.
/// It is reset to 0 at the beginning of every step.
/// Should be set to a positive value when the agent performs a "good"
@@ -281,6 +281,24 @@ internal struct AgentParameters
///
internal VectorSensor collectObservationsSensor;
+ ///
+ /// List of IActuators that this Agent will delegate actions to if any exist.
+ ///
+ ActuatorManager m_ActuatorManager;
+
+ ///
+ /// VectorActuator which is used by default if no other sensors exist on this Agent. This VectorSensor will
+ /// delegate its actions to by default in order to keep backward compatibility
+ /// with the current behavior of Agent.
+ ///
+ IActuator m_VectorActuator;
+
+ ///
+ /// This is used to avoid allocation of a float array every frame if users are still using the old
+ /// OnActionReceived method.
+ ///
+ float[] m_LegacyActionCache;
+
///
/// Called when the attached [GameObject] becomes enabled and active.
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
@@ -385,7 +403,6 @@ public void LazyInitialize()
m_PolicyFactory = GetComponent();
m_Info = new AgentInfo();
- m_Action = new AgentAction();
sensors = new List();
Academy.Instance.AgentIncrementStep += AgentIncrementStep;
@@ -402,6 +419,13 @@ public void LazyInitialize()
InitializeSensors();
}
+ using (TimerStack.Instance.Scoped("InitializeActuators"))
+ {
+ InitializeActuators();
+ }
+
+ m_Info.storedVectorActions = new float[m_ActuatorManager.TotalNumberOfActions];
+
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the event.
// To avoid the Agent resetting twice, the Agents will not begin their
@@ -624,7 +648,7 @@ public void SetReward(float reward)
/// set the reward assigned to the current step with a specific value rather than
/// increasing or decreasing it.
///
- /// Typically, you assign rewards in the Agent subclass's
+ /// Typically, you assign rewards in the Agent subclass's
/// implementation after carrying out the received action and evaluating its success.
///
/// Rewards are used during reinforcement learning; they are ignored during inference.
@@ -701,7 +725,7 @@ public void RequestDecision()
///
/// Call `RequestAction()` to repeat the previous action returned by the agent's
/// most recent decision. A new decision is not requested. When you call this function,
- /// the Agent instance invokes with the
+ /// the Agent instance invokes with the
/// existing action vector.
///
/// You can use `RequestAction()` in situations where an agent must take an action
@@ -728,16 +752,7 @@ public void RequestAction()
/// at the end of an episode.
void ResetData()
{
- var param = m_PolicyFactory.BrainParameters;
- m_ActionMasker = new DiscreteActionMasker(param);
- // If we haven't initialized vectorActions, initialize to 0. This should only
- // happen during the creation of the Agent. In subsequent episodes, vectorAction
- // should stay the previous action before the Done(), so that it is properly recorded.
- if (m_Action.vectorActions == null)
- {
- m_Action.vectorActions = new float[param.NumActions];
- m_Info.storedVectorActions = new float[param.NumActions];
- }
+ m_ActuatorManager?.ResetData();
}
///
@@ -765,11 +780,12 @@ public virtual void Initialize() {}
/// control of an agent using keyboard, mouse, or game controller input.
///
/// Your heuristic implementation can use any decision making logic you specify. Assign decision
- /// values to the float[] array, , passed to your function as a parameter.
+ /// values to the and
+ /// arrays , passed to your function as a parameter.
/// The same array will be reused between steps. It is up to the user to initialize
/// the values on each call, for example by calling `Array.Clear(actionsOut, 0, actionsOut.Length);`.
/// Add values to the array at the same indexes as they are used in your
- /// function, which receives this array and
+ /// function, which receives this array and
/// implements the corresponding agent behavior. See [Actions] for more information
/// about agent actions.
/// Note : Do not create a new float array of action in the `Heuristic()` method,
@@ -801,22 +817,39 @@ public virtual void Initialize() {}
/// You can also use the [Input System package], which provides a more flexible and
/// configurable input system.
///
- /// public override void Heuristic(float[] actionsOut)
+ /// public override void Heuristic(ActionBuffers actionsOut)
/// {
- /// actionsOut[0] = Input.GetAxis("Horizontal");
- /// actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
- /// actionsOut[2] = Input.GetAxis("Vertical");
+ /// actionsOut.ContinuousActions[0] = Input.GetAxis("Horizontal");
+ /// actionsOut.ContinuousActions[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
+ /// actionsOut.ContinuousActions[2] = Input.GetAxis("Vertical");
/// }
///
/// [Input Manager]: https://docs.unity3d.com/Manual/class-InputManager.html
/// [Input System package]: https://docs.unity3d.com/Packages/com.unity.inputsystem@1.0/manual/index.html
///
- /// Array for the output actions.
- ///
- public virtual void Heuristic(float[] actionsOut)
+ /// The which contain the continuous and
+ /// discrete action buffers to write to.
+ ///
+ public virtual void Heuristic(in ActionBuffers actionsOut)
{
- Debug.LogWarning("Heuristic method called but not implemented. Returning placeholder actions.");
- Array.Clear(actionsOut, 0, actionsOut.Length);
+ // For backward compatibility
+ switch (m_PolicyFactory.BrainParameters.VectorActionSpaceType)
+ {
+ case SpaceType.Continuous:
+ Heuristic(actionsOut.ContinuousActions.Array);
+ actionsOut.DiscreteActions.Clear();
+ break;
+ case SpaceType.Discrete:
+ var convertedOut = Array.ConvertAll(actionsOut.DiscreteActions.Array, x => (float)x);
+ Heuristic(convertedOut);
+ var discreteActionSegment = actionsOut.DiscreteActions;
+ for (var i = 0; i < actionsOut.DiscreteActions.Length; i++)
+ {
+ discreteActionSegment[i] = (int)convertedOut[i];
+ }
+ actionsOut.ContinuousActions.Clear();
+ break;
+ }
}
///
@@ -875,6 +908,7 @@ internal void InitializeSensors()
#if DEBUG
// Make sure the names are actually unique
+
for (var i = 0; i < sensors.Count - 1; i++)
{
Debug.Assert(
@@ -884,6 +918,32 @@ internal void InitializeSensors()
#endif
}
+ void InitializeActuators()
+ {
+ ActuatorComponent[] attachedActuators;
+ if (m_PolicyFactory.UseChildActuators)
+ {
+ attachedActuators = GetComponentsInChildren();
+ }
+ else
+ {
+ attachedActuators = GetComponents();
+ }
+
+ // Support legacy OnActionReceived
+ var param = m_PolicyFactory.BrainParameters;
+ m_VectorActuator = new VectorActuator(this, param.VectorActionSize, param.VectorActionSpaceType);
+ m_ActuatorManager = new ActuatorManager(attachedActuators.Length + 1);
+ m_LegacyActionCache = new float[m_VectorActuator.TotalNumberOfActions];
+
+ m_ActuatorManager.Add(m_VectorActuator);
+
+ foreach (var actuatorComponent in attachedActuators)
+ {
+ m_ActuatorManager.Add(actuatorComponent.CreateActuator());
+ }
+ }
+
///
/// Sends the Agent info to the linked Brain.
///
@@ -902,13 +962,13 @@ void SendInfoToBrain()
if (m_Info.done)
{
- Array.Clear(m_Info.storedVectorActions, 0, m_Info.storedVectorActions.Length);
+ m_Info.ClearActions();
}
else
{
- Array.Copy(m_Action.vectorActions, m_Info.storedVectorActions, m_Action.vectorActions.Length);
+ m_ActuatorManager.StoredActions.PackActions(m_Info.storedVectorActions);
}
- m_ActionMasker.ResetMask();
+
UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
{
@@ -916,13 +976,10 @@ void SendInfoToBrain()
}
using (TimerStack.Instance.Scoped("CollectDiscreteActionMasks"))
{
- if (m_PolicyFactory.BrainParameters.VectorActionSpaceType == SpaceType.Discrete)
- {
- CollectDiscreteActionMasks(m_ActionMasker);
- }
+ m_ActuatorManager.WriteActionMask();
}
- m_Info.discreteActionMasks = m_ActionMasker.GetMask();
+ m_Info.discreteActionMasks = m_ActuatorManager.DiscreteActionMask?.GetMask();
m_Info.reward = m_Reward;
m_Info.done = false;
m_Info.maxStepReached = false;
@@ -1005,7 +1062,7 @@ public virtual void CollectObservations(VectorSensor sensor)
///
/// Returns a read-only view of the observations that were generated in
/// . This is mainly useful inside of a
- /// method to avoid recomputing the observations.
+ /// method to avoid recomputing the observations.
///
/// A read-only view of the observations list.
public ReadOnlyCollection GetObservations()
@@ -1029,11 +1086,18 @@ public ReadOnlyCollection GetObservations()
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#actions
///
- ///
- public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
+ ///
+ public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
+ if (m_ActionMasker == null)
+ {
+ m_ActionMasker = new DiscreteActionMasker(actionMask);
+ }
+ CollectDiscreteActionMasks(m_ActionMasker);
}
+ ActionSpec IActionReceiver.ActionSpec { get; }
+
///
/// Implement `OnActionReceived()` to specify agent behavior at every step, based
/// on the provided action.
@@ -1049,7 +1113,7 @@ public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker
/// three values in the action array to use as the force components. During
/// training, the agent's policy learns to set those particular elements of
/// the array to maximize the training rewards the agent receives. (Of course,
- /// if you implement a function, it must use the same
+ /// if you implement a function, it must use the same
/// elements of the action array for the same purpose since there is no learning
/// involved.)
///
@@ -1099,12 +1163,14 @@ public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#actions
///
- ///
- /// An array containing the action vector. The length of the array is specified
- /// by the of the agent's associated
- /// component.
+ ///
+ /// Struct containing the buffers of actions to be executed at this step.
///
- public virtual void OnActionReceived(float[] vectorAction) {}
+ public virtual void OnActionReceived(ActionBuffers actions)
+ {
+ actions.PackActions(m_LegacyActionCache);
+ OnActionReceived(m_LegacyActionCache);
+ }
///
/// Implement `OnEpisodeBegin()` to set up an Agent instance at the beginning
@@ -1115,15 +1181,11 @@ public virtual void OnActionReceived(float[] vectorAction) {}
public virtual void OnEpisodeBegin() {}
///
- /// Returns the last action that was decided on by the Agent.
+ /// Gets the last ActionBuffer for this agent.
///
- ///
- /// The last action that was decided by the Agent (or null if no decision has been made).
- ///
- ///
- public float[] GetAction()
+ public ActionBuffers GetStoredContinuousActions()
{
- return m_Action.vectorActions;
+ return m_ActuatorManager.StoredActions;
}
///
@@ -1177,7 +1239,7 @@ void AgentStep()
if ((m_RequestAction) && (m_Brain != null))
{
m_RequestAction = false;
- OnActionReceived(m_Action.vectorActions);
+ m_ActuatorManager.ExecuteActions();
}
if ((m_StepCount >= MaxStep) && (MaxStep > 0))
@@ -1189,20 +1251,13 @@ void AgentStep()
void DecideAction()
{
- if (m_Action.vectorActions == null)
+ if (m_ActuatorManager.StoredActions.ContinuousActions.Array == null)
{
ResetData();
}
- var action = m_Brain?.DecideAction();
-
- if (action == null)
- {
- Array.Clear(m_Action.vectorActions, 0, m_Action.vectorActions.Length);
- }
- else
- {
- Array.Copy(action, m_Action.vectorActions, action.Length);
- }
+ var actions = m_Brain?.DecideAction() ?? new ActionBuffers();
+ m_Info.CopyActions(actions);
+ m_ActuatorManager.UpdateActions(actions);
}
}
}
diff --git a/com.unity.ml-agents/Runtime/Agent.deprecated.cs b/com.unity.ml-agents/Runtime/Agent.deprecated.cs
new file mode 100644
index 0000000000..8a09370957
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Agent.deprecated.cs
@@ -0,0 +1,38 @@
+using System;
+using UnityEngine;
+
+namespace Unity.MLAgents
+{
+ public partial class Agent
+ {
+ public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
+ {
+ }
+
+ ///
+ /// This method passes in a float array that is to be populated with actions. The actions
+ ///
+ ///
+ public virtual void Heuristic(float[] actionsOut)
+ {
+ Debug.LogWarning("Heuristic method called but not implemented. Returning placeholder actions.");
+ Array.Clear(actionsOut, 0, actionsOut.Length);
+ }
+
+ public virtual void OnActionReceived(float[] vectorAction) {}
+
+ ///
+ /// Returns the last action that was decided on by the Agent.
+ ///
+ ///
+ /// The last action that was decided by the Agent (or null if no decision has been made).
+ ///
+ ///
+ // [Obsolete("GetAction has been deprecated, please use GetStoredContinuousActions, Or GetStoredDiscreteActions.")]
+ public float[] GetAction()
+ {
+ return m_Info.storedVectorActions;
+ }
+
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Agent.deprecated.cs.meta b/com.unity.ml-agents/Runtime/Agent.deprecated.cs.meta
new file mode 100644
index 0000000000..767483f7f7
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Agent.deprecated.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 9650a482703b47db8cd7fb2df8caa1bf
+timeCreated: 1595613441
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
index e41353fc44..deb2bb891d 100644
--- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
+++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
@@ -175,20 +175,12 @@ public static UnityRLInitParameters ToUnityRLInitParameters(this UnityRLInitiali
}
#region AgentAction
- public static AgentAction ToAgentAction(this AgentActionProto aap)
+ public static List ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto)
{
- return new AgentAction
- {
- vectorActions = aap.VectorActions.ToArray()
- };
- }
-
- public static List ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto)
- {
- var agentActions = new List(proto.Value.Count);
+ var agentActions = new List(proto.Value.Count);
foreach (var ap in proto.Value)
{
- agentActions.Add(ap.ToAgentAction());
+ agentActions.Add(ap.VectorActions.ToArray());
}
return agentActions;
}
diff --git a/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
index 6c9ce6655b..afc4544242 100644
--- a/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
+++ b/com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
@@ -412,7 +412,7 @@ void SendBatchedMessageHelper()
var agentId = m_OrderedAgentsRequestingDecisions[brainName][i];
if (m_LastActionsReceived[brainName].ContainsKey(agentId))
{
- m_LastActionsReceived[brainName][agentId] = agentAction.vectorActions;
+ m_LastActionsReceived[brainName][agentId] = agentAction;
}
}
}
diff --git a/com.unity.ml-agents/Runtime/DecisionRequester.cs b/com.unity.ml-agents/Runtime/DecisionRequester.cs
index fabdaecc50..fc7cc55afd 100644
--- a/com.unity.ml-agents/Runtime/DecisionRequester.cs
+++ b/com.unity.ml-agents/Runtime/DecisionRequester.cs
@@ -26,7 +26,7 @@ public class DecisionRequester : MonoBehaviour
/// that the Agent will request a decision every 5 Academy steps. ///
[Range(1, 20)]
[Tooltip("The frequency with which the agent requests a decision. A DecisionPeriod " +
- "of 5 means that the Agent will request a decision every 5 Academy steps.")]
+ "of 5 means that the Agent will request a decision every 5 Academy steps.")]
public int DecisionPeriod = 5;
///
@@ -34,8 +34,8 @@ public class DecisionRequester : MonoBehaviour
/// it does not request a decision. Has no effect when DecisionPeriod is set to 1.
///
[Tooltip("Indicates whether or not the agent will take an action during the Academy " +
- "steps where it does not request a decision. Has no effect when DecisionPeriod " +
- "is set to 1.")]
+ "steps where it does not request a decision. Has no effect when DecisionPeriod " +
+ "is set to 1.")]
[FormerlySerializedAs("RepeatAction")]
public bool TakeActionsBetweenDecisions = true;
diff --git a/com.unity.ml-agents/Runtime/DiscreteActionMasker.cs b/com.unity.ml-agents/Runtime/DiscreteActionMasker.cs
index 1a9b322a98..c5a69a3c31 100644
--- a/com.unity.ml-agents/Runtime/DiscreteActionMasker.cs
+++ b/com.unity.ml-agents/Runtime/DiscreteActionMasker.cs
@@ -1,7 +1,6 @@
using System;
using System.Collections.Generic;
-using System.Linq;
-using Unity.MLAgents.Policies;
+using Unity.MLAgents.Actuators;
namespace Unity.MLAgents
{
@@ -15,19 +14,13 @@ namespace Unity.MLAgents
/// may be illegal. For example, if an agent is adjacent to a wall or other obstacle
/// you could mask any actions that direct the agent to move into the blocked space.
///
- public class DiscreteActionMasker
+ public class DiscreteActionMasker : IDiscreteActionMask
{
- /// When using discrete control, is the starting indices of the actions
- /// when all the branches are concatenated with each other.
- int[] m_StartingActionIndices;
+ IDiscreteActionMask m_Delegate;
- bool[] m_CurrentMask;
-
- readonly BrainParameters m_BrainParameters;
-
- internal DiscreteActionMasker(BrainParameters brainParameters)
+ internal DiscreteActionMasker(IDiscreteActionMask actionMask)
{
- m_BrainParameters = brainParameters;
+ m_Delegate = actionMask;
}
///
@@ -46,109 +39,22 @@ internal DiscreteActionMasker(BrainParameters brainParameters)
/// The indices of the masked actions.
public void SetMask(int branch, IEnumerable actionIndices)
{
- // If the branch does not exist, raise an error
- if (branch >= m_BrainParameters.VectorActionSize.Length)
- throw new UnityAgentsException(
- "Invalid Action Masking : Branch " + branch + " does not exist.");
-
- var totalNumberActions = m_BrainParameters.VectorActionSize.Sum();
-
- // By default, the masks are null. If we want to specify a new mask, we initialize
- // the actionMasks with trues.
- if (m_CurrentMask == null)
- {
- m_CurrentMask = new bool[totalNumberActions];
- }
-
- // If this is the first time the masked actions are used, we generate the starting
- // indices for each branch.
- if (m_StartingActionIndices == null)
- {
- m_StartingActionIndices = Utilities.CumSum(m_BrainParameters.VectorActionSize);
- }
-
- // Perform the masking
- foreach (var actionIndex in actionIndices)
- {
- if (actionIndex >= m_BrainParameters.VectorActionSize[branch])
- {
- throw new UnityAgentsException(
- "Invalid Action Masking: Action Mask is too large for specified branch.");
- }
- m_CurrentMask[actionIndex + m_StartingActionIndices[branch]] = true;
- }
- }
-
- ///
- /// Get the current mask for an agent.
- ///
- /// A mask for the agent. A boolean array of length equal to the total number of
- /// actions.
- internal bool[] GetMask()
- {
- if (m_CurrentMask != null)
- {
- AssertMask();
- }
- return m_CurrentMask;
+ m_Delegate.WriteMask(branch, actionIndices);
}
- ///
- /// Makes sure that the current mask is usable.
- ///
- void AssertMask()
+ public void WriteMask(int branch, IEnumerable actionIndices)
{
- // Action Masks can only be used in Discrete Control.
- if (m_BrainParameters.VectorActionSpaceType != SpaceType.Discrete)
- {
- throw new UnityAgentsException(
- "Invalid Action Masking : Can only set action mask for Discrete Control.");
- }
-
- var numBranches = m_BrainParameters.VectorActionSize.Length;
- for (var branchIndex = 0; branchIndex < numBranches; branchIndex++)
- {
- if (AreAllActionsMasked(branchIndex))
- {
- throw new UnityAgentsException(
- "Invalid Action Masking : All the actions of branch " + branchIndex +
- " are masked.");
- }
- }
+ m_Delegate.WriteMask(branch, actionIndices);
}
- ///
- /// Resets the current mask for an agent.
- ///
- internal void ResetMask()
+ public bool[] GetMask()
{
- if (m_CurrentMask != null)
- {
- Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length);
- }
+ return m_Delegate.GetMask();
}
- ///
- /// Checks if all the actions in the input branch are masked.
- ///
- /// The index of the branch to check.
- /// True if all the actions of the branch are masked.
- bool AreAllActionsMasked(int branch)
+ public void ResetMask()
{
- if (m_CurrentMask == null)
- {
- return false;
- }
- var start = m_StartingActionIndices[branch];
- var end = m_StartingActionIndices[branch + 1];
- for (var i = start; i < end; i++)
- {
- if (!m_CurrentMask[i])
- {
- return false;
- }
- }
- return true;
+ m_Delegate.ResetMask();
}
}
}
diff --git a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
index 6c6bcec4c1..b583e4aa39 100644
--- a/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
+++ b/com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
@@ -1,5 +1,7 @@
+using System;
using Unity.Barracuda;
using System.Collections.Generic;
+using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Sensors;
@@ -29,6 +31,7 @@ public enum InferenceDevice
internal class BarracudaPolicy : IPolicy
{
protected ModelRunner m_ModelRunner;
+ ActionBuffers m_LastActionBuffer;
int m_AgentId;
@@ -36,6 +39,7 @@ internal class BarracudaPolicy : IPolicy
/// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors.
///
List m_SensorShapes;
+ SpaceType m_SpaceType;
///
public BarracudaPolicy(
@@ -45,6 +49,7 @@ public BarracudaPolicy(
{
var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, brainParameters, inferenceDevice);
m_ModelRunner = modelRunner;
+ m_SpaceType = brainParameters.VectorActionSpaceType;
}
///
@@ -55,10 +60,18 @@ public void RequestDecision(AgentInfo info, List sensors)
}
///
- public float[] DecideAction()
+ public ref readonly ActionBuffers DecideAction()
{
m_ModelRunner?.DecideBatch();
- return m_ModelRunner?.GetAction(m_AgentId);
+ var actions = m_ModelRunner?.GetAction(m_AgentId);
+ if (m_SpaceType == SpaceType.Continuous)
+ {
+ m_LastActionBuffer = new ActionBuffers(actions, Array.Empty());
+ return ref m_LastActionBuffer;
+ }
+
+ m_LastActionBuffer = ActionBuffers.FromDiscreteActions(actions);
+ return ref m_LastActionBuffer;
}
public void Dispose()
diff --git a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
index b8d38f49d4..72b0038a9a 100644
--- a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
+++ b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
@@ -150,6 +150,11 @@ public string BehaviorName
[Tooltip("Use all Sensor components attached to child GameObjects of this Agent.")]
bool m_UseChildSensors = true;
+ [HideInInspector]
+ [SerializeField]
+ [Tooltip("Use all Actuator components attached to child GameObjects of this Agent.")]
+ bool m_UseChildActuators = true;
+
///
/// Whether or not to use all the sensor components attached to child GameObjects of the agent.
/// Note that changing this after the Agent has been initialized will not have any effect.
@@ -160,6 +165,16 @@ public bool UseChildSensors
set { m_UseChildSensors = value; }
}
+ ///
+ /// Whether or not to use all the actuator components attached to child GameObjects of the agent.
+ /// Note that changing this after the Agent has been initialized will not have any effect.
+ ///
+ public bool UseChildActuators
+ {
+ get { return m_UseChildActuators; }
+ set { m_UseChildActuators = value; }
+ }
+
[HideInInspector, SerializeField]
ObservableAttributeOptions m_ObservableAttributeHandling = ObservableAttributeOptions.Ignore;
@@ -185,7 +200,7 @@ internal IPolicy GeneratePolicy(HeuristicPolicy.ActionGenerator heuristic)
switch (m_BehaviorType)
{
case BehaviorType.HeuristicOnly:
- return new HeuristicPolicy(heuristic, m_BrainParameters.NumActions);
+ return GenerateHeuristicPolicy(heuristic);
case BehaviorType.InferenceOnly:
{
if (m_Model == null)
@@ -209,13 +224,29 @@ internal IPolicy GeneratePolicy(HeuristicPolicy.ActionGenerator heuristic)
}
else
{
- return new HeuristicPolicy(heuristic, m_BrainParameters.NumActions);
+ return GenerateHeuristicPolicy(heuristic);
}
default:
- return new HeuristicPolicy(heuristic, m_BrainParameters.NumActions);
+ return GenerateHeuristicPolicy(heuristic);
}
}
+ internal IPolicy GenerateHeuristicPolicy(HeuristicPolicy.ActionGenerator heuristic)
+ {
+ var numContinuousActions = 0;
+ var numDiscreteActions = 0;
+ if (m_BrainParameters.VectorActionSpaceType == SpaceType.Continuous)
+ {
+ numContinuousActions = m_BrainParameters.NumActions;
+ }
+ else if (m_BrainParameters.VectorActionSpaceType == SpaceType.Discrete)
+ {
+ numDiscreteActions = m_BrainParameters.NumActions;
+ }
+
+ return new HeuristicPolicy(heuristic, numContinuousActions, numDiscreteActions);
+ }
+
internal void UpdateAgentPolicy()
{
var agent = GetComponent();
diff --git a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
index 12d304e17f..232f459eb0 100644
--- a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
+++ b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
@@ -1,6 +1,7 @@
using System.Collections.Generic;
using System;
using System.Collections;
+using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Policies
@@ -12,9 +13,9 @@ namespace Unity.MLAgents.Policies
///
internal class HeuristicPolicy : IPolicy
{
- public delegate void ActionGenerator(float[] actionsOut);
+ public delegate void ActionGenerator(in ActionBuffers actionBuffers);
ActionGenerator m_Heuristic;
- float[] m_LastDecision;
+ ActionBuffers m_ActionBuffers;
bool m_Done;
bool m_DecisionRequested;
@@ -23,10 +24,12 @@ internal class HeuristicPolicy : IPolicy
///
- public HeuristicPolicy(ActionGenerator heuristic, int numActions)
+ public HeuristicPolicy(ActionGenerator heuristic, int numContinuousActions, int numDiscreteActions)
{
m_Heuristic = heuristic;
- m_LastDecision = new float[numActions];
+ var continuousDecision = new ActionSegment(new float[numContinuousActions], 0, numContinuousActions);
+ var discreteDecision = new ActionSegment(new int[numDiscreteActions], 0, numDiscreteActions);
+ m_ActionBuffers = new ActionBuffers(continuousDecision, discreteDecision);
}
///
@@ -35,18 +38,17 @@ public void RequestDecision(AgentInfo info, List sensors)
StepSensors(sensors);
m_Done = info.done;
m_DecisionRequested = true;
-
}
///
- public float[] DecideAction()
+ public ref readonly ActionBuffers DecideAction()
{
if (!m_Done && m_DecisionRequested)
{
- m_Heuristic.Invoke(m_LastDecision);
+ m_Heuristic.Invoke(m_ActionBuffers);
}
m_DecisionRequested = false;
- return m_LastDecision;
+ return ref m_ActionBuffers;
}
public void Dispose()
@@ -110,7 +112,7 @@ public void RemoveAt(int index)
public float this[int index]
{
get { return 0.0f; }
- set { }
+ set {}
}
}
diff --git a/com.unity.ml-agents/Runtime/Policies/IPolicy.cs b/com.unity.ml-agents/Runtime/Policies/IPolicy.cs
index 7203853589..4079a1f25a 100644
--- a/com.unity.ml-agents/Runtime/Policies/IPolicy.cs
+++ b/com.unity.ml-agents/Runtime/Policies/IPolicy.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Policies
@@ -26,6 +27,6 @@ internal interface IPolicy : IDisposable
/// it must be taken now. The Brain is expected to update the actions
/// of the Agents at this point the latest.
///
- float[] DecideAction();
+ ref readonly ActionBuffers DecideAction();
}
}
diff --git a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
index 2f88d37f53..f8c7c76eaa 100644
--- a/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
+++ b/com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
@@ -1,6 +1,7 @@
using UnityEngine;
using System.Collections.Generic;
using System;
+using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Policies
@@ -13,6 +14,8 @@ internal class RemotePolicy : IPolicy
{
int m_AgentId;
string m_FullyQualifiedBehaviorName;
+ SpaceType m_SpaceType;
+ ActionBuffers m_LastActionBuffer;
internal ICommunicator m_Communicator;
@@ -23,6 +26,7 @@ public RemotePolicy(
{
m_FullyQualifiedBehaviorName = fullyQualifiedBehaviorName;
m_Communicator = Academy.Instance.Communicator;
+ m_SpaceType = brainParameters.VectorActionSpaceType;
m_Communicator.SubscribeBrain(m_FullyQualifiedBehaviorName, brainParameters);
}
@@ -34,10 +38,17 @@ public void RequestDecision(AgentInfo info, List sensors)
}
///
- public float[] DecideAction()
+ public ref readonly ActionBuffers DecideAction()
{
m_Communicator?.DecideBatch();
- return m_Communicator?.GetActions(m_FullyQualifiedBehaviorName, m_AgentId);
+ var actions = m_Communicator?.GetActions(m_FullyQualifiedBehaviorName, m_AgentId);
+ if (m_SpaceType == SpaceType.Continuous)
+ {
+ m_LastActionBuffer = new ActionBuffers(actions, Array.Empty());
+ return ref m_LastActionBuffer;
+ }
+ m_LastActionBuffer = ActionBuffers.FromDiscreteActions(actions);
+ return ref m_LastActionBuffer;
}
public void Dispose()
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
index 974c1fd35d..c115ae9879 100644
--- a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
@@ -29,14 +29,14 @@ public void TestEnsureBufferSizeContinuous()
actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes,
actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions);
- manager.UpdateActions(new[]
- { 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f }, Array.Empty());
+ manager.UpdateActions(new ActionBuffers(new[]
+ { 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f }, Array.Empty()));
Assert.IsTrue(12 == manager.NumContinuousActions);
Assert.IsTrue(0 == manager.NumDiscreteActions);
Assert.IsTrue(0 == manager.SumOfDiscreteBranchSizes);
- Assert.IsTrue(12 == manager.StoredContinuousActions.Length);
- Assert.IsTrue(0 == manager.StoredDiscreteActions.Length);
+ Assert.IsTrue(12 == manager.StoredActions.ContinuousActions.Length);
+ Assert.IsTrue(0 == manager.StoredActions.DiscreteActions.Length);
}
[Test]
@@ -54,14 +54,14 @@ public void TestEnsureBufferDiscrete()
actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes,
actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions);
- manager.UpdateActions(Array.Empty(),
- new[] { 0, 1, 2, 3, 4, 5, 6});
+ manager.UpdateActions(new ActionBuffers(Array.Empty(),
+ new[] { 0, 1, 2, 3, 4, 5, 6}));
Assert.IsTrue(0 == manager.NumContinuousActions);
Assert.IsTrue(7 == manager.NumDiscreteActions);
Assert.IsTrue(13 == manager.SumOfDiscreteBranchSizes);
- Assert.IsTrue(0 == manager.StoredContinuousActions.Length);
- Assert.IsTrue(7 == manager.StoredDiscreteActions.Length);
+ Assert.IsTrue(0 == manager.StoredActions.ContinuousActions.Length);
+ Assert.IsTrue(7 == manager.StoredActions.DiscreteActions.Length);
}
[Test]
@@ -98,8 +98,8 @@ public void TestExecuteActionsDiscrete()
manager.Add(actuator2);
var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5, 6};
- manager.UpdateActions(Array.Empty(),
- discreteActionBuffer);
+ manager.UpdateActions(new ActionBuffers(Array.Empty(),
+ discreteActionBuffer));
manager.ExecuteActions();
var actuator1Actions = actuator1.LastActionBuffer.DiscreteActions;
@@ -118,8 +118,8 @@ public void TestExecuteActionsContinuous()
manager.Add(actuator2);
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
- manager.UpdateActions(continuousActionBuffer,
- Array.Empty());
+ manager.UpdateActions(new ActionBuffers(continuousActionBuffer,
+ Array.Empty()));
manager.ExecuteActions();
var actuator1Actions = actuator1.LastActionBuffer.ContinuousActions;
@@ -149,10 +149,10 @@ public void TestUpdateActionsContinuous()
manager.Add(actuator1);
manager.Add(actuator2);
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
- manager.UpdateActions(continuousActionBuffer,
- Array.Empty());
+ manager.UpdateActions(new ActionBuffers(continuousActionBuffer,
+ Array.Empty()));
- Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer));
+ Assert.IsTrue(manager.StoredActions.ContinuousActions.SequenceEqual(continuousActionBuffer));
}
[Test]
@@ -165,12 +165,12 @@ public void TestUpdateActionsDiscrete()
manager.Add(actuator1);
manager.Add(actuator2);
var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5};
- manager.UpdateActions(Array.Empty(),
- discreteActionBuffer);
+ manager.UpdateActions(new ActionBuffers(Array.Empty(),
+ discreteActionBuffer));
- Debug.Log(manager.StoredDiscreteActions);
+ Debug.Log(manager.StoredActions.DiscreteActions);
Debug.Log(discreteActionBuffer);
- Assert.IsTrue(manager.StoredDiscreteActions.SequenceEqual(discreteActionBuffer));
+ Assert.IsTrue(manager.StoredActions.DiscreteActions.SequenceEqual(discreteActionBuffer));
}
[Test]
@@ -261,14 +261,14 @@ public void TestResetData()
manager.Add(actuator1);
manager.Add(actuator2);
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
- manager.UpdateActions(continuousActionBuffer,
- Array.Empty());
+ manager.UpdateActions(new ActionBuffers(continuousActionBuffer,
+ Array.Empty()));
- Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer));
+ Assert.IsTrue(manager.StoredActions.ContinuousActions.SequenceEqual(continuousActionBuffer));
Assert.IsTrue(manager.NumContinuousActions == 6);
manager.ResetData();
- Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(new[] { 0f, 0f, 0f, 0f, 0f, 0f}));
+ Assert.IsTrue(manager.StoredActions.ContinuousActions.SequenceEqual(new[] { 0f, 0f, 0f, 0f, 0f, 0f}));
}
[Test]
diff --git a/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs
index 7364e3173a..aa0a87ed32 100644
--- a/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs
+++ b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs
@@ -1,4 +1,5 @@
using NUnit.Framework;
+using Unity.MLAgents.Actuators;
using UnityEngine;
using Unity.MLAgents.Policies;
@@ -7,7 +8,7 @@ namespace Unity.MLAgents.Tests
[TestFixture]
public class BehaviorParameterTests
{
- static void DummyHeuristic(float[] actionsOut)
+ static void DummyHeuristic(in ActionBuffers actionsOut)
{
// No-op
}
diff --git a/com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs b/com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs
deleted file mode 100644
index 0f5c24edbd..0000000000
--- a/com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs
+++ /dev/null
@@ -1,142 +0,0 @@
-using NUnit.Framework;
-using Unity.MLAgents.Policies;
-
-namespace Unity.MLAgents.Tests
-{
- public class EditModeTestActionMasker
- {
- [Test]
- public void Contruction()
- {
- var bp = new BrainParameters();
- var masker = new DiscreteActionMasker(bp);
- Assert.IsNotNull(masker);
- }
-
- [Test]
- public void FailsWithContinuous()
- {
- var bp = new BrainParameters();
- bp.VectorActionSpaceType = SpaceType.Continuous;
- bp.VectorActionSize = new[] {4};
- var masker = new DiscreteActionMasker(bp);
- masker.SetMask(0, new[] {0});
- Assert.Catch(() => masker.GetMask());
- }
-
- [Test]
- public void NullMask()
- {
- var bp = new BrainParameters();
- bp.VectorActionSpaceType = SpaceType.Discrete;
- var masker = new DiscreteActionMasker(bp);
- var mask = masker.GetMask();
- Assert.IsNull(mask);
- }
-
- [Test]
- public void FirstBranchMask()
- {
- var bp = new BrainParameters();
- bp.VectorActionSpaceType = SpaceType.Discrete;
- bp.VectorActionSize = new[] {4, 5, 6};
- var masker = new DiscreteActionMasker(bp);
- var mask = masker.GetMask();
- Assert.IsNull(mask);
- masker.SetMask(0, new[] {1, 2, 3});
- mask = masker.GetMask();
- Assert.IsFalse(mask[0]);
- Assert.IsTrue(mask[1]);
- Assert.IsTrue(mask[2]);
- Assert.IsTrue(mask[3]);
- Assert.IsFalse(mask[4]);
- Assert.AreEqual(mask.Length, 15);
- }
-
- [Test]
- public void SecondBranchMask()
- {
- var bp = new BrainParameters
- {
- VectorActionSpaceType = SpaceType.Discrete,
- VectorActionSize = new[] { 4, 5, 6 }
- };
- var masker = new DiscreteActionMasker(bp);
- masker.SetMask(1, new[] {1, 2, 3});
- var mask = masker.GetMask();
- Assert.IsFalse(mask[0]);
- Assert.IsFalse(mask[4]);
- Assert.IsTrue(mask[5]);
- Assert.IsTrue(mask[6]);
- Assert.IsTrue(mask[7]);
- Assert.IsFalse(mask[8]);
- Assert.IsFalse(mask[9]);
- }
-
- [Test]
- public void MaskReset()
- {
- var bp = new BrainParameters
- {
- VectorActionSpaceType = SpaceType.Discrete,
- VectorActionSize = new[] { 4, 5, 6 }
- };
- var masker = new DiscreteActionMasker(bp);
- masker.SetMask(1, new[] {1, 2, 3});
- masker.ResetMask();
- var mask = masker.GetMask();
- for (var i = 0; i < 15; i++)
- {
- Assert.IsFalse(mask[i]);
- }
- }
-
- [Test]
- public void ThrowsError()
- {
- var bp = new BrainParameters
- {
- VectorActionSpaceType = SpaceType.Discrete,
- VectorActionSize = new[] { 4, 5, 6 }
- };
- var masker = new DiscreteActionMasker(bp);
-
- Assert.Catch(
- () => masker.SetMask(0, new[] {5}));
- Assert.Catch(
- () => masker.SetMask(1, new[] {5}));
- masker.SetMask(2, new[] {5});
- Assert.Catch(
- () => masker.SetMask(3, new[] {1}));
- masker.GetMask();
- masker.ResetMask();
- masker.SetMask(0, new[] {0, 1, 2, 3});
- Assert.Catch(
- () => masker.GetMask());
- }
-
- [Test]
- public void MultipleMaskEdit()
- {
- var bp = new BrainParameters();
- bp.VectorActionSpaceType = SpaceType.Discrete;
- bp.VectorActionSize = new[] {4, 5, 6};
- var masker = new DiscreteActionMasker(bp);
- masker.SetMask(0, new[] {0, 1});
- masker.SetMask(0, new[] {3});
- masker.SetMask(2, new[] {1});
- var mask = masker.GetMask();
- for (var i = 0; i < 15; i++)
- {
- if ((i == 0) || (i == 1) || (i == 3) || (i == 10))
- {
- Assert.IsTrue(mask[i]);
- }
- else
- {
- Assert.IsFalse(mask[i]);
- }
- }
- }
- }
-}
diff --git a/com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs.meta b/com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs.meta
deleted file mode 100644
index 1325856f78..0000000000
--- a/com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs.meta
+++ /dev/null
@@ -1,11 +0,0 @@
-fileFormatVersion: 2
-guid: 2e2810ee6c8c64fb39abdf04b5d17f50
-MonoImporter:
- externalObjects: {}
- serializedVersion: 2
- defaultReferences: []
- executionOrder: 0
- icon: {instanceID: 0}
- userData:
- assetBundleName:
- assetBundleVariant:
diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
index 28a7bc3368..b51b8fc2e6 100644
--- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
+++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
@@ -3,6 +3,7 @@
using NUnit.Framework;
using System.Reflection;
using System.Collections.Generic;
+using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using Unity.MLAgents.Policies;
@@ -14,6 +15,7 @@ internal class TestPolicy : IPolicy
{
public Action OnRequestDecision;
ObservationWriter m_ObsWriter = new ObservationWriter();
+ static ActionBuffers s_EmptyActionBuffers = new ActionBuffers(Array.Empty(), Array.Empty());
public void RequestDecision(AgentInfo info, List sensors)
{
foreach (var sensor in sensors)
@@ -23,7 +25,7 @@ public void RequestDecision(AgentInfo info, List sensors)
OnRequestDecision?.Invoke();
}
- public float[] DecideAction() { return new float[0]; }
+ public ref readonly ActionBuffers DecideAction() { return ref s_EmptyActionBuffers; }
public void Dispose() {}
}