Skip to content

Throw if Academy.EnvironmentStep() is called recursively #4227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ and this project adheres to
### Minor Changes

### Bug Fixes
`mlagents-learn` will now raise an error immediately if `--num-envs` is greater than 1 without setting the `--env`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This got cherry-picked to the release branch, so it should get merged back into the section for 1.2.0

argument. (#4203)
#### com.unity.ml-agents (C#)
- Academy.EnvironmentStep() will now throw an exception if it is called
recursively (for example, by an Agent's CollectObservations method).
Previously, this would result in an infinite loop and cause the editor to hang.
(#4226)

## [1.2.0-preview] - 2020-07-15

Expand Down
69 changes: 48 additions & 21 deletions com.unity.ml-agents/Runtime/Academy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ public bool IsCommunicatorOn
// Flag used to keep track of the first time the Academy is reset.
bool m_HadFirstReset;

// Whether the Academy is in the middle of a step. This is used to detect and Academy
// step called by user code that is also called by the Academy.
bool m_IsStepping;

// Random seed used for inference.
int m_InferenceSeed;

Expand Down Expand Up @@ -486,36 +490,59 @@ void ForcedFullReset()
/// </summary>
public void EnvironmentStep()
{
if (!m_HadFirstReset)
// Check whether we're already in the middle of a step.
// This shouldn't happen generally, but could happen if user code (e.g. CollectObservations)
// that is called by EnvironmentStep() also calls EnvironmentStep(). This would result
// in an infinite loop and/or stack overflow, so stop it before it happens.
if (m_IsStepping)
{
ForcedFullReset();
throw new UnityAgentsException(
"Academy.EnvironmentStep() called recursively. " +
"This might happen if you call EnvironmentStep() from custom code such as " +
"CollectObservations() or OnActionReceived()."
);
}

AgentPreStep?.Invoke(m_StepCount);

m_StepCount += 1;
m_TotalStepCount += 1;
AgentIncrementStep?.Invoke();
m_IsStepping = true;

using (TimerStack.Instance.Scoped("AgentSendState"))
try
{
AgentSendState?.Invoke();
}
if (!m_HadFirstReset)
{
ForcedFullReset();
}

using (TimerStack.Instance.Scoped("DecideAction"))
{
DecideAction?.Invoke();
}
AgentPreStep?.Invoke(m_StepCount);

// If the communicator is not on, we need to clear the SideChannel sending queue
if (!IsCommunicatorOn)
{
SideChannelManager.GetSideChannelMessage();
}
m_StepCount += 1;
m_TotalStepCount += 1;
AgentIncrementStep?.Invoke();

using (TimerStack.Instance.Scoped("AgentAct"))
using (TimerStack.Instance.Scoped("AgentSendState"))
{
AgentSendState?.Invoke();
}

using (TimerStack.Instance.Scoped("DecideAction"))
{
DecideAction?.Invoke();
}

// If the communicator is not on, we need to clear the SideChannel sending queue
if (!IsCommunicatorOn)
{
SideChannelManager.GetSideChannelMessage();
}

using (TimerStack.Instance.Scoped("AgentAct"))
{
AgentAct?.Invoke();
}
}
finally
{
AgentAct?.Invoke();
// Reset m_IsStepping when we're done (or if an exception occurred).
m_IsStepping = false;
}
}

Expand Down
32 changes: 32 additions & 0 deletions com.unity.ml-agents/Tests/Editor/AcademyTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using NUnit.Framework;
using Unity.MLAgents.Sensors;
using UnityEngine;
#if UNITY_2019_3_OR_NEWER
using System.Reflection;
Expand All @@ -22,6 +23,37 @@ public void TestPackageVersion()
#endif
}

class RecursiveAgent : Agent
{
int m_collectObsCount;
public override void CollectObservations(VectorSensor sensor)
{
m_collectObsCount++;
if (m_collectObsCount == 1)
{
// NEVER DO THIS IN REAL CODE!
Academy.Instance.EnvironmentStep();
}
}
}

[Test]
public void TestRecursiveStepThrows()
{
var gameObj = new GameObject();
var agent = gameObj.AddComponent<RecursiveAgent>();
agent.LazyInitialize();
agent.RequestDecision();

Assert.Throws<UnityAgentsException>(() =>
{
Academy.Instance.EnvironmentStep();
});

// Make sure the Academy reset to a good state and is still steppable.
Academy.Instance.EnvironmentStep();
}


}
}