Skip to content

Integrate IActuators into ML-Agents core code. #4315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Unity.MLAgents.Actuators
/// the offset into the original array, and an length.
/// </summary>
/// <typeparam name="T">The type of object stored in the underlying <see cref="Array"/></typeparam>
internal readonly struct ActionSegment<T> : IEnumerable<T>, IEquatable<ActionSegment<T>>
public readonly struct ActionSegment<T> : IEnumerable<T>, IEquatable<ActionSegment<T>>
where T : struct
{
/// <summary>
Expand All @@ -30,18 +30,25 @@ namespace Unity.MLAgents.Actuators
/// </summary>
public static ActionSegment<T> Empty = new ActionSegment<T>(System.Array.Empty<T>(), 0, 0);

static void CheckParameters(T[] actionArray, int offset, int length)
static void CheckParameters(IReadOnlyCollection<T> actionArray, int offset, int length)
{
#if DEBUG
if (offset + length > actionArray.Length)
if (offset + length > actionArray.Count)
{
throw new ArgumentOutOfRangeException(nameof(offset),
$"Arguments offset: {offset} and length: {length} " +
$"are out of bounds of actionArray: {actionArray.Length}.");
$"are out of bounds of actionArray: {actionArray.Count}.");
}
#endif
}

/// <summary>
/// Construct an <see cref="ActionSegment{T}"/> with just an actionArray. The <see cref="Offset"/> will
/// be set to 0 and the <see cref="Length"/> will be set to `actionArray.Length`.
/// </summary>
/// <param name="actionArray">The action array to use for the this segment.</param>
public ActionSegment(T[] actionArray) : this(actionArray, 0, actionArray.Length) { }

/// <summary>
/// Construct an <see cref="ActionSegment{T}"/> with an underlying array
/// and offset, and a length.
Expand All @@ -51,7 +58,9 @@ static void CheckParameters(T[] actionArray, int offset, int length)
/// <param name="length">The length of the segment.</param>
public ActionSegment(T[] actionArray, int offset, int length)
{
#if DEBUG
CheckParameters(actionArray, offset, length);
#endif
Array = actionArray;
Offset = offset;
Length = length;
Expand All @@ -78,6 +87,22 @@ public T this[int index]
}
return Array[Offset + index];
}
set
{
if (index < 0 || index > Length)
{
throw new IndexOutOfRangeException($"Index out of bounds, expected a number between 0 and {Length}");
}
Array[Offset + index] = value;
}
}

/// <summary>
/// Sets the segment of the backing array to all zeros.
/// </summary>
public void Clear()
{
System.Array.Clear(Array, Offset, Length);
}

