Skip to content

Commit 2dcbcb4

Browse files
author
Chris Elion
authored
store hit GameObject in raycast output (#4111)
* store hit GameObject in raycast output * changelog * fix unit tests * formatting
1 parent 05e24b1 commit 2dcbcb4

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ and this project adheres to
1414

1515
### Minor Changes
1616
#### com.unity.ml-agents (C#)
17+
- `RayPerceptionSensor.Perceive()` now additionally store the GameObject that was hit by the ray. (#4111)
1718
#### ml-agents / ml-agents-envs / gym-unity (Python)
1819

1920
### Bug Fixes

com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,12 @@ public struct RayOutput
161161
/// </summary>
162162
public float HitFraction;
163163

164+
/// <summary>
165+
/// The hit GameObject (or null if there was no hit).
166+
/// </summary>
167+
public GameObject HitGameObject;
168+
169+
164170
/// <summary>
165171
/// Writes the ray output information to a subset of the float array. Each element in the rayAngles array
166172
/// determines a sublist of data to the observation. The sublist contains the observation data for a single cast.
@@ -334,7 +340,7 @@ public void Update()
334340
}
335341

336342
/// <inheritdoc/>
337-
public void Reset() { }
343+
public void Reset() {}
338344

339345
/// <inheritdoc/>
340346
public int[] GetObservationShape()
@@ -456,7 +462,8 @@ out DebugDisplayInfo.RayInfo debugRayOut
456462
HasHit = castHit,
457463
HitFraction = hitFraction,
458464
HitTaggedObject = false,
459-
HitTagIndex = -1
465+
HitTagIndex = -1,
466+
HitGameObject = hitObject
460467
};
461468

462469
if (castHit)

com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ public class RayPerception3DTests
2727
const string k_CubeTag = "Player";
2828
const string k_SphereTag = "Respawn";
2929

30+
[TearDown]
31+
public void RemoveGameObjects()
32+
{
33+
var objects = GameObject.FindObjectsOfType<GameObject>();
34+
foreach (var o in objects)
35+
{
36+
UnityEngine.Object.DestroyImmediate(o);
37+
}
38+
}
39+
3040
void SetupScene()
3141
{
3242
/* Creates game objects in the world for testing.
@@ -44,18 +54,22 @@ void SetupScene()
4454
var cube = GameObject.CreatePrimitive(PrimitiveType.Cube);
4555
cube.transform.position = new Vector3(0, 0, 10);
4656
cube.tag = k_CubeTag;
57+
cube.name = "cube";
4758

4859
var sphere1 = GameObject.CreatePrimitive(PrimitiveType.Sphere);
4960
sphere1.transform.position = new Vector3(-5, 0, 5);
5061
sphere1.tag = k_SphereTag;
62+
sphere1.name = "sphere1";
5163

5264
var sphere2 = GameObject.CreatePrimitive(PrimitiveType.Sphere);
5365
sphere2.transform.position = new Vector3(5, 0, 5);
5466
// No tag for sphere2
67+
sphere2.name = "sphere2";
5568

5669
var sphere3 = GameObject.CreatePrimitive(PrimitiveType.Sphere);
5770
sphere3.transform.position = new Vector3(0, 0, -10);
5871
sphere3.tag = k_SphereTag;
72+
sphere3.name = "sphere3";
5973

6074
Physics.SyncTransforms();
6175
}
@@ -296,5 +310,33 @@ public void TestRayZeroLength()
296310
Assert.LessOrEqual(outputBuffer[2], 1.0f);
297311
}
298312
}
313+
314+
[Test]
315+
public void TestStaticPerceive()
316+
{
317+
SetupScene();
318+
var obj = new GameObject("agent");
319+
var perception = obj.AddComponent<RayPerceptionSensorComponent3D>();
320+
321+
perception.RaysPerDirection = 0; // single ray
322+
perception.MaxRayDegrees = 45;
323+
perception.RayLength = 20;
324+
perception.DetectableTags = new List<string>();
325+
perception.DetectableTags.Add(k_CubeTag);
326+
perception.DetectableTags.Add(k_SphereTag);
327+
328+
var radii = new[] { 0f, .5f };
329+
foreach (var castRadius in radii)
330+
{
331+
perception.SphereCastRadius = castRadius;
332+
var castInput = perception.GetRayPerceptionInput();
333+
var castOutput = RayPerceptionSensor.Perceive(castInput);
334+
335+
Assert.AreEqual(1, castOutput.RayOutputs.Length);
336+
337+
// Expected to hit the cube
338+
Assert.AreEqual("cube", castOutput.RayOutputs[0].HitGameObject.name);
339+
}
340+
}
299341
}
300342
}

0 commit comments

Comments
 (0)