diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 4f44533681..1014170f6f 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -6,6 +6,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). +## [Unreleased] +### Bug Fixes +#### com.unity.ml-agents (C#) +- Fixed an issue where RayPerceptionSensor would raise an exception when the +list of tags was empty, or a tag in the list was invalid (unknown, null, or +empty string). (#4155) + ## [1.0.2] - 2020-06-04 ### Minor Changes #### com.unity.ml-agents (C#) diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs index 036e090c29..523ad01f5f 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs @@ -462,9 +462,24 @@ out DebugDisplayInfo.RayInfo debugRayOut if (castHit) { // Find the index of the tag of the object that was hit. - for (var i = 0; i < input.DetectableTags.Count; i++) + var numTags = input.DetectableTags?.Count ?? 0; + for (var i = 0; i < numTags; i++) { - if (hitObject.CompareTag(input.DetectableTags[i])) + var tagsEqual = false; + try + { + var tag = input.DetectableTags[i]; + if (!string.IsNullOrEmpty(tag)) + { + tagsEqual = hitObject.CompareTag(tag); + } + } + catch (UnityException e) + { + // If the tag is null, empty, or not a valid tag, just ignore it. + } + + if (tagsEqual) { rayOutput.HitTaggedObject = true; rayOutput.HitTagIndex = i; diff --git a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs index 5909eb84d0..094a36dfdc 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs @@ -281,6 +281,11 @@ void OnDrawGizmosSelected() else { var rayInput = GetRayPerceptionInput(); + // We don't actually need the tags here, since they don't affect the display of the rays. + // Additionally, the user might be in the middle of typing the tag name when this is called, + // and there's no way to turn off the "Tag ... is not defined" error logs. + // So just don't use any tags here. + rayInput.DetectableTags = null; for (var rayIndex = 0; rayIndex < rayInput.Angles.Count; rayIndex++) { DebugDisplayInfo.RayInfo debugRay; diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs index 51e86d09ed..2858c42fd1 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using NUnit.Framework; using UnityEngine; +using UnityEngine.TestTools; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Tests @@ -296,5 +297,93 @@ public void TestRayZeroLength() Assert.LessOrEqual(outputBuffer[2], 1.0f); } } + + [Test] + public void TestStaticPerceive() + { + SetupScene(); + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + + perception.RaysPerDirection = 0; // single ray + perception.MaxRayDegrees = 45; + perception.RayLength = 20; + perception.DetectableTags = new List(); + perception.DetectableTags.Add(k_CubeTag); + perception.DetectableTags.Add(k_SphereTag); + + var radii = new[] { 0f, .5f }; + foreach (var castRadius in radii) + { + perception.SphereCastRadius = castRadius; + var castInput = perception.GetRayPerceptionInput(); + var castOutput = RayPerceptionSensor.Perceive(castInput); + + Assert.AreEqual(1, castOutput.RayOutputs.Length); + + // Expected to hit the cube + Assert.AreEqual(0, castOutput.RayOutputs[0].HitTagIndex); + } + } + + [Test] + public void TestStaticPerceiveInvalidTags() + { + SetupScene(); + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + + perception.RaysPerDirection = 0; // single ray + perception.MaxRayDegrees = 45; + perception.RayLength = 20; + perception.DetectableTags = new List(); + perception.DetectableTags.Add("Bad tag"); + perception.DetectableTags.Add(null); + perception.DetectableTags.Add(""); + perception.DetectableTags.Add(k_CubeTag); + + var radii = new[] { 0f, .5f }; + foreach (var castRadius in radii) + { + perception.SphereCastRadius = castRadius; + var castInput = perception.GetRayPerceptionInput(); + + // There's no clean way that I can find to check for a defined tag without + // logging an error. + LogAssert.Expect(LogType.Error, "Tag: Bad tag is not defined."); + var castOutput = RayPerceptionSensor.Perceive(castInput); + + Assert.AreEqual(1, castOutput.RayOutputs.Length); + + // Expected to hit the cube + Assert.AreEqual(3, castOutput.RayOutputs[0].HitTagIndex); + } + } + + [Test] + public void TestStaticPerceiveNoTags() + { + SetupScene(); + var obj = new GameObject("agent"); + var perception = obj.AddComponent(); + + perception.RaysPerDirection = 0; // single ray + perception.MaxRayDegrees = 45; + perception.RayLength = 20; + perception.DetectableTags = null; + + var radii = new[] { 0f, .5f }; + foreach (var castRadius in radii) + { + perception.SphereCastRadius = castRadius; + var castInput = perception.GetRayPerceptionInput(); + var castOutput = RayPerceptionSensor.Perceive(castInput); + + Assert.AreEqual(1, castOutput.RayOutputs.Length); + + // Expected to hit the cube + Assert.AreEqual(-1, castOutput.RayOutputs[0].HitTagIndex); + } + } } }