Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public void TestStoreInitalize()
reward = 1f,
actionMasks = new[] { false, true },
done = true,
id = 5,
episodeId = 5,
maxStepReached = true,
storedVectorActions = new[] { 0f, 1f },
};
Expand Down
47 changes: 16 additions & 31 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ public AgentInfo _Info
}
}

public bool IsDone()
{
return (bool)typeof(Agent).GetField("m_Done", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this);
}
public int initializeAgentCalls;
public int collectObservationsCalls;
public int agentActionCalls;
Expand Down Expand Up @@ -191,8 +187,6 @@ public void TestAgent()
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();

Assert.AreEqual(false, agent1.IsDone());
Assert.AreEqual(false, agent2.IsDone());
Assert.AreEqual(0, agent1.agentResetCalls);
Assert.AreEqual(0, agent2.agentResetCalls);
Assert.AreEqual(0, agent1.initializeAgentCalls);
Expand All @@ -206,8 +200,6 @@ public void TestAgent()
agentEnableMethod?.Invoke(agent2, new object[] { });
agentEnableMethod?.Invoke(agent1, new object[] { });

Assert.AreEqual(false, agent1.IsDone());
Assert.AreEqual(false, agent2.IsDone());
// agent1 was not enabled when the academy started
// The agents have been initialized
Assert.AreEqual(0, agent1.agentResetCalls);
Expand Down Expand Up @@ -422,18 +414,14 @@ public void TestAgent()
if (i % 11 == 5)
{
agent1.Done();
numberAgent1Reset += 1;
}
// Resetting agent 2 regularly
if (i % 13 == 3)
{
if (!(agent2.IsDone()))
{
// If the agent was already reset before the request decision
// We should not reset again
agent2.Done();
numberAgent2Reset += 1;
agent2StepSinceReset = 0;
}
agent2.Done();
numberAgent2Reset += 1;
agent2StepSinceReset = 0;
}
// Request a decision for agent 2 regularly
if (i % 3 == 2)
Expand All @@ -445,16 +433,9 @@ public void TestAgent()
// Request an action without decision regularly
agent2.RequestAction();
}
if (agent1.IsDone())
{
numberAgent1Reset += 1;
}

acaStepsSinceReset += 1;
agent2StepSinceReset += 1;
//Agent 1 is only initialized at step 2
if (i < 2)
{ }
aca.EnvironmentStep();
}
}
Expand Down Expand Up @@ -500,19 +481,23 @@ public void TestCumulativeReward()
var j = 0;
for (var i = 0; i < 500; i++)
{
if (i % 20 == 0)
{
j = 0;
}
else
{
j++;
}
agent2.RequestAction();
Assert.LessOrEqual(Mathf.Abs(j * 0.1f + j * 10f - agent1.GetCumulativeReward()), 0.05f);
Assert.LessOrEqual(Mathf.Abs(j * 10.1f - agent1.GetCumulativeReward()), 0.05f);
Assert.LessOrEqual(Mathf.Abs(i * 0.1f - agent2.GetCumulativeReward()), 0.05f);


aca.EnvironmentStep();
agent1.AddReward(10f);
aca.EnvironmentStep();



if ((i % 21 == 0) && (i > 0))
{
j = 0;
}
j++;
}
}
}
Expand Down
10 changes: 0 additions & 10 deletions UnitySDK/Assets/ML-Agents/Scripts/Academy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,6 @@ public bool IsCommunicatorOn
// in addition to aligning on the step count of the global episode.
public event System.Action<int> AgentSetStatus;

// Signals to all the agents at each environment step so they can reset
// if their flag has been set to done (assuming the agent has requested a
// decision).
public event System.Action AgentResetIfDone;

