From f8a54c38b7af5a829449d4de5d83ef15c04094ba Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Fri, 7 Feb 2020 10:48:27 -0800 Subject: [PATCH 1/4] Sentencing Action masking the same as observations I am rather unsure about the doubling of the CollectObservation methods (and the copy pasta that comes along) Need to edit the documentation and the migrating doc once we agree we want to do this --- .../Examples/GridWorld/Scripts/GridAgent.cs | 14 +-- com.unity.ml-agents/Runtime/ActionMasker.cs | 39 ++++++++- com.unity.ml-agents/Runtime/Agent.cs | 86 +++++++++---------- 3 files changed, 85 insertions(+), 54 deletions(-) diff --git a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs index 7fe523e3b9..aa3b712edd 100644 --- a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs +++ b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs @@ -31,7 +31,7 @@ public override void InitializeAgent() { } - public override void CollectObservations(VectorSensor sensor) + public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker) { // There are no numeric observations to collect as this environment uses visual // observations. @@ -39,14 +39,14 @@ public override void CollectObservations(VectorSensor sensor) // Mask the necessary actions if selected by the user. if (maskActions) { - SetMask(); + SetMask(actionMasker); } } /// /// Applies the mask for the agents action to disallow unnecessary actions. /// - void SetMask() + void SetMask(ActionMasker actionMasker) { // Prevents the agent from picking an action that would make it collide with a wall var positionX = (int)transform.position.x; @@ -55,22 +55,22 @@ void SetMask() if (positionX == 0) { - SetActionMask(k_Left); + actionMasker.SetActionMask(k_Left); } if (positionX == maxPosition) { - SetActionMask(k_Right); + actionMasker.SetActionMask(k_Right); } if (positionZ == 0) { - SetActionMask(k_Down); + actionMasker.SetActionMask(k_Down); } if (positionZ == maxPosition) { - SetActionMask(k_Up); + actionMasker.SetActionMask(k_Up); } } diff --git a/com.unity.ml-agents/Runtime/ActionMasker.cs b/com.unity.ml-agents/Runtime/ActionMasker.cs index 25f36bac97..ddbd7d87f8 100644 --- a/com.unity.ml-agents/Runtime/ActionMasker.cs +++ b/com.unity.ml-agents/Runtime/ActionMasker.cs @@ -4,7 +4,7 @@ namespace MLAgents { - internal class ActionMasker + public class ActionMasker { /// When using discrete control, is the starting indices of the actions /// when all the branches are concatenated with each other. @@ -19,6 +19,43 @@ internal ActionMasker(BrainParameters brainParameters) m_BrainParameters = brainParameters; } + /// + /// Sets an action mask for discrete control agents. When used, the agent will not be + /// able to perform the action passed as argument at the next decision. If no branch is + /// specified, the default branch will be 0. The actionIndex or actionIndices correspond + /// to the action the agent will be unable to perform. + /// + /// The indices of the masked actions on branch 0 + protected void SetActionMask(IEnumerable actionIndices) + { + SetActionMask(0, actionIndices); + } + + /// + /// Sets an action mask for discrete control agents. When used, the agent will not be + /// able to perform the action passed as argument at the next decision. If no branch is + /// specified, the default branch will be 0. The actionIndex or actionIndices correspond + /// to the action the agent will be unable to perform. + /// + /// The branch for which the actions will be masked + /// The index of the masked action + protected void SetActionMask(int branch, int actionIndex) + { + SetActionMask(branch, new[] { actionIndex }); + } + + /// + /// Sets an action mask for discrete control agents. When used, the agent will not be + /// able to perform the action passed as argument at the next decision. If no branch is + /// specified, the default branch will be 0. The actionIndex or actionIndices correspond + /// to the action the agent will be unable to perform. + /// + /// The index of the masked action on branch 0 + public void SetActionMask(int actionIndex) + { + SetActionMask(0, new[] { actionIndex }); + } + /// /// Modifies an action mask for discrete control agents. When used, the agent will not be /// able to perform the action passed as argument at the next decision. If no branch is diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 65e477b999..3a485e6d1d 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -481,7 +481,7 @@ void SendInfoToBrain() UpdateSensors(); using (TimerStack.Instance.Scoped("CollectObservations")) { - CollectObservations(collectObservationsSensor); + CollectObservations(collectObservationsSensor, m_ActionMasker); } m_Info.actionMasks = m_ActionMasker.GetMask(); @@ -544,53 +544,47 @@ public virtual void CollectObservations(VectorSensor sensor) } /// - /// Sets an action mask for discrete control agents. When used, the agent will not be - /// able to perform the action passed as argument at the next decision. If no branch is - /// specified, the default branch will be 0. The actionIndex or actionIndices correspond - /// to the action the agent will be unable to perform. - /// - /// The indices of the masked actions on branch 0 - protected void SetActionMask(IEnumerable actionIndices) - { - m_ActionMasker.SetActionMask(0, actionIndices); - } - - /// - /// Sets an action mask for discrete control agents. When used, the agent will not be - /// able to perform the action passed as argument at the next decision. If no branch is - /// specified, the default branch will be 0. The actionIndex or actionIndices correspond - /// to the action the agent will be unable to perform. - /// - /// The index of the masked action on branch 0 - protected void SetActionMask(int actionIndex) - { - m_ActionMasker.SetActionMask(0, new[] { actionIndex }); - } - - /// - /// Sets an action mask for discrete control agents. When used, the agent will not be - /// able to perform the action passed as argument at the next decision. If no branch is - /// specified, the default branch will be 0. The actionIndex or actionIndices correspond - /// to the action the agent will be unable to perform. - /// - /// The branch for which the actions will be masked - /// The index of the masked action - protected void SetActionMask(int branch, int actionIndex) - { - m_ActionMasker.SetActionMask(branch, new[] { actionIndex }); - } - - /// - /// Modifies an action mask for discrete control agents. When used, the agent will not be - /// able to perform the action passed as argument at the next decision. If no branch is - /// specified, the default branch will be 0. The actionIndex or actionIndices correspond - /// to the action the agent will be unable to perform. + /// Collects the vector observations of the agent. + /// The agent observation describes the current environment from the + /// perspective of the agent. /// - /// The branch for which the actions will be masked - /// The indices of the masked actions - protected void SetActionMask(int branch, IEnumerable actionIndices) + /// + /// An agents observation is any environment information that helps + /// the Agent achieve its goal. For example, for a fighting Agent, its + /// observation could include distances to friends or enemies, or the + /// current level of ammunition at its disposal. + /// Recall that an Agent may attach vector or visual observations. + /// Vector observations are added by calling the provided helper methods + /// on the VectorSensor input: + /// - + /// - + /// - + /// - + /// - + /// AddVectorObs(float[]) + /// + /// - + /// AddVectorObs(List{float}) + /// + /// - + /// - + /// - + /// Depending on your environment, any combination of these helpers can + /// be used. They just need to be used in the exact same order each time + /// this method is called and the resulting size of the vector observation + /// needs to match the vectorObservationSize attribute of the linked Brain. + /// Visual observations are implicitly added from the cameras attached to + /// the Agent. + /// When using Discrete Control, you can prevent the Agent from using a certain + /// action by masking it. You can call the following method on the ActionMasker + /// input : + /// - + /// The branch input is the index of the action, actionIndices are the indices of the + /// invalid options for that action. + /// + public virtual void CollectObservations(VectorSensor sensor, ActionMasker actionMasker) { - m_ActionMasker.SetActionMask(branch, actionIndices); + CollectObservations(sensor); } /// From 4fd28531df39fe6829ecb54a4350f5169f188917 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Fri, 7 Feb 2020 14:24:40 -0800 Subject: [PATCH 2/4] Addressing the comments --- com.unity.ml-agents/Runtime/ActionMasker.cs | 8 ++++---- com.unity.ml-agents/Runtime/Agent.cs | 15 +++------------ 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/com.unity.ml-agents/Runtime/ActionMasker.cs b/com.unity.ml-agents/Runtime/ActionMasker.cs index ddbd7d87f8..bc47685295 100644 --- a/com.unity.ml-agents/Runtime/ActionMasker.cs +++ b/com.unity.ml-agents/Runtime/ActionMasker.cs @@ -26,7 +26,7 @@ internal ActionMasker(BrainParameters brainParameters) /// to the action the agent will be unable to perform. /// /// The indices of the masked actions on branch 0 - protected void SetActionMask(IEnumerable actionIndices) + public void SetActionMask(IEnumerable actionIndices) { SetActionMask(0, actionIndices); } @@ -39,7 +39,7 @@ protected void SetActionMask(IEnumerable actionIndices) /// /// The branch for which the actions will be masked /// The index of the masked action - protected void SetActionMask(int branch, int actionIndex) + public void SetActionMask(int branch, int actionIndex) { SetActionMask(branch, new[] { actionIndex }); } @@ -104,7 +104,7 @@ public void SetActionMask(int branch, IEnumerable actionIndices) /// /// A mask for the agent. A boolean array of length equal to the total number of /// actions. - public bool[] GetMask() + internal bool[] GetMask() { if (m_CurrentMask != null) { @@ -140,7 +140,7 @@ void AssertMask() /// /// Resets the current mask for an agent /// - public void ResetMask() + internal void ResetMask() { if (m_CurrentMask != null) { diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 3a485e6d1d..08dc8c447a 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -523,12 +523,6 @@ void UpdateSensors() /// - /// - /// - - /// - - /// AddVectorObs(float[]) - /// - /// - - /// AddVectorObs(List{float}) - /// /// - /// - /// - @@ -560,12 +554,6 @@ public virtual void CollectObservations(VectorSensor sensor) /// - /// - /// - - /// - - /// AddVectorObs(float[]) - /// - /// - - /// AddVectorObs(List{float}) - /// /// - /// - /// - @@ -579,6 +567,9 @@ public virtual void CollectObservations(VectorSensor sensor) /// action by masking it. You can call the following method on the ActionMasker /// input : /// - + /// - + /// - + /// - /// The branch input is the index of the action, actionIndices are the indices of the /// invalid options for that action. /// From 3316c7f8f2fb129e34dc719963235ced56fff639 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Fri, 7 Feb 2020 15:29:59 -0800 Subject: [PATCH 3/4] Improvements to the documentation --- com.unity.ml-agents/Runtime/ActionMasker.cs | 23 ++++++++++----------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/com.unity.ml-agents/Runtime/ActionMasker.cs b/com.unity.ml-agents/Runtime/ActionMasker.cs index bc47685295..741a385270 100644 --- a/com.unity.ml-agents/Runtime/ActionMasker.cs +++ b/com.unity.ml-agents/Runtime/ActionMasker.cs @@ -21,9 +21,9 @@ internal ActionMasker(BrainParameters brainParameters) /// /// Sets an action mask for discrete control agents. When used, the agent will not be - /// able to perform the action passed as argument at the next decision. If no branch is - /// specified, the default branch will be 0. The actionIndex or actionIndices correspond - /// to the action the agent will be unable to perform. + /// able to perform the actions passed as argument at the next decision. + /// The actionIndices correspond to the actions the agent will be unable to perform + /// on the branch 0. /// /// The indices of the masked actions on branch 0 public void SetActionMask(IEnumerable actionIndices) @@ -33,9 +33,9 @@ public void SetActionMask(IEnumerable actionIndices) /// /// Sets an action mask for discrete control agents. When used, the agent will not be - /// able to perform the action passed as argument at the next decision. If no branch is - /// specified, the default branch will be 0. The actionIndex or actionIndices correspond - /// to the action the agent will be unable to perform. + /// able to perform the action passed as argument at the next decision for the specified + /// action branch. The actionIndex correspond to the action the agent will be unable + /// to perform. /// /// The branch for which the actions will be masked /// The index of the masked action @@ -46,9 +46,8 @@ public void SetActionMask(int branch, int actionIndex) /// /// Sets an action mask for discrete control agents. When used, the agent will not be - /// able to perform the action passed as argument at the next decision. If no branch is - /// specified, the default branch will be 0. The actionIndex or actionIndices correspond - /// to the action the agent will be unable to perform. + /// able to perform the action passed as argument at the next decision. The actionIndex + /// correspond to the action the agent will be unable to perform on the branch 0. /// /// The index of the masked action on branch 0 public void SetActionMask(int actionIndex) @@ -58,9 +57,9 @@ public void SetActionMask(int actionIndex) /// /// Modifies an action mask for discrete control agents. When used, the agent will not be - /// able to perform the action passed as argument at the next decision. If no branch is - /// specified, the default branch will be 0. The actionIndex or actionIndices correspond - /// to the action the agent will be unable to perform. + /// able to perform the actions passed as argument at the next decision for the specified + /// action branch. The actionIndices correspond to the action options the agent will + /// be unable to perform. /// /// The branch for which the actions will be masked /// The indices of the masked actions From 0f4879ec15e2af61897dbe4203b7b43f7a2fb36e Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Mon, 10 Feb 2020 10:32:29 -0800 Subject: [PATCH 4/4] Editing the documentation --- docs/Learning-Environment-Design-Agents.md | 6 ++++-- docs/Migrating.md | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/Learning-Environment-Design-Agents.md b/docs/Learning-Environment-Design-Agents.md index dcc92e44bb..4d221b5fef 100644 --- a/docs/Learning-Environment-Design-Agents.md +++ b/docs/Learning-Environment-Design-Agents.md @@ -391,10 +391,12 @@ impossible for the next decision. When the Agent is controlled by a neural network, the Agent will be unable to perform the specified action. Note that when the Agent is controlled by its Heuristic, the Agent will still be able to decide to perform the masked action. In order to mask an -action, call the method `SetActionMask` within the `CollectObservation` method : +action, call the method `SetActionMask` on the optional `ActionMasker` argument of the `CollectObservation` method : ```csharp -SetActionMask(branch, actionIndices) +public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker){ + actionMasker.SetActionMask(branch, actionIndices) +} ``` Where: diff --git a/docs/Migrating.md b/docs/Migrating.md index 765ae068a8..705c792a78 100644 --- a/docs/Migrating.md +++ b/docs/Migrating.md @@ -13,11 +13,12 @@ The versions can be found in * The `Agent.CollectObservations()` virtual method now takes as input a `VectorSensor` sensor as argument. The `Agent.AddVectorObs()` methods were removed. * The `Monitor` class has been moved to the Examples Project. (It was prone to errors during testing) * The `MLAgents.Sensor` namespace has been removed. All sensors now belong to the `MLAgents` namespace. +* The `SetActionMask` method must now be called on the optional `ActionMasker` argument of the `CollectObservations` method. (We now consider an action mask as a type of observation) ### Steps to Migrate * Replace your Agent's implementation of `CollectObservations()` with `CollectObservations(VectorSensor sensor)`. In addition, replace all calls to `AddVectorObs()` with `sensor.AddObservation()` or `sensor.AddOneHotObservation()` on the `VectorSensor` passed as argument. - +* Replace your calls to `SetActionMask` on your Agent to `ActionMasker.SetActionMask` in `CollectObservations` ## Migrating from 0.13 to 0.14