diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index f907e98d7f..b231999948 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -25,6 +25,10 @@ and this project adheres to ### Bug Fixes #### com.unity.ml-agents (C#) +- Fixed an issue where RayPerceptionSensor would raise an exception when the +list of tags was empty, or a tag in the list was invalid (unknown, null, or +empty string). (#4155) + #### ml-agents / ml-agents-envs / gym-unity (Python) ## [1.1.0-preview] - 2020-06-10 diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs index 418cb5825f..8996eb38d3 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs @@ -469,9 +469,24 @@ out DebugDisplayInfo.RayInfo debugRayOut if (castHit) { // Find the index of the tag of the object that was hit. - for (var i = 0; i < input.DetectableTags.Count; i++) + var numTags = input.DetectableTags?.Count ?? 0; + for (var i = 0; i < numTags; i++) { - if (hitObject.CompareTag(input.DetectableTags[i])) + var tagsEqual = false; + try + { + var tag = input.DetectableTags[i]; + if (!string.IsNullOrEmpty(tag)) + { + tagsEqual = hitObject.CompareTag(tag); + } + } + catch (UnityException e) + { + // If the tag is null, empty, or not a valid tag, just ignore it. + } + + if (tagsEqual) { rayOutput.HitTaggedObject = true; rayOutput.HitTagIndex = i; diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs index 5909eb84d0..094a36dfdc 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs @@ -281,6 +281,11 @@ void OnDrawGizmosSelected() else { var rayInput = GetRayPerceptionInput(); + // We don't actually need the tags here, since they don't affect the display of the rays. + // Additionally, the user might be in the middle of typing the tag name when this is called, + // and there's no way to turn off the "Tag ... is not defined" error logs. + // So just don't use any tags here. + rayInput.DetectableTags = null; for (var rayIndex = 0; rayIndex < rayInput.Angles.Count; rayIndex++) { DebugDisplayInfo.RayInfo debugRay; diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs index 5f56fe321e..a15b25e790 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using NUnit.Framework; using UnityEngine; +using UnityEngine.TestTools; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Tests @@ -336,6 +337,69 @@ public void TestStaticPerceive() // Expected to hit the cube Assert.AreEqual("cube", castOutput.RayOutputs[0].HitGameObject.name); + Assert.AreEqual(0, castOutput.RayOutputs[0].HitTagIndex); + } + } + + [Test] + public void TestStaticPerceiveInvalidTags() + { + SetupScene(); + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + + perception.RaysPerDirection = 0; // single ray + perception.MaxRayDegrees = 45; + perception.RayLength = 20; + perception.DetectableTags = new List(); + perception.DetectableTags.Add("Bad tag"); + perception.DetectableTags.Add(null); + perception.DetectableTags.Add(""); + perception.DetectableTags.Add(k_CubeTag); + + var radii = new[] { 0f, .5f }; + foreach (var castRadius in radii) + { + perception.SphereCastRadius = castRadius; + var castInput = perception.GetRayPerceptionInput(); + + // There's no clean way that I can find to check for a defined tag without + // logging an error. + LogAssert.Expect(LogType.Error, "Tag: Bad tag is not defined."); + var castOutput = RayPerceptionSensor.Perceive(castInput); + + Assert.AreEqual(1, castOutput.RayOutputs.Length); + + // Expected to hit the cube + Assert.AreEqual("cube", castOutput.RayOutputs[0].HitGameObject.name); + Assert.AreEqual(3, castOutput.RayOutputs[0].HitTagIndex); + } + } + + [Test] + public void TestStaticPerceiveNoTags() + { + SetupScene(); + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + + perception.RaysPerDirection = 0; // single ray + perception.MaxRayDegrees = 45; + perception.RayLength = 20; + perception.DetectableTags = null; + + var radii = new[] { 0f, .5f }; + foreach (var castRadius in radii) + { + perception.SphereCastRadius = castRadius; + var castInput = perception.GetRayPerceptionInput(); + var castOutput = RayPerceptionSensor.Perceive(castInput); + + Assert.AreEqual(1, castOutput.RayOutputs.Length); + + // Expected to hit the cube + Assert.AreEqual("cube", castOutput.RayOutputs[0].HitGameObject.name); + Assert.AreEqual(-1, castOutput.RayOutputs[0].HitTagIndex); } } }