// Signals to all the agents at each environment step so they can send
// their state to their Policy if they have requested a decision.
public event System.Action AgentSendState;
Expand Down Expand Up @@ -314,7 +309,6 @@ void ResetActions()
DecideAction = () => { };
DestroyAction = () => { };
AgentSetStatus = i => { };
AgentResetIfDone = () => { };
AgentSendState = () => { };
AgentAct = () => { };
AgentForceReset = () => { };
Expand Down Expand Up @@ -392,10 +386,6 @@ public void EnvironmentStep()

AgentSetStatus?.Invoke(m_StepCount);

using (TimerStack.Instance.Scoped("AgentResetIfDone"))
{
AgentResetIfDone?.Invoke();
}

using (TimerStack.Instance.Scoped("AgentSendState"))
{
Expand Down
99 changes: 33 additions & 66 deletions UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ public struct AgentInfo
public bool maxStepReached;

/// <summary>
/// Unique identifier each agent receives at initialization. It is used
/// Episode identifier each agent receives at every reset. It is used
/// to separate between different agents in the environment.
/// </summary>
public int id;
public int episodeId;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update this comment too.

}

/// <summary>
Expand Down Expand Up @@ -148,23 +148,17 @@ public abstract class Agent : MonoBehaviour
/// Whether or not the agent requests a decision.
bool m_RequestDecision;

/// Whether or not the agent has completed the episode. This may be due
/// to either reaching a success or fail state, or reaching the maximum
/// number of steps (i.e. timing out).
bool m_Done;

/// Whether or not the agent reached the maximum number of steps.
bool m_MaxStepReached;

/// Keeps track of the number of steps taken by the agent in this episode.
/// Note that this value is different for each agent, and may not overlap
/// with the step counter in the Academy, since agents reset based on
/// their own experience.
int m_StepCount;

/// Unique identifier each agent receives at initialization. It is used
/// Episode identifier each agent receives. It is used
/// to separate between different agents in the environment.
int m_Id;
/// This Id will be changed every time the Agent resets.
int m_EpisodeId;

/// Keeps track of the actions that are masked at each step.
ActionMasker m_ActionMasker;
Expand All @@ -190,7 +184,7 @@ public abstract class Agent : MonoBehaviour
/// becomes enabled or active.
void OnEnable()
{
m_Id = gameObject.GetInstanceID();
m_EpisodeId = EpisodeIdCounter.GetEpisodeId();
OnEnableHelper();

m_Recorder = GetComponent<DemonstrationRecorder>();
Expand All @@ -204,7 +198,6 @@ void OnEnableHelper()
m_Action = new AgentAction();
sensors = new List<ISensor>();

Academy.Instance.AgentResetIfDone += ResetIfDone;
Academy.Instance.AgentSendState += SendInfo;
Academy.Instance.DecideAction += DecideAction;
Academy.Instance.AgentAct += AgentStep;
Expand All @@ -224,7 +217,6 @@ void OnDisable()
// We don't want to even try, because this will lazily create a new Academy!
if (Academy.IsInitialized)
{
Academy.Instance.AgentResetIfDone -= ResetIfDone;
Academy.Instance.AgentSendState -= SendInfo;
Academy.Instance.DecideAction -= DecideAction;
Academy.Instance.AgentAct -= AgentStep;
Expand All @@ -234,12 +226,20 @@ void OnDisable()
m_Brain?.Dispose();
}

void NotifyAgentDone()
void NotifyAgentDone(bool maxStepReached = false)
{
m_Info.reward = m_Reward;
m_Info.done = true;
m_Info.maxStepReached = maxStepReached;
// Request the last decision with no callbacks
// We request a decision so Python knows the Agent is disabled
// We request a decision so Python knows the Agent is done immediately
m_Brain?.RequestDecision(m_Info, sensors, (a) => { });
// The Agent is done, so we give it a new episode Id
m_EpisodeId = EpisodeIdCounter.GetEpisodeId();
m_Reward = 0f;
m_CumulativeReward = 0f;
m_RequestAction = false;
m_RequestDecision = false;
}

/// <summary>
Expand Down Expand Up @@ -322,7 +322,9 @@ public float GetCumulativeReward()
/// </summary>
public void Done()
{
m_Done = true;
NotifyAgentDone();
_AgentReset();

}

/// <summary>
Expand All @@ -342,28 +344,6 @@ public void RequestAction()
m_RequestAction = true;
}