/// <inheritdoc cref="IEnumerable{T}.GetEnumerator"/>
Expand Down
2 changes: 1 addition & 1 deletion com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Unity.MLAgents.Actuators
/// <summary>
/// Defines the structure of an Action Space to be used by the Actuator system.
/// </summary>
internal readonly struct ActionSpec
public readonly struct ActionSpec
{

/// <summary>
Expand Down
51 changes: 32 additions & 19 deletions com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ internal class ActuatorManager : IList<IActuator>
/// <summary>
/// Returns the previously stored actions for the actuators in this list.
/// </summary>
public float[] StoredContinuousActions { get; private set; }
// public float[] StoredContinuousActions { get; private set; }

/// <summary>
/// Returns the previously stored actions for the actuators in this list.
/// </summary>
public int[] StoredDiscreteActions { get; private set; }
// public int[] StoredDiscreteActions { get; private set; }

public ActionBuffers StoredActions { get; private set; }

/// <summary>
/// Create an ActuatorList with a preset capacity.
Expand Down Expand Up @@ -99,8 +101,11 @@ internal void ReadyActuatorsForExecution(IList<IActuator> actuators, int numCont

// Sort the Actuators by name to ensure determinism
SortActuators();
StoredContinuousActions = numContinuousActions == 0 ? Array.Empty<float>() : new float[numContinuousActions];
StoredDiscreteActions = numDiscreteBranches == 0 ? Array.Empty<int>() : new int[numDiscreteBranches];
var continuousActions = numContinuousActions == 0 ? ActionSegment<float>.Empty :
new ActionSegment<float>(new float[numContinuousActions]);
var discreteActions = numDiscreteBranches == 0 ? ActionSegment<int>.Empty : new ActionSegment<int>(new int[numDiscreteBranches]);

StoredActions = new ActionBuffers(continuousActions, discreteActions);
m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches);
m_ReadyForExecution = true;
}
Expand All @@ -113,26 +118,31 @@ internal void ReadyActuatorsForExecution(IList<IActuator> actuators, int numCont
/// continuous actions for the IActuators in this list.</param>
/// <param name="discreteActionBuffer">The action buffer which contains all of the
/// discrete actions for the IActuators in this list.</param>
public void UpdateActions(float[] continuousActionBuffer, int[] discreteActionBuffer)
public void UpdateActions(ActionBuffers actions)
{
ReadyActuatorsForExecution();
UpdateActionArray(continuousActionBuffer, StoredContinuousActions);
UpdateActionArray(discreteActionBuffer, StoredDiscreteActions);
UpdateActionArray(actions.ContinuousActions, StoredActions.ContinuousActions);
UpdateActionArray(actions.DiscreteActions, StoredActions.DiscreteActions);
}

static void UpdateActionArray<T>(T[] sourceActionBuffer, T[] destination)
static void UpdateActionArray<T>(ActionSegment<T> sourceActionBuffer, ActionSegment<T> destination)
where T : struct
{
if (sourceActionBuffer == null || sourceActionBuffer.Length == 0)
if (sourceActionBuffer.Length <= 0)
{
Array.Clear(destination, 0, destination.Length);
destination.Clear();
}
else
{
Debug.Assert(sourceActionBuffer.Length == destination.Length,
$"sourceActionBuffer:{sourceActionBuffer.Length} is a different" +
$" size than destination: {destination.Length}.");

Array.Copy(sourceActionBuffer, destination, destination.Length);
Array.Copy(sourceActionBuffer.Array,
sourceActionBuffer.Offset,
destination.Array,
destination.Offset,
destination.Length);
}
}

Expand All @@ -148,9 +158,12 @@ public void WriteActionMask()
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
m_DiscreteActionMask.CurrentBranchOffset = offset;
actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
offset += actuator.ActionSpec.NumDiscreteActions;
if (actuator.ActionSpec.NumDiscreteActions > 0)
{
m_DiscreteActionMask.CurrentBranchOffset = offset;
actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
offset += actuator.ActionSpec.NumDiscreteActions;
}
}
}

Expand All @@ -173,15 +186,15 @@ public void ExecuteActions()
var continuousActions = ActionSegment<float>.Empty;
if (numContinuousActions > 0)
{
continuousActions = new ActionSegment<float>(StoredContinuousActions,
continuousActions = new ActionSegment<float>(StoredActions.ContinuousActions.Array,
continuousStart,
numContinuousActions);
}

var discreteActions = ActionSegment<int>.Empty;
if (numDiscreteActions > 0)
{
discreteActions = new ActionSegment<int>(StoredDiscreteActions,
discreteActions = new ActionSegment<int>(StoredActions.DiscreteActions.Array,
discreteStart,
numDiscreteActions);
}
Expand All @@ -193,7 +206,7 @@ public void ExecuteActions()
}

/// <summary>
/// Resets the <see cref="StoredContinuousActions"/> and <see cref="StoredDiscreteActions"/> buffers to be all
/// Resets the <see cref="ActionBuffers"/> to be all
/// zeros and calls <see cref="IActuator.ResetData"/> on each <see cref="IActuator"/> managed by this object.
/// </summary>
public void ResetData()
Expand All @@ -202,12 +215,12 @@ public void ResetData()
{
return;
}
Array.Clear(StoredContinuousActions, 0, StoredContinuousActions.Length);
Array.Clear(StoredDiscreteActions, 0, StoredDiscreteActions.Length);
StoredActions.Clear();
for (var i = 0; i < m_Actuators.Count; i++)
{
m_Actuators[i].ResetData();
}
m_DiscreteActionMask.ResetMask();
}


Expand Down
72 changes: 70 additions & 2 deletions com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
using System;
using System.Linq;
using UnityEngine;