/// <summary>
/// Indicates if the agent has reached his maximum number of steps.
/// </summary>
/// <returns>
/// <c>true</c>, if max step reached was reached, <c>false</c> otherwise.
/// </returns>
public bool IsMaxStepReached()
{
return m_MaxStepReached;
}

/// <summary>
/// Indicates if the agent is done
/// </summary>
/// <returns>
/// <c>true</c>, if the agent is done, <c>false</c> otherwise.
/// </returns>
public bool IsDone()
{
return m_Done;
}

/// Helper function that resets all the data structures associated with
/// the agent. Typically used when the agent is being initialized or reset
/// at the end of an episode.
Expand Down Expand Up @@ -489,9 +469,9 @@ void SendInfoToBrain()
m_Info.actionMasks = m_ActionMasker.GetMask();

m_Info.reward = m_Reward;
m_Info.done = m_Done;
m_Info.maxStepReached = m_MaxStepReached;
m_Info.id = m_Id;
m_Info.done = false;
m_Info.maxStepReached = false;
m_Info.episodeId = m_EpisodeId;

m_Brain.RequestDecision(m_Info, sensors, UpdateAgentAction);

Expand Down Expand Up @@ -742,51 +722,38 @@ protected float ScaleAction(float rawAction, float min, float max)
}


/// Signals the agent that it must reset if its done flag is set to true.
void ResetIfDone()
{
if (m_Done)
{
_AgentReset();
}
}

/// <summary>
/// Signals the agent that it must sent its decision to the brain.
/// </summary>
void SendInfo()
{
// If the Agent is done, it has just reset and thus requires a new decision
if (m_RequestDecision || m_Done)
if (m_RequestDecision)
{
SendInfoToBrain();
m_Reward = 0f;
if (m_Done)
{
m_CumulativeReward = 0f;
}
m_Done = false;
m_MaxStepReached = false;
m_RequestDecision = false;
}
}

/// Used by the brain to make the agent perform a step.
void AgentStep()
{
if ((m_StepCount >= maxStep - 1) && (maxStep > 0))
{
NotifyAgentDone(true);
_AgentReset();

}
else
{
m_StepCount += 1;
}
if ((m_RequestAction) && (m_Brain != null))
{
m_RequestAction = false;
AgentAction(m_Action.vectorActions);
}

if ((m_StepCount >= maxStep) && (maxStep > 0))
{
m_MaxStepReached = true;
Done();
}

m_StepCount += 1;
}

void DecideAction()
Expand Down
11 changes: 11 additions & 0 deletions UnitySDK/Assets/ML-Agents/Scripts/EpisodeIdCounter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace MLAgents
{
public static class EpisodeIdCounter
{
private static int Counter;
public static int GetEpisodeId()
{
return Counter++;
}
}
}
11 changes: 11 additions & 0 deletions UnitySDK/Assets/ML-Agents/Scripts/EpisodeIdCounter.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
Reward = ai.reward,
MaxStepReached = ai.maxStepReached,
Done = ai.done,
Id = ai.id,
Id = ai.episodeId,
};

if (ai.actionMasks != null)
Expand Down
2 changes: 1 addition & 1 deletion UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ public void PutObservations(string brainKey, AgentInfo info, List<ISensor> senso
{
m_ActionCallbacks[brainKey] = new List<IdCallbackPair>();
}
m_ActionCallbacks[brainKey].Add(new IdCallbackPair { AgentId = info.id, Callback = action });
m_ActionCallbacks[brainKey].Add(new IdCallbackPair { AgentId = info.episodeId, Callback = action });
}

/// <summary>
Expand Down
Loading