namespace Unity.MLAgents.Actuators
{
/// <summary>
/// A structure that wraps the <see cref="ActionSegment{T}"/>s for a particular <see cref="IActionReceiver"/> and is
/// used when <see cref="IActionReceiver.OnActionReceived"/> is called.
/// </summary>
internal readonly struct ActionBuffers
public readonly struct ActionBuffers
{
/// <summary>
/// An empty action buffer.
Expand All @@ -24,6 +25,24 @@ internal readonly struct ActionBuffers
/// </summary>
public ActionSegment<int> DiscreteActions { get; }

/// <summary>
/// Create an <see cref="ActionBuffers"/> instance with discrete actions stored as a float array. This exists
/// to achieve backward compatibility with the former Agent methods which used a float array for both continuous
/// and discrete actions.
/// </summary>
/// <param name="discreteActions">The float array of discrete actions.</param>
/// <returns>An <see cref="ActionBuffers"/> instance initialized with a <see cref="DiscreteActions"/>
/// <see cref="ActionSegment{T}"/> initialized from a float array.</returns>
public static ActionBuffers FromDiscreteActions(float[] discreteActions)
{
return new ActionBuffers(ActionSegment<float>.Empty, discreteActions == null ? ActionSegment<int>.Empty
: new ActionSegment<int>(Array.ConvertAll(discreteActions,
x => (int)x)));
}

public ActionBuffers(float[] continuousActions, int[] discreteActions)
: this(new ActionSegment<float>(continuousActions), new ActionSegment<int>(discreteActions)) { }

/// <summary>
/// Construct an <see cref="ActionBuffers"/> instance with the continuous and discrete actions that will
/// be used.
Expand All @@ -36,6 +55,15 @@ public ActionBuffers(ActionSegment<float> continuousActions, ActionSegment<int>
DiscreteActions = discreteActions;
}

/// <summary>
/// Clear the <see cref="ContinuousActions"/> and <see cref="DiscreteActions"/> segments to be all zeros.
/// </summary>
public void Clear()
{
ContinuousActions.Clear();
DiscreteActions.Clear();
}

/// <inheritdoc cref="ValueType.Equals(object)"/>
public override bool Equals(object obj)
{
Expand All @@ -57,12 +85,52 @@ public override int GetHashCode()
return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode();
}
}

/// <summary>
/// Packs the continuous and discrete actions into one float array. The array passed into this method
/// must have a Length that is greater than or equal to the sum of the Lengths of
/// <see cref="ContinuousActions"/> and <see cref="DiscreteActions"/>.
/// </summary>
/// <param name="destination">A float array to pack actions into whose length is greater than or
/// equal to the addition of the Lengths of this objects <see cref="ContinuousActions"/> and
/// <see cref="DiscreteActions"/> segments.</param>
public void PackActions(in float[] destination)
{
Debug.Assert(destination.Length >= ContinuousActions.Length + DiscreteActions.Length,
$"argument '{nameof(destination)}' is not large enough to pack the actions into.\n" +
$"{nameof(destination)}.Length: {destination.Length}\n" +
$"{nameof(ContinuousActions)}.Length + {nameof(DiscreteActions)}.Length: {ContinuousActions.Length + DiscreteActions.Length}");

var start = 0;
if (ContinuousActions.Length > 0)
{
Array.Copy(ContinuousActions.Array,
ContinuousActions.Offset,
destination,
start,
ContinuousActions.Length);
start = ContinuousActions.Length;
}
if (start >= destination.Length)
{
return;
}

if (DiscreteActions.Length > 0)
{
Array.Copy(DiscreteActions.Array,
DiscreteActions.Offset,
destination,
start,
DiscreteActions.Length);
}
}
}

/// <summary>
/// An interface that describes an object that can receive actions from a Reinforcement Learning network.
/// </summary>
internal interface IActionReceiver
public interface IActionReceiver
{

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion com.unity.ml-agents/Runtime/Actuators/IActuator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace Unity.MLAgents.Actuators
/// <summary>
/// Abstraction that facilitates the execution of actions.
/// </summary>
internal interface IActuator : IActionReceiver
public interface IActuator : IActionReceiver
{
int TotalNumberOfActions { get; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace Unity.MLAgents.Actuators
/// <summary>
/// Interface for writing a mask to disable discrete actions for agents for the next decision.
/// </summary>
internal interface IDiscreteActionMask
public interface IDiscreteActionMask
{
/// <summary>
/// Modifies an action mask for discrete control agents.
Expand Down
2 changes: 1 addition & 1 deletion com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace Unity.MLAgents.Actuators
{
internal class VectorActuator : IActuator
public class VectorActuator : IActuator
{
IActionReceiver m_ActionReceiver;

Expand Down
Loading