From 3d9ea936cc8ef099d9adc9367765c91626612bc4 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 4 May 2020 16:32:49 -0700 Subject: [PATCH 01/23] ObservableAttribute proof-of-concept --- .../Scripts/FoodCollectorAgent.cs | 16 +++- com.unity.ml-agents/Runtime/Agent.cs | 23 ++++++ .../Runtime/Sensors/AttributeFieldSensor.cs | 72 +++++++++++++++++ .../Sensors/AttributePropertySensor.cs | 78 +++++++++++++++++++ .../Runtime/Sensors/ObservableAttribute.cs | 10 +++ 5 files changed, 198 insertions(+), 1 deletion(-) create mode 100644 com.unity.ml-agents/Runtime/Sensors/AttributeFieldSensor.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/AttributePropertySensor.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/ObservableAttribute.cs diff --git a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs index 87c9d316ae..4d21ea4e1b 100644 --- a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs +++ b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs @@ -7,10 +7,16 @@ public class FoodCollectorAgent : Agent FoodCollectorSettings m_FoodCollecterSettings; public GameObject area; FoodCollectorArea m_MyArea; + + [Observable] bool m_Frozen; + bool m_Poisoned; bool m_Satiated; + + [Observable] bool m_Shoot; + float m_FrozenTime; float m_EffectTime; Rigidbody m_AgentRb; @@ -39,13 +45,21 @@ public override void Initialize() SetResetParameters(); } + [Observable] + Vector3 localRbVelocity + { + get { return transform.InverseTransformDirection(m_AgentRb.velocity); } + } + public override void CollectObservations(VectorSensor sensor) { if (useVectorObs) { - var localVelocity = transform.InverseTransformDirection(m_AgentRb.velocity); + // TODO use Observable with localRbVelocity instead + var localVelocity = localRbVelocity; sensor.AddObservation(localVelocity.x); sensor.AddObservation(localVelocity.z); + // TODO replace with Observables sensor.AddObservation(System.Convert.ToInt32(m_Frozen)); sensor.AddObservation(System.Convert.ToInt32(m_Shoot)); } diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index c142abbd52..889f9b99ef 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Collections.ObjectModel; +using System.Reflection; using UnityEngine; using Unity.Barracuda; using Unity.MLAgents.Sensors; @@ -816,6 +817,28 @@ public virtual void Heuristic(float[] actionsOut) /// internal void InitializeSensors() { + // Iterate over Observables + var fields = this.GetType().GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + foreach (var field in fields) + { + var attr = (ObservableAttribute)Attribute.GetCustomAttribute(field, typeof(ObservableAttribute)); + if (attr != null) + { + sensors.Add(new AttributeFieldSensor(this, field, attr)); + } + } + + var properties = this.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + foreach (var prop in properties) + { + var attr = (ObservableAttribute)Attribute.GetCustomAttribute(prop, typeof(ObservableAttribute)); + if (attr != null) + { + sensors.Add(new AttributePropertySensor(this, prop, attr)); + } + } + + // Get all attached sensor components SensorComponent[] attachedSensorComponents; if (m_PolicyFactory.UseChildSensors) diff --git a/com.unity.ml-agents/Runtime/Sensors/AttributeFieldSensor.cs b/com.unity.ml-agents/Runtime/Sensors/AttributeFieldSensor.cs new file mode 100644 index 0000000000..655c0226f9 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/AttributeFieldSensor.cs @@ -0,0 +1,72 @@ +using System.Reflection; + +namespace Unity.MLAgents.Sensors +{ + internal class AttributeFieldSensor : ISensor + { + object m_Object; + FieldInfo m_FieldInfo; + // Not currently used, but might want later. + ObservableAttribute m_ObservableAttribute; + + string m_SensorName; + int[] m_Shape; + + public AttributeFieldSensor(object o, FieldInfo fieldInfo, ObservableAttribute observableAttribute) + { + m_Object = o; + m_FieldInfo = fieldInfo; + m_ObservableAttribute = observableAttribute; + + m_SensorName = $"ObservableAttribute:{fieldInfo.DeclaringType.Name}.{fieldInfo.Name}"; + // TODO handle Vector3, quaternion, blittable(?) + m_Shape = new [] {1}; + } + + /// + public int[] GetObservationShape() + { + return m_Shape; + } + + /// + public int Write(ObservationWriter writer) + { + var val = m_FieldInfo.GetValue(m_Object); + if (m_FieldInfo.FieldType == typeof(System.Boolean)) + { + var boolVal = (System.Boolean)val; + writer[0] = boolVal ? 1.0f : 0.0f; + } + else + { + writer[0] = 0.0f; + } + return 1; + } + + /// + public byte[] GetCompressedObservation() + { + return null; + } + + /// + public void Update() { } + + /// + public void Reset() { } + + /// + public SensorCompressionType GetCompressionType() + { + return SensorCompressionType.None; + } + + /// + public string GetName() + { + return m_SensorName; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/AttributePropertySensor.cs b/com.unity.ml-agents/Runtime/Sensors/AttributePropertySensor.cs new file mode 100644 index 0000000000..346f4cd9ef --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/AttributePropertySensor.cs @@ -0,0 +1,78 @@ +using System.Reflection; + +namespace Unity.MLAgents.Sensors +{ + internal class AttributePropertySensor : ISensor + { + object m_Object; + PropertyInfo m_PropertyInfo; + // Not currently used, but might want later. + ObservableAttribute m_ObservableAttribute; + + string m_SensorName; + int[] m_Shape; + + public AttributePropertySensor(object o, PropertyInfo propertyInfo, ObservableAttribute observableAttribute) + { + m_Object = o; + m_PropertyInfo = propertyInfo; + m_ObservableAttribute = observableAttribute; + + m_SensorName = $"ObservableAttribute:{propertyInfo.DeclaringType.Name}.{propertyInfo.Name}"; + // TODO handle scalar, quaternion, blittable(?) + m_Shape = new [] {3}; + } + + /// + public int[] GetObservationShape() + { + return m_Shape; + } + + /// + public int Write(ObservationWriter writer) + { + // TODO make Delegate in ctor instead? + var val = m_PropertyInfo.GetMethod.Invoke(m_Object, null); + + if (m_PropertyInfo.PropertyType == typeof(UnityEngine.Vector3)) + { + var vec3Val = (UnityEngine.Vector3)val; + writer[0] = vec3Val.x; + writer[1] = vec3Val.y; + writer[2] = vec3Val.z; + } + else + { + writer[0] = 0.0f; + writer[1] = 0.0f; + writer[2] = 0.0f; + } + return 1; + } + + /// + public byte[] GetCompressedObservation() + { + return null; + } + + /// + public void Update() { } + + /// + public void Reset() { } + + /// + public SensorCompressionType GetCompressionType() + { + return SensorCompressionType.None; + } + + /// + public string GetName() + { + return m_SensorName; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/ObservableAttribute.cs new file mode 100644 index 0000000000..5df69e0336 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/ObservableAttribute.cs @@ -0,0 +1,10 @@ +namespace Unity.MLAgents.Sensors +{ + [System.AttributeUsage(System.AttributeTargets.Field | System.AttributeTargets.Property)] + public class ObservableAttribute : System.Attribute + { + // Currently nothing here + // Could possible add "mask" flags for vector fields/properties + // E.g. MaskX | MaskZ to get the only the X and Z properties of a Vector3 field. + } +} From 245fc5abbed80fbd980e191710d3e3e16c4d1f71 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Fri, 8 May 2020 10:45:20 -0700 Subject: [PATCH 02/23] restructue sensors, add int impl, unit test --- .../Scripts/FoodCollectorAgent.cs | 1 + com.unity.ml-agents/Runtime/Agent.cs | 45 +++++----- .../Runtime/Sensors/ObservableAttribute.cs | 10 --- .../Runtime/Sensors/Reflection.meta | 8 ++ .../{ => Reflection}/AttributeFieldSensor.cs | 2 +- .../Reflection/AttributeFieldSensor.cs.meta | 11 +++ .../AttributePropertySensor.cs | 2 +- .../AttributePropertySensor.cs.meta | 11 +++ .../Sensors/Reflection/IntReflectionSensor.cs | 28 +++++++ .../Reflection/IntReflectionSensor.cs.meta | 11 +++ .../Sensors/Reflection/ObservableAttribute.cs | 84 +++++++++++++++++++ .../Reflection/ObservableAttribute.cs.meta | 11 +++ .../Reflection/ReflectionSensorBase.cs | 68 +++++++++++++++ .../Reflection/ReflectionSensorBase.cs.meta | 11 +++ .../Editor/Sensor/ObservableAttributeTests.cs | 50 +++++++++++ .../Sensor/ObservableAttributeTests.cs.meta | 11 +++ 16 files changed, 331 insertions(+), 33 deletions(-) delete mode 100644 com.unity.ml-agents/Runtime/Sensors/ObservableAttribute.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection.meta rename com.unity.ml-agents/Runtime/Sensors/{ => Reflection}/AttributeFieldSensor.cs (97%) create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs.meta rename com.unity.ml-agents/Runtime/Sensors/{ => Reflection}/AttributePropertySensor.cs (97%) create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs.meta create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta create mode 100644 com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs create mode 100644 com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta diff --git a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs index 4d21ea4e1b..0ee610c1d1 100644 --- a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs +++ b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs @@ -1,6 +1,7 @@ using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; public class FoodCollectorAgent : Agent { diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 889f9b99ef..61950e6c0a 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -5,6 +5,7 @@ using UnityEngine; using Unity.Barracuda; using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; using Unity.MLAgents.Demonstrations; using Unity.MLAgents.Policies; using UnityEngine.Serialization; @@ -817,27 +818,29 @@ public virtual void Heuristic(float[] actionsOut) /// internal void InitializeSensors() { - // Iterate over Observables - var fields = this.GetType().GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); - foreach (var field in fields) - { - var attr = (ObservableAttribute)Attribute.GetCustomAttribute(field, typeof(ObservableAttribute)); - if (attr != null) - { - sensors.Add(new AttributeFieldSensor(this, field, attr)); - } - } - - var properties = this.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); - foreach (var prop in properties) - { - var attr = (ObservableAttribute)Attribute.GetCustomAttribute(prop, typeof(ObservableAttribute)); - if (attr != null) - { - sensors.Add(new AttributePropertySensor(this, prop, attr)); - } - } - +// // Iterate over Observables +// var fields = this.GetType().GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); +// foreach (var field in fields) +// { +// var attr = (ObservableAttribute)Attribute.GetCustomAttribute(field, typeof(ObservableAttribute)); +// if (attr != null) +// { +// sensors.Add(new AttributeFieldSensor(this, field, attr)); +// } +// } +// +// var properties = this.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); +// foreach (var prop in properties) +// { +// var attr = (ObservableAttribute)Attribute.GetCustomAttribute(prop, typeof(ObservableAttribute)); +// if (attr != null) +// { +// sensors.Add(new AttributePropertySensor(this, prop, attr)); +// } +// } + + var observableSensors = ObservableAttribute.GetObservableSensors(this); + sensors.AddRange(observableSensors); // Get all attached sensor components SensorComponent[] attachedSensorComponents; diff --git a/com.unity.ml-agents/Runtime/Sensors/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/ObservableAttribute.cs deleted file mode 100644 index 5df69e0336..0000000000 --- a/com.unity.ml-agents/Runtime/Sensors/ObservableAttribute.cs +++ /dev/null @@ -1,10 +0,0 @@ -namespace Unity.MLAgents.Sensors -{ - [System.AttributeUsage(System.AttributeTargets.Field | System.AttributeTargets.Property)] - public class ObservableAttribute : System.Attribute - { - // Currently nothing here - // Could possible add "mask" flags for vector fields/properties - // E.g. MaskX | MaskZ to get the only the X and Z properties of a Vector3 field. - } -} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection.meta new file mode 100644 index 0000000000..fb7288f717 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 08ece3d7e9bb94089a9d59c6f269ab0a +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/AttributeFieldSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs similarity index 97% rename from com.unity.ml-agents/Runtime/Sensors/AttributeFieldSensor.cs rename to com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs index 655c0226f9..1e82044c85 100644 --- a/com.unity.ml-agents/Runtime/Sensors/AttributeFieldSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs @@ -1,6 +1,6 @@ using System.Reflection; -namespace Unity.MLAgents.Sensors +namespace Unity.MLAgents.Sensors.Reflection { internal class AttributeFieldSensor : ISensor { diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs.meta new file mode 100644 index 0000000000..32e133d57a --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a477ee7a6cdf849b18de49c82b8f9c2e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/AttributePropertySensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs similarity index 97% rename from com.unity.ml-agents/Runtime/Sensors/AttributePropertySensor.cs rename to com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs index 346f4cd9ef..f187e7fad0 100644 --- a/com.unity.ml-agents/Runtime/Sensors/AttributePropertySensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs @@ -1,6 +1,6 @@ using System.Reflection; -namespace Unity.MLAgents.Sensors +namespace Unity.MLAgents.Sensors.Reflection { internal class AttributePropertySensor : ISensor { diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs.meta new file mode 100644 index 0000000000..c7d0b0fcf9 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 4a0bc726219994fbe95fbec261df482b +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs new file mode 100644 index 0000000000..c90c8e8033 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs @@ -0,0 +1,28 @@ +using System.Reflection; + +namespace Unity.MLAgents.Sensors.Reflection +{ + internal class IntReflectionSensor : ReflectionSensorBase + { + internal IntReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute, string sensorName) + : base(o, fieldInfo, propertyInfo, observableAttribute, 1, sensorName) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + if (m_FieldInfo != null) + { + var val = m_FieldInfo.GetValue(m_Object); + var intVal = (System.Int32)val; + writer[0] = intVal; + } + else + { + // TODO form delegate in ctor + var val = m_PropertyInfo.GetMethod.Invoke(m_Object, null); + var intVal = (System.Int32)val; + writer[0] = intVal; + } + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta new file mode 100644 index 0000000000..a07726937f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5cae4c843cc074d11a549aaa3904c898 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs new file mode 100644 index 0000000000..4be762b301 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -0,0 +1,84 @@ +using System; +using System.Collections.Generic; +using System.Reflection; + +namespace Unity.MLAgents.Sensors.Reflection +{ + [System.AttributeUsage(System.AttributeTargets.Field | System.AttributeTargets.Property)] + public class ObservableAttribute : System.Attribute + { + // Currently nothing here + string m_Name; + + public ObservableAttribute(string name=null) + { + m_Name = name; + } + + internal static List GetObservableSensors(object o) + { + var sensorsOut = new List(); + + var fields = o.GetType().GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + foreach (var field in fields) + { + var attr = (ObservableAttribute)Attribute.GetCustomAttribute(field, typeof(ObservableAttribute)); + if (attr != null) + { + sensorsOut.Add(CreateReflectionSensor(o, field, null, attr)); + } + } + + var properties = o.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + foreach (var prop in properties) + { + var attr = (ObservableAttribute)Attribute.GetCustomAttribute(prop, typeof(ObservableAttribute)); + if (attr != null) + { + sensorsOut.Add(CreateReflectionSensor(o, null, prop, attr)); + } + } + return sensorsOut; + } + + internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute) + { + MemberInfo memberInfo = fieldInfo != null ? (MemberInfo) fieldInfo : propertyInfo; + string memberName; + string declaringTypeName; + Type memberType; + if (fieldInfo != null) + { + declaringTypeName = fieldInfo.DeclaringType.Name; + memberName = fieldInfo.Name; + memberType = fieldInfo.FieldType; + } + else + { + declaringTypeName = propertyInfo.DeclaringType.Name; + memberName = propertyInfo.Name; + memberType = propertyInfo.PropertyType; + } + + string sensorName; + if (string.IsNullOrEmpty(observableAttribute.m_Name)) + { + sensorName = $"ObservableAttribute:{declaringTypeName}.{memberName}"; + } + else + { + sensorName = observableAttribute.m_Name; + } + + if (memberType == typeof(System.Int32)) + { + return new IntReflectionSensor(o, fieldInfo, propertyInfo, observableAttribute, sensorName); + } + + throw new UnityAgentsException($"Unsupported Observable type: {memberType.Name}"); + + } + + } + +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta new file mode 100644 index 0000000000..41659283da --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a75086dc66a594baea6b8b2935f5dacf +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs new file mode 100644 index 0000000000..1c42328b7a --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs @@ -0,0 +1,68 @@ +using System.Reflection; + +namespace Unity.MLAgents.Sensors.Reflection +{ + internal abstract class ReflectionSensorBase : ISensor + { + protected object m_Object; + + protected FieldInfo m_FieldInfo; + protected PropertyInfo m_PropertyInfo; + // Not currently used, but might want later. + protected ObservableAttribute m_ObservableAttribute; + + string m_SensorName; + int[] m_Shape; + + public ReflectionSensorBase(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute, int size, string sensorName) + { + // TODO 2 constructors? + + m_Object = o; + m_FieldInfo = fieldInfo; + m_PropertyInfo = propertyInfo; + m_ObservableAttribute = observableAttribute; + m_SensorName = sensorName; + m_Shape = new [] {size}; + } + + /// + public int[] GetObservationShape() + { + return m_Shape; + } + + /// + public int Write(ObservationWriter writer) + { + WriteReflectedField(writer); + return m_Shape[0]; + } + + internal abstract void WriteReflectedField(ObservationWriter writer); + + /// + public byte[] GetCompressedObservation() + { + return null; + } + + /// + public void Update() { } + + /// + public void Reset() { } + + /// + public SensorCompressionType GetCompressionType() + { + return SensorCompressionType.None; + } + + /// + public string GetName() + { + return m_SensorName; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta new file mode 100644 index 0000000000..cef19bb598 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 6b68d855fb94a45fbbeb0dbe968a35f8 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs new file mode 100644 index 0000000000..0b16627f18 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs @@ -0,0 +1,50 @@ +using System; +using System.Collections.Generic; +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; + +namespace Unity.MLAgents.Tests +{ + + [TestFixture] + public class ObservableAttributeTests + { + class TestClass + { + [Observable] + public int m_IntMember; + + int m_IntProperty; + + [Observable] + public int IntProperty + { + get => m_IntProperty; + set => m_IntProperty = value; + } + } + + [Test] + public void TestGetObservableSensors() + { + var testClass = new TestClass(); + testClass.m_IntMember = 1; + testClass.IntProperty = 2; + + var sensors = ObservableAttribute.GetObservableSensors(testClass); + Assert.AreEqual(sensors.Count, 2); + + var sensorsByName = new Dictionary(); + foreach (var sensor in sensors) + { + sensorsByName[sensor.GetName()] = sensor; + } + + SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.m_IntMember"], new[] {1.0f}); + SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.IntProperty"], new[] {2.0f}); + + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta new file mode 100644 index 0000000000..611fdcfa12 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 33d7912e6b3504412bd261b40e46df32 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: From 351be8139cbf6f1cc683af0ef450da725c767005 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 11 May 2020 13:21:18 -0700 Subject: [PATCH 03/23] add vector3 sensor, cleanup constructors --- .../Reflection/AttributeFieldSensor.cs | 72 ----------------- .../Reflection/AttributePropertySensor.cs | 78 ------------------- .../AttributePropertySensor.cs.meta | 11 --- .../Sensors/Reflection/IntReflectionSensor.cs | 21 +---- .../Sensors/Reflection/ObservableAttribute.cs | 22 +++++- .../Reflection/ReflectionSensorBase.cs | 30 +++++-- .../Reflection/Vector3ReflectionSensor.cs | 17 ++++ ...s.meta => Vector3ReflectionSensor.cs.meta} | 2 +- .../Editor/Sensor/ObservableAttributeTests.cs | 18 ++++- 9 files changed, 83 insertions(+), 188 deletions(-) delete mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs delete mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs delete mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs.meta create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs rename com.unity.ml-agents/Runtime/Sensors/Reflection/{AttributeFieldSensor.cs.meta => Vector3ReflectionSensor.cs.meta} (83%) diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs deleted file mode 100644 index 1e82044c85..0000000000 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs +++ /dev/null @@ -1,72 +0,0 @@ -using System.Reflection; - -namespace Unity.MLAgents.Sensors.Reflection -{ - internal class AttributeFieldSensor : ISensor - { - object m_Object; - FieldInfo m_FieldInfo; - // Not currently used, but might want later. - ObservableAttribute m_ObservableAttribute; - - string m_SensorName; - int[] m_Shape; - - public AttributeFieldSensor(object o, FieldInfo fieldInfo, ObservableAttribute observableAttribute) - { - m_Object = o; - m_FieldInfo = fieldInfo; - m_ObservableAttribute = observableAttribute; - - m_SensorName = $"ObservableAttribute:{fieldInfo.DeclaringType.Name}.{fieldInfo.Name}"; - // TODO handle Vector3, quaternion, blittable(?) - m_Shape = new [] {1}; - } - - /// - public int[] GetObservationShape() - { - return m_Shape; - } - - /// - public int Write(ObservationWriter writer) - { - var val = m_FieldInfo.GetValue(m_Object); - if (m_FieldInfo.FieldType == typeof(System.Boolean)) - { - var boolVal = (System.Boolean)val; - writer[0] = boolVal ? 1.0f : 0.0f; - } - else - { - writer[0] = 0.0f; - } - return 1; - } - - /// - public byte[] GetCompressedObservation() - { - return null; - } - - /// - public void Update() { } - - /// - public void Reset() { } - - /// - public SensorCompressionType GetCompressionType() - { - return SensorCompressionType.None; - } - - /// - public string GetName() - { - return m_SensorName; - } - } -} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs deleted file mode 100644 index f187e7fad0..0000000000 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs +++ /dev/null @@ -1,78 +0,0 @@ -using System.Reflection; - -namespace Unity.MLAgents.Sensors.Reflection -{ - internal class AttributePropertySensor : ISensor - { - object m_Object; - PropertyInfo m_PropertyInfo; - // Not currently used, but might want later. - ObservableAttribute m_ObservableAttribute; - - string m_SensorName; - int[] m_Shape; - - public AttributePropertySensor(object o, PropertyInfo propertyInfo, ObservableAttribute observableAttribute) - { - m_Object = o; - m_PropertyInfo = propertyInfo; - m_ObservableAttribute = observableAttribute; - - m_SensorName = $"ObservableAttribute:{propertyInfo.DeclaringType.Name}.{propertyInfo.Name}"; - // TODO handle scalar, quaternion, blittable(?) - m_Shape = new [] {3}; - } - - /// - public int[] GetObservationShape() - { - return m_Shape; - } - - /// - public int Write(ObservationWriter writer) - { - // TODO make Delegate in ctor instead? - var val = m_PropertyInfo.GetMethod.Invoke(m_Object, null); - - if (m_PropertyInfo.PropertyType == typeof(UnityEngine.Vector3)) - { - var vec3Val = (UnityEngine.Vector3)val; - writer[0] = vec3Val.x; - writer[1] = vec3Val.y; - writer[2] = vec3Val.z; - } - else - { - writer[0] = 0.0f; - writer[1] = 0.0f; - writer[2] = 0.0f; - } - return 1; - } - - /// - public byte[] GetCompressedObservation() - { - return null; - } - - /// - public void Update() { } - - /// - public void Reset() { } - - /// - public SensorCompressionType GetCompressionType() - { - return SensorCompressionType.None; - } - - /// - public string GetName() - { - return m_SensorName; - } - } -} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs.meta deleted file mode 100644 index c7d0b0fcf9..0000000000 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributePropertySensor.cs.meta +++ /dev/null @@ -1,11 +0,0 @@ -fileFormatVersion: 2 -guid: 4a0bc726219994fbe95fbec261df482b -MonoImporter: - externalObjects: {} - serializedVersion: 2 - defaultReferences: [] - executionOrder: 0 - icon: {instanceID: 0} - userData: - assetBundleName: - assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs index c90c8e8033..748762cbba 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs @@ -1,28 +1,15 @@ -using System.Reflection; - namespace Unity.MLAgents.Sensors.Reflection { internal class IntReflectionSensor : ReflectionSensorBase { - internal IntReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute, string sensorName) - : base(o, fieldInfo, propertyInfo, observableAttribute, 1, sensorName) + internal IntReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 1) {} internal override void WriteReflectedField(ObservationWriter writer) { - if (m_FieldInfo != null) - { - var val = m_FieldInfo.GetValue(m_Object); - var intVal = (System.Int32)val; - writer[0] = intVal; - } - else - { - // TODO form delegate in ctor - var val = m_PropertyInfo.GetMethod.Invoke(m_Object, null); - var intVal = (System.Int32)val; - writer[0] = intVal; - } + var intVal = (System.Int32)GetReflectedValue(); + writer[0] = intVal; } } } diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs index 4be762b301..79910d318b 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -70,10 +70,28 @@ internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, Pr sensorName = observableAttribute.m_Name; } - if (memberType == typeof(System.Int32)) + var reflectionSensorInfo = new ReflectionSensorInfo { - return new IntReflectionSensor(o, fieldInfo, propertyInfo, observableAttribute, sensorName); + Object = o, + FieldInfo = fieldInfo, + PropertyInfo = propertyInfo, + ObservableAttribute = observableAttribute, + SensorName = sensorName + }; + + if (memberType == typeof(Int32)) + { + return new IntReflectionSensor(reflectionSensorInfo); + } + if (memberType == typeof(UnityEngine.Vector3)) + { + return new Vector3ReflectionSensor(reflectionSensorInfo); } + // Int + // Bool + // Vector2 + // Vector4 + // Quaternion throw new UnityAgentsException($"Unsupported Observable type: {memberType.Name}"); diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs index 1c42328b7a..6419735087 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs @@ -2,6 +2,17 @@ namespace Unity.MLAgents.Sensors.Reflection { + // Construction info for a ReflectionSensorBase. + internal struct ReflectionSensorInfo + { + public object Object; + + public FieldInfo FieldInfo; + public PropertyInfo PropertyInfo; + public ObservableAttribute ObservableAttribute; + public string SensorName; + } + internal abstract class ReflectionSensorBase : ISensor { protected object m_Object; @@ -14,15 +25,15 @@ internal abstract class ReflectionSensorBase : ISensor string m_SensorName; int[] m_Shape; - public ReflectionSensorBase(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute, int size, string sensorName) + public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size) { // TODO 2 constructors? - m_Object = o; - m_FieldInfo = fieldInfo; - m_PropertyInfo = propertyInfo; - m_ObservableAttribute = observableAttribute; - m_SensorName = sensorName; + m_Object = reflectionSensorInfo.Object; + m_FieldInfo = reflectionSensorInfo.FieldInfo; + m_PropertyInfo = reflectionSensorInfo.PropertyInfo; + m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute; + m_SensorName = reflectionSensorInfo.SensorName; m_Shape = new [] {size}; } @@ -41,6 +52,13 @@ public int Write(ObservationWriter writer) internal abstract void WriteReflectedField(ObservationWriter writer); + protected object GetReflectedValue() + { + return m_FieldInfo != null ? + m_FieldInfo.GetValue(m_Object) : + m_PropertyInfo.GetMethod.Invoke(m_Object, null); + } + /// public byte[] GetCompressedObservation() { diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs new file mode 100644 index 0000000000..1823277eec --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs @@ -0,0 +1,17 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + internal class Vector3ReflectionSensor : ReflectionSensorBase + { + internal Vector3ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 3) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var vecVal = (UnityEngine.Vector3)GetReflectedValue(); + writer[0] = vecVal.x; + writer[1] = vecVal.y; + writer[2] = vecVal.z; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta similarity index 83% rename from com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs.meta rename to com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta index 32e133d57a..771b690b07 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/AttributeFieldSensor.cs.meta +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: a477ee7a6cdf849b18de49c82b8f9c2e +guid: e756976ec2a0943cfbc0f97a6550a85b MonoImporter: externalObjects: {} serializedVersion: 2 diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs index 0b16627f18..19a40ed54a 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs @@ -24,6 +24,18 @@ public int IntProperty get => m_IntProperty; set => m_IntProperty = value; } + + [Observable("vector3member")] + public Vector3 m_Vector3Member; + + Vector3 m_VectorProperty; + + [Observable("vector3property")] + public Vector3 VectorProperty + { + get => m_VectorProperty; + set => m_VectorProperty = value; + } } [Test] @@ -32,9 +44,11 @@ public void TestGetObservableSensors() var testClass = new TestClass(); testClass.m_IntMember = 1; testClass.IntProperty = 2; + testClass.m_Vector3Member = new Vector3(30,31,32); + testClass.VectorProperty = new Vector3(33,34,35); var sensors = ObservableAttribute.GetObservableSensors(testClass); - Assert.AreEqual(sensors.Count, 2); + Assert.AreEqual(sensors.Count, 4); var sensorsByName = new Dictionary(); foreach (var sensor in sensors) @@ -44,6 +58,8 @@ public void TestGetObservableSensors() SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.m_IntMember"], new[] {1.0f}); SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.IntProperty"], new[] {2.0f}); + SensorTestHelper.CompareObservation(sensorsByName["vector3member"], new[] {30.0f, 31.0f, 32.0f}); + SensorTestHelper.CompareObservation(sensorsByName["vector3property"], new[] {33.0f, 34.0f, 35.0f}); } } From f5f529a797492263ed4f590a584e8e85e7f8d1eb Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 11 May 2020 14:37:02 -0700 Subject: [PATCH 04/23] add more types --- .../Reflection/BoolReflectionSensor.cs | 15 ++ .../Reflection/BoolReflectionSensor.cs.meta | 11 ++ .../Reflection/FloatReflectionSensor.cs | 15 ++ .../Reflection/FloatReflectionSensor.cs.meta | 11 ++ .../Sensors/Reflection/ObservableAttribute.cs | 41 ++++-- .../Reflection/QuaternionReflectionSensor.cs | 18 +++ .../QuaternionReflectionSensor.cs.meta | 11 ++ .../Reflection/Vector2ReflectionSensor.cs | 16 +++ .../Vector2ReflectionSensor.cs.meta | 11 ++ .../Reflection/Vector4ReflectionSensor.cs | 18 +++ .../Vector4ReflectionSensor.cs.meta | 11 ++ .../Editor/Sensor/ObservableAttributeTests.cs | 136 ++++++++++++++++-- 12 files changed, 291 insertions(+), 23 deletions(-) create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs create mode 100644 com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs new file mode 100644 index 0000000000..b71ebf4aa6 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs @@ -0,0 +1,15 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + internal class BoolReflectionSensor : ReflectionSensorBase + { + internal BoolReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 1) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var boolVal = (System.Boolean)GetReflectedValue(); + writer[0] = boolVal ? 1.0f : 0.0f; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta new file mode 100644 index 0000000000..5cac420f11 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: be795c90750a6420d93f569b69ddc1ba +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs new file mode 100644 index 0000000000..d34a7faeb3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs @@ -0,0 +1,15 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + internal class FloatReflectionSensor : ReflectionSensorBase + { + internal FloatReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 1) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var floatVal = (System.Single)GetReflectedValue(); + writer[0] = floatVal; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta new file mode 100644 index 0000000000..2de8b18c7c --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 51ed837d5b7cd44349287ac8066120fc +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs index 79910d318b..74c0970a5a 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -1,15 +1,17 @@ using System; using System.Collections.Generic; using System.Reflection; +using UnityEngine; namespace Unity.MLAgents.Sensors.Reflection { - [System.AttributeUsage(System.AttributeTargets.Field | System.AttributeTargets.Property)] - public class ObservableAttribute : System.Attribute + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] + public class ObservableAttribute : Attribute { - // Currently nothing here string m_Name; + const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic + public ObservableAttribute(string name=null) { m_Name = name; @@ -19,20 +21,20 @@ internal static List GetObservableSensors(object o) { var sensorsOut = new List(); - var fields = o.GetType().GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + var fields = o.GetType().GetFields(k_BindingFlags); foreach (var field in fields) { - var attr = (ObservableAttribute)Attribute.GetCustomAttribute(field, typeof(ObservableAttribute)); + var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute)); if (attr != null) { sensorsOut.Add(CreateReflectionSensor(o, field, null, attr)); } } - var properties = o.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + var properties = o.GetType().GetProperties(k_BindingFlags); foreach (var prop in properties) { - var attr = (ObservableAttribute)Attribute.GetCustomAttribute(prop, typeof(ObservableAttribute)); + var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); if (attr != null) { sensorsOut.Add(CreateReflectionSensor(o, null, prop, attr)); @@ -83,15 +85,30 @@ internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, Pr { return new IntReflectionSensor(reflectionSensorInfo); } + if (memberType == typeof(float)) + { + return new FloatReflectionSensor(reflectionSensorInfo); + } + if (memberType == typeof(bool)) + { + return new BoolReflectionSensor(reflectionSensorInfo); + } + if (memberType == typeof(UnityEngine.Vector2)) + { + return new Vector2ReflectionSensor(reflectionSensorInfo); + } if (memberType == typeof(UnityEngine.Vector3)) { return new Vector3ReflectionSensor(reflectionSensorInfo); } - // Int - // Bool - // Vector2 - // Vector4 - // Quaternion + if (memberType == typeof(UnityEngine.Vector4)) + { + return new Vector4ReflectionSensor(reflectionSensorInfo); + } + if (memberType == typeof(UnityEngine.Quaternion)) + { + return new QuaternionReflectionSensor(reflectionSensorInfo); + } throw new UnityAgentsException($"Unsupported Observable type: {memberType.Name}"); diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs new file mode 100644 index 0000000000..6100cf427e --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs @@ -0,0 +1,18 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + internal class QuaternionReflectionSensor : ReflectionSensorBase + { + internal QuaternionReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 4) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var quatVal = (UnityEngine.Quaternion)GetReflectedValue(); + writer[0] = quatVal.x; + writer[1] = quatVal.y; + writer[2] = quatVal.z; + writer[3] = quatVal.w; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta new file mode 100644 index 0000000000..f3970e6b51 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d38241d74074d459bb4590f7f5d16c80 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs new file mode 100644 index 0000000000..b5a142ee4d --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs @@ -0,0 +1,16 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + internal class Vector2ReflectionSensor : ReflectionSensorBase + { + internal Vector2ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 2) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var vecVal = (UnityEngine.Vector2)GetReflectedValue(); + writer[0] = vecVal.x; + writer[1] = vecVal.y; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta new file mode 100644 index 0000000000..2b78c25ffe --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: da06ff33f6f2d409cbf240cffa2ba0be +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs new file mode 100644 index 0000000000..f5adb657c3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs @@ -0,0 +1,18 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + internal class Vector4ReflectionSensor : ReflectionSensorBase + { + internal Vector4ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 4) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var vecVal = (UnityEngine.Vector4)GetReflectedValue(); + writer[0] = vecVal.x; + writer[1] = vecVal.y; + writer[2] = vecVal.z; + writer[3] = vecVal.w; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta new file mode 100644 index 0000000000..3d938af6c8 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 01d93aaa1b42b47b8960d303d7c498d3 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs index 19a40ed54a..5f3da79233 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs @@ -13,6 +13,9 @@ public class ObservableAttributeTests { class TestClass { + // + // Int + // [Observable] public int m_IntMember; @@ -25,16 +28,94 @@ public int IntProperty set => m_IntProperty = value; } - [Observable("vector3member")] + // + // Float + // + [Observable("floatMember")] + public float m_FloatMember; + + float m_FloatProperty; + [Observable("floatProperty")] + public float FloatProperty + { + get => m_FloatProperty; + set => m_FloatProperty = value; + } + + // + // Bool + // + [Observable("boolMember")] + public bool m_BoolMember; + + bool m_BoolProperty; + [Observable("boolProperty")] + public bool BoolProperty + { + get => m_BoolProperty; + set => m_BoolProperty = value; + } + + // + // Vector2 + // + + [Observable("vector2Member")] + public Vector2 m_Vector2Member; + + Vector2 m_Vector2Property; + + [Observable("vector2Property")] + public Vector2 Vector2Property + { + get => m_Vector2Property; + set => m_Vector2Property = value; + } + + // + // Vector3 + // + [Observable("vector3Member")] public Vector3 m_Vector3Member; - Vector3 m_VectorProperty; + Vector3 m_Vector3Property; - [Observable("vector3property")] - public Vector3 VectorProperty + [Observable("vector3Property")] + public Vector3 Vector3Property { - get => m_VectorProperty; - set => m_VectorProperty = value; + get => m_Vector3Property; + set => m_Vector3Property = value; + } + + // + // Vector4 + // + + [Observable("vector4Member")] + public Vector4 m_Vector4Member; + + Vector4 m_Vector4Property; + + [Observable("vector4Property")] + public Vector4 Vector4Property + { + get => m_Vector4Property; + set => m_Vector4Property = value; + } + + // + // Quaternion + // + [Observable("quaternionMember")] + public Quaternion m_QuaternionMember; + + Quaternion m_QuaternionProperty; + + [Observable("quaternionProperty")] + public Quaternion QuaternionProperty + { + get => m_QuaternionProperty; + set => m_QuaternionProperty = value; } } @@ -44,11 +125,29 @@ public void TestGetObservableSensors() var testClass = new TestClass(); testClass.m_IntMember = 1; testClass.IntProperty = 2; - testClass.m_Vector3Member = new Vector3(30,31,32); - testClass.VectorProperty = new Vector3(33,34,35); + + testClass.m_FloatMember = 1.1f; + testClass.FloatProperty = 1.2f; + + testClass.m_BoolMember = true; + testClass.BoolProperty = true; + + testClass.m_Vector2Member = new Vector2(2.0f, 2.1f); + testClass.Vector2Property = new Vector2(2.2f, 2.3f); + + testClass.m_Vector3Member = new Vector3(3.0f, 3.1f, 3.2f); + testClass.Vector3Property = new Vector3(3.3f, 3.4f, 3.5f); + + testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f); + testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f); + + testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f); + testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f); + + testClass.m_QuaternionMember = new Quaternion(5.0f, 5.1f, 5.2f, 5.3f); + testClass.QuaternionProperty = new Quaternion(5.4f, 5.5f, 5.5f, 5.7f); var sensors = ObservableAttribute.GetObservableSensors(testClass); - Assert.AreEqual(sensors.Count, 4); var sensorsByName = new Dictionary(); foreach (var sensor in sensors) @@ -58,9 +157,24 @@ public void TestGetObservableSensors() SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.m_IntMember"], new[] {1.0f}); SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.IntProperty"], new[] {2.0f}); - SensorTestHelper.CompareObservation(sensorsByName["vector3member"], new[] {30.0f, 31.0f, 32.0f}); - SensorTestHelper.CompareObservation(sensorsByName["vector3property"], new[] {33.0f, 34.0f, 35.0f}); + SensorTestHelper.CompareObservation(sensorsByName["floatMember"], new[] {1.1f}); + SensorTestHelper.CompareObservation(sensorsByName["floatProperty"], new[] {1.2f}); + + SensorTestHelper.CompareObservation(sensorsByName["boolMember"], new[] {1.0f}); + SensorTestHelper.CompareObservation(sensorsByName["boolProperty"], new[] {1.0f}); + + SensorTestHelper.CompareObservation(sensorsByName["vector2Member"], new[] {2.0f, 2.1f}); + SensorTestHelper.CompareObservation(sensorsByName["vector2Property"], new[] {2.2f, 2.3f}); + + SensorTestHelper.CompareObservation(sensorsByName["vector3Member"], new[] {3.0f, 3.1f, 3.2f}); + SensorTestHelper.CompareObservation(sensorsByName["vector3Property"], new[] {3.3f, 3.4f, 3.5f}); + + SensorTestHelper.CompareObservation(sensorsByName["vector4Member"], new[] {4.0f, 4.1f, 4.2f, 4.3f}); + SensorTestHelper.CompareObservation(sensorsByName["vector4Property"], new[] {4.4f, 4.5f, 4.5f, 4.7f}); + + SensorTestHelper.CompareObservation(sensorsByName["quaternionMember"], new[] {5.0f, 5.1f, 5.2f, 5.3f}); + SensorTestHelper.CompareObservation(sensorsByName["quaternionProperty"], new[] {5.4f, 5.5f, 5.5f, 5.7f}); } } } From 19168fb9a5e379747c9e50f1307636bfbffe34cc Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 11 May 2020 16:04:50 -0700 Subject: [PATCH 05/23] account for observables in barracuda checks --- .../Scripts/FoodCollectorAgent.cs | 20 ++------ .../Editor/BehaviorParametersEditor.cs | 14 +++++- .../Inference/BarracudaModelParamLoader.cs | 39 +++++++++------ .../Sensors/Reflection/ObservableAttribute.cs | 47 ++++++++++++++++--- 4 files changed, 83 insertions(+), 37 deletions(-) diff --git a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs index 0ee610c1d1..9e11ad2c77 100644 --- a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs +++ b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs @@ -1,6 +1,5 @@ using UnityEngine; using Unity.MLAgents; -using Unity.MLAgents.Sensors; using Unity.MLAgents.Sensors.Reflection; public class FoodCollectorAgent : Agent @@ -33,7 +32,6 @@ public class FoodCollectorAgent : Agent public Material frozenMaterial; public GameObject myLaser; public bool contribute; - public bool useVectorObs; EnvironmentParameters m_ResetParams; @@ -47,22 +45,12 @@ public override void Initialize() } [Observable] - Vector3 localRbVelocity + Vector2 localRbVelocity { - get { return transform.InverseTransformDirection(m_AgentRb.velocity); } - } - - public override void CollectObservations(VectorSensor sensor) - { - if (useVectorObs) + get { - // TODO use Observable with localRbVelocity instead - var localVelocity = localRbVelocity; - sensor.AddObservation(localVelocity.x); - sensor.AddObservation(localVelocity.z); - // TODO replace with Observables - sensor.AddObservation(System.Convert.ToInt32(m_Frozen)); - sensor.AddObservation(System.Convert.ToInt32(m_Shoot)); + var rbVel = transform.InverseTransformDirection(m_AgentRb.velocity); + return new Vector2(rbVel.x, rbVel.z); } } diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs index 08b49a1826..7d71d109c1 100644 --- a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs +++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs @@ -2,6 +2,7 @@ using Unity.Barracuda; using Unity.MLAgents.Policies; using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; using UnityEngine; namespace Unity.MLAgents.Editor @@ -89,6 +90,7 @@ void DisplayFailedModelChecks() Model barracudaModel = null; var model = (NNModel)serializedObject.FindProperty("m_Model").objectReferenceValue; var behaviorParameters = (BehaviorParameters)target; + SensorComponent[] sensorComponents; if (behaviorParameters.UseChildSensors) { @@ -98,6 +100,15 @@ void DisplayFailedModelChecks() { sensorComponents = behaviorParameters.GetComponents(); } + + int observableSensorSizes = 0; + var agent = behaviorParameters.GetComponent(); + if (agent != null) + { + // TODO check for invalid types and add HelpBox's + observableSensorSizes = ObservableAttribute.GetTotalObservationSize(agent); + } + var brainParameters = behaviorParameters.BrainParameters; if (model != null) { @@ -106,7 +117,8 @@ void DisplayFailedModelChecks() if (brainParameters != null) { var failedChecks = Inference.BarracudaModelParamLoader.CheckModel( - barracudaModel, brainParameters, sensorComponents, behaviorParameters.BehaviorType + barracudaModel, brainParameters, sensorComponents, observableSensorSizes, + behaviorParameters.BehaviorType ); foreach (var check in failedChecks) { diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs index b83ed2a9a6..51ece14095 100644 --- a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs @@ -126,10 +126,12 @@ public static string[] GetOutputNames(Model model) /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// Attached sensor components + /// Sum of the sizes of all ObservableAttributes. /// BehaviorType or the Agent to check. /// The list the error messages of the checks that failed public static IEnumerable CheckModel(Model model, BrainParameters brainParameters, - SensorComponent[] sensorComponents, BehaviorType behaviorType = BehaviorType.Default) + SensorComponent[] sensorComponents, int observableAttributeTotalSize = 0, + BehaviorType behaviorType = BehaviorType.Default) { List failedModelChecks = new List(); if (model == null) @@ -182,7 +184,7 @@ public static IEnumerable CheckModel(Model model, BrainParameters brainP CheckOutputTensorPresence(model, memorySize)) ; failedModelChecks.AddRange( - CheckInputTensorShape(model, brainParameters, sensorComponents) + CheckInputTensorShape(model, brainParameters, sensorComponents, observableAttributeTotalSize) ); failedModelChecks.AddRange( CheckOutputTensorShape(model, brainParameters, isContinuous, actionSize) @@ -253,6 +255,7 @@ static IEnumerable CheckIntScalarPresenceHelper( /// Whether the model is expecting continuous or discrete control. /// /// Array of attached sensor components + /// Total size of ObservableAttributes /// /// A IEnumerable of string corresponding to the failed input presence checks. /// @@ -404,25 +407,27 @@ static string CheckVisualObsShape( /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// Attached sensors + /// Sum of the sizes of all ObservableAttributes. /// The list the error messages of the checks that failed static IEnumerable CheckInputTensorShape( - Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents) + Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents, + int observableAttributeTotalSize) { var failedModelChecks = new List(); var tensorTester = - new Dictionary>() + new Dictionary>() { {TensorNames.VectorObservationPlaceholder, CheckVectorObsShape}, {TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape}, - {TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs) => null)}, - {TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs) => null)}, - {TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs) => null)}, - {TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs) => null)}, + {TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)}, + {TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)}, + {TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)}, + {TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)}, }; foreach (var mem in model.memories) { - tensorTester[mem.input] = ((bp, tensor, scs) => null); + tensorTester[mem.input] = ((bp, tensor, scs, i) => null); } var visObsIndex = 0; @@ -434,7 +439,7 @@ static IEnumerable CheckInputTensorShape( continue; } tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] = - (bp, tensor, scs) => CheckVisualObsShape(tensor, sensorComponent); + (bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent); visObsIndex++; } @@ -452,7 +457,7 @@ static IEnumerable CheckInputTensorShape( else { var tester = tensorTester[tensor.name]; - var error = tester.Invoke(brainParameters, tensor, sensorComponents); + var error = tester.Invoke(brainParameters, tensor, sensorComponents, observableAttributeTotalSize); if (error != null) { failedModelChecks.Add(error); @@ -471,12 +476,14 @@ static IEnumerable CheckInputTensorShape( /// /// The tensor that is expected by the model /// Array of attached sensor components + /// Sum of the sizes of all ObservableAttributes. /// /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. /// static string CheckVectorObsShape( - BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents) + BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents, + int observableAttributeTotalSize) { var vecObsSizeBp = brainParameters.VectorObservationSize; var numStackedVector = brainParameters.NumStackedVectorObservations; @@ -491,6 +498,8 @@ static string CheckVectorObsShape( } } + totalVectorSensorSize += observableAttributeTotalSize; + if (vecObsSizeBp * numStackedVector + totalVectorSensorSize != totalVecObsSizeT) { var sensorSizes = ""; @@ -526,11 +535,13 @@ static string CheckVectorObsShape( /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// The tensor that is expected by the model - /// Array of attached sensor components + /// Array of attached sensor components (unused). + /// /// Sum of the sizes of all ObservableAttributes (unused). /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. static string CheckPreviousActionShape( - BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents) + BrainParameters brainParameters, TensorProxy tensorProxy, + SensorComponent[] sensorComponents, int observableAttributeTotalSize) { var numberActionsBp = brainParameters.VectorActionSize.Length; var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1]; diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs index 74c0970a5a..d731088982 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -10,7 +10,17 @@ public class ObservableAttribute : Attribute { string m_Name; - const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic + const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; + static Dictionary s_TypeSizes = new Dictionary() + { + {typeof(int), 1}, + {typeof(bool), 1}, + {typeof(float), 1}, + {typeof(Vector2), 2}, + {typeof(Vector3), 3}, + {typeof(Vector4), 4}, + {typeof(Quaternion), 4}, + }; public ObservableAttribute(string name=null) { @@ -45,7 +55,6 @@ internal static List GetObservableSensors(object o) internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute) { - MemberInfo memberInfo = fieldInfo != null ? (MemberInfo) fieldInfo : propertyInfo; string memberName; string declaringTypeName; Type memberType; @@ -93,19 +102,19 @@ internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, Pr { return new BoolReflectionSensor(reflectionSensorInfo); } - if (memberType == typeof(UnityEngine.Vector2)) + if (memberType == typeof(Vector2)) { return new Vector2ReflectionSensor(reflectionSensorInfo); } - if (memberType == typeof(UnityEngine.Vector3)) + if (memberType == typeof(Vector3)) { return new Vector3ReflectionSensor(reflectionSensorInfo); } - if (memberType == typeof(UnityEngine.Vector4)) + if (memberType == typeof(Vector4)) { return new Vector4ReflectionSensor(reflectionSensorInfo); } - if (memberType == typeof(UnityEngine.Quaternion)) + if (memberType == typeof(Quaternion)) { return new QuaternionReflectionSensor(reflectionSensorInfo); } @@ -114,6 +123,32 @@ internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, Pr } + internal static int GetTotalObservationSize(object o) + { + int sizeOut = 0; + + var fields = o.GetType().GetFields(k_BindingFlags); + foreach (var field in fields) + { + var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute)); + if (attr != null) + { + sizeOut += s_TypeSizes[field.FieldType]; + } + } + + var properties = o.GetType().GetProperties(k_BindingFlags); + foreach (var prop in properties) + { + var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); + if (attr != null) + { + sizeOut += s_TypeSizes[prop.PropertyType]; + } + } + return sizeOut; + } + } } From 38cae69f0f62a21ad6142a35be6d65fffe1136fb Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 11 May 2020 17:50:31 -0700 Subject: [PATCH 06/23] iterators for observable fields/props --- .../Editor/BehaviorParametersEditor.cs | 9 ++- .../Inference/BarracudaModelParamLoader.cs | 1 + .../Sensors/Reflection/ObservableAttribute.cs | 51 ++++++++++----- .../Editor/Sensor/ObservableAttributeTests.cs | 64 +++++++++++++++---- 4 files changed, 94 insertions(+), 31 deletions(-) diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs index 7d71d109c1..22d48828a2 100644 --- a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs +++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs @@ -1,3 +1,4 @@ +using System.Collections.Generic; using UnityEditor; using Unity.Barracuda; using Unity.MLAgents.Policies; @@ -105,8 +106,12 @@ void DisplayFailedModelChecks() var agent = behaviorParameters.GetComponent(); if (agent != null) { - // TODO check for invalid types and add HelpBox's - observableSensorSizes = ObservableAttribute.GetTotalObservationSize(agent); + List observableErrors = new List(); + observableSensorSizes = ObservableAttribute.GetTotalObservationSize(agent, observableErrors); + foreach (var check in observableErrors) + { + EditorGUILayout.HelpBox(check, MessageType.Warning); + } } var brainParameters = behaviorParameters.BrainParameters; diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs index 51ece14095..78b15ce62e 100644 --- a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs @@ -498,6 +498,7 @@ static string CheckVectorObsShape( } } + // TODO account for totalVectorSensorSize in error message totalVectorSensorSize += observableAttributeTotalSize; if (vecObsSizeBp * numStackedVector + totalVectorSensorSize != totalVecObsSizeT) diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs index d731088982..5b91b52436 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -27,29 +27,45 @@ public ObservableAttribute(string name=null) m_Name = name; } - internal static List GetObservableSensors(object o) + internal static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o) { - var sensorsOut = new List(); - var fields = o.GetType().GetFields(k_BindingFlags); foreach (var field in fields) { var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute)); if (attr != null) { - sensorsOut.Add(CreateReflectionSensor(o, field, null, attr)); + yield return (field, attr); } } + } + internal static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o) + { var properties = o.GetType().GetProperties(k_BindingFlags); foreach (var prop in properties) { var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); if (attr != null) { - sensorsOut.Add(CreateReflectionSensor(o, null, prop, attr)); + yield return (prop, attr); } } + } + + internal static List GetObservableSensors(object o) + { + var sensorsOut = new List(); + foreach (var (field, attr) in GetObservableFields(o)) + { + sensorsOut.Add(CreateReflectionSensor(o, field, null, attr)); + } + + foreach (var (prop, attr) in GetObservableProperties(o)) + { + sensorsOut.Add(CreateReflectionSensor(o, null, prop, attr)); + } + return sensorsOut; } @@ -123,29 +139,34 @@ internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, Pr } - internal static int GetTotalObservationSize(object o) + internal static int GetTotalObservationSize(object o, List errorsOut) { int sizeOut = 0; - - var fields = o.GetType().GetFields(k_BindingFlags); - foreach (var field in fields) + foreach (var (field, _) in GetObservableFields(o)) { - var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute)); - if (attr != null) + if (s_TypeSizes.ContainsKey(field.FieldType)) { sizeOut += s_TypeSizes[field.FieldType]; } + else + { + errorsOut.Add($"Unsupported Observable type {field.FieldType.Name} on field {field.Name}"); + } } - var properties = o.GetType().GetProperties(k_BindingFlags); - foreach (var prop in properties) + foreach (var (prop, _) in GetObservableProperties(o)) { - var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); - if (attr != null) + + if (s_TypeSizes.ContainsKey(prop.PropertyType)) { sizeOut += s_TypeSizes[prop.PropertyType]; } + else + { + errorsOut.Add($"Unsupported Observable type {prop.PropertyType.Name} on field {prop.Name}"); + } } + return sizeOut; } diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs index 5f3da79233..cd32d4fcf4 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs @@ -13,6 +13,10 @@ public class ObservableAttributeTests { class TestClass { + // Non-observables + int m_NonObservableInt; + float m_NonObservableFloat; + // // Int // @@ -155,26 +159,58 @@ public void TestGetObservableSensors() sensorsByName[sensor.GetName()] = sensor; } - SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.m_IntMember"], new[] {1.0f}); - SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.IntProperty"], new[] {2.0f}); + SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.m_IntMember"], new[] { 1.0f }); + SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.IntProperty"], new[] { 2.0f }); + + SensorTestHelper.CompareObservation(sensorsByName["floatMember"], new[] { 1.1f }); + SensorTestHelper.CompareObservation(sensorsByName["floatProperty"], new[] { 1.2f }); + + SensorTestHelper.CompareObservation(sensorsByName["boolMember"], new[] { 1.0f }); + SensorTestHelper.CompareObservation(sensorsByName["boolProperty"], new[] { 1.0f }); + + SensorTestHelper.CompareObservation(sensorsByName["vector2Member"], new[] { 2.0f, 2.1f }); + SensorTestHelper.CompareObservation(sensorsByName["vector2Property"], new[] { 2.2f, 2.3f }); + + SensorTestHelper.CompareObservation(sensorsByName["vector3Member"], new[] { 3.0f, 3.1f, 3.2f }); + SensorTestHelper.CompareObservation(sensorsByName["vector3Property"], new[] { 3.3f, 3.4f, 3.5f }); - SensorTestHelper.CompareObservation(sensorsByName["floatMember"], new[] {1.1f}); - SensorTestHelper.CompareObservation(sensorsByName["floatProperty"], new[] {1.2f}); + SensorTestHelper.CompareObservation(sensorsByName["vector4Member"], new[] { 4.0f, 4.1f, 4.2f, 4.3f }); + SensorTestHelper.CompareObservation(sensorsByName["vector4Property"], new[] { 4.4f, 4.5f, 4.5f, 4.7f }); - SensorTestHelper.CompareObservation(sensorsByName["boolMember"], new[] {1.0f}); - SensorTestHelper.CompareObservation(sensorsByName["boolProperty"], new[] {1.0f}); + SensorTestHelper.CompareObservation(sensorsByName["quaternionMember"], new[] { 5.0f, 5.1f, 5.2f, 5.3f }); + SensorTestHelper.CompareObservation(sensorsByName["quaternionProperty"], new[] { 5.4f, 5.5f, 5.5f, 5.7f }); - SensorTestHelper.CompareObservation(sensorsByName["vector2Member"], new[] {2.0f, 2.1f}); - SensorTestHelper.CompareObservation(sensorsByName["vector2Property"], new[] {2.2f, 2.3f}); + } + [Test] + public void TestGetTotalObservationSize() + { + var testClass = new TestClass(); + var errors = new List(); + var expectedObsSize = 2 * (1 + 1 + 1 + 2 + 3 + 4 + 4); + Assert.AreEqual(expectedObsSize, ObservableAttribute.GetTotalObservationSize(testClass, errors)); + Assert.AreEqual(0, errors.Count); + } - SensorTestHelper.CompareObservation(sensorsByName["vector3Member"], new[] {3.0f, 3.1f, 3.2f}); - SensorTestHelper.CompareObservation(sensorsByName["vector3Property"], new[] {3.3f, 3.4f, 3.5f}); + class BadClass + { + [Observable] + double m_Double; - SensorTestHelper.CompareObservation(sensorsByName["vector4Member"], new[] {4.0f, 4.1f, 4.2f, 4.3f}); - SensorTestHelper.CompareObservation(sensorsByName["vector4Property"], new[] {4.4f, 4.5f, 4.5f, 4.7f}); + [Observable] + double DoubleProperty + { + get => m_Double; + set => m_Double = value; + } + } - SensorTestHelper.CompareObservation(sensorsByName["quaternionMember"], new[] {5.0f, 5.1f, 5.2f, 5.3f}); - SensorTestHelper.CompareObservation(sensorsByName["quaternionProperty"], new[] {5.4f, 5.5f, 5.5f, 5.7f}); + [Test] + public void TestGetTotalObservationSizeErrors() + { + var bad = new BadClass(); + var errors = new List(); + Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, errors)); + Assert.AreEqual(2, errors.Count); } } } From 38e0f54e02550a7bf9b14882f258a43c798510e3 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Mon, 11 May 2020 18:28:17 -0700 Subject: [PATCH 07/23] stacking, fix obs size in prefab --- .../Prefabs/FoodCollectorArea.prefab | 70 ++++++++----------- .../Sensors/Reflection/ObservableAttribute.cs | 49 +++++++++---- .../Editor/Sensor/ObservableAttributeTests.cs | 25 +++++++ 3 files changed, 91 insertions(+), 53 deletions(-) diff --git a/Project/Assets/ML-Agents/Examples/FoodCollector/Prefabs/FoodCollectorArea.prefab b/Project/Assets/ML-Agents/Examples/FoodCollector/Prefabs/FoodCollectorArea.prefab index 34bf31e657..a02f2f4f5a 100644 --- a/Project/Assets/ML-Agents/Examples/FoodCollector/Prefabs/FoodCollectorArea.prefab +++ b/Project/Assets/ML-Agents/Examples/FoodCollector/Prefabs/FoodCollectorArea.prefab @@ -2178,11 +2178,11 @@ MonoBehaviour: m_Name: m_EditorClassIdentifier: m_BrainParameters: - vectorObservationSize: 4 - numStackedVectorObservations: 1 - vectorActionSize: 03000000030000000300000002000000 - vectorActionDescriptions: [] - vectorActionSpaceType: 0 + VectorObservationSize: 0 + NumStackedVectorObservations: 1 + VectorActionSize: 03000000030000000300000002000000 + VectorActionDescriptions: [] + VectorActionSpaceType: 0 m_Model: {fileID: 11400000, guid: 36ab3e93020504f48858d0856f939685, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -2204,7 +2204,7 @@ MonoBehaviour: agentParameters: maxStep: 0 hasUpgradedFromAgentParameters: 1 - maxStep: 5000 + MaxStep: 5000 area: {fileID: 1819751139121548} turnSpeed: 300 moveSpeed: 2 @@ -2214,7 +2214,6 @@ MonoBehaviour: frozenMaterial: {fileID: 2100000, guid: 66163cf35956a4be08e801b750c26f33, type: 2} myLaser: {fileID: 1081721624670010} contribute: 0 - useVectorObs: 1 --- !u!114 &114725457980523372 MonoBehaviour: m_ObjectHideFlags: 0 @@ -2260,7 +2259,6 @@ MonoBehaviour: m_EditorClassIdentifier: DecisionPeriod: 5 TakeActionsBetweenDecisions: 1 - offsetStep: 0 --- !u!114 &1222199865870203693 MonoBehaviour: m_ObjectHideFlags: 0 @@ -2517,11 +2515,11 @@ MonoBehaviour: m_Name: m_EditorClassIdentifier: m_BrainParameters: - vectorObservationSize: 4 - numStackedVectorObservations: 1 - vectorActionSize: 03000000030000000300000002000000 - vectorActionDescriptions: [] - vectorActionSpaceType: 0 + VectorObservationSize: 0 + NumStackedVectorObservations: 1 + VectorActionSize: 03000000030000000300000002000000 + VectorActionDescriptions: [] + VectorActionSpaceType: 0 m_Model: {fileID: 11400000, guid: 36ab3e93020504f48858d0856f939685, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -2543,7 +2541,7 @@ MonoBehaviour: agentParameters: maxStep: 0 hasUpgradedFromAgentParameters: 1 - maxStep: 5000 + MaxStep: 5000 area: {fileID: 1819751139121548} turnSpeed: 300 moveSpeed: 2 @@ -2553,7 +2551,6 @@ MonoBehaviour: frozenMaterial: {fileID: 2100000, guid: 66163cf35956a4be08e801b750c26f33, type: 2} myLaser: {fileID: 1941433838307300} contribute: 0 - useVectorObs: 1 --- !u!114 &114443152683847924 MonoBehaviour: m_ObjectHideFlags: 0 @@ -2599,7 +2596,6 @@ MonoBehaviour: m_EditorClassIdentifier: DecisionPeriod: 5 TakeActionsBetweenDecisions: 1 - offsetStep: 0 --- !u!1 &1528397385587768 GameObject: m_ObjectHideFlags: 0 @@ -2848,11 +2844,11 @@ MonoBehaviour: m_Name: m_EditorClassIdentifier: m_BrainParameters: - vectorObservationSize: 4 - numStackedVectorObservations: 1 - vectorActionSize: 03000000030000000300000002000000 - vectorActionDescriptions: [] - vectorActionSpaceType: 0 + VectorObservationSize: 0 + NumStackedVectorObservations: 1 + VectorActionSize: 03000000030000000300000002000000 + VectorActionDescriptions: [] + VectorActionSpaceType: 0 m_Model: {fileID: 11400000, guid: 36ab3e93020504f48858d0856f939685, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -2874,7 +2870,7 @@ MonoBehaviour: agentParameters: maxStep: 0 hasUpgradedFromAgentParameters: 1 - maxStep: 5000 + MaxStep: 5000 area: {fileID: 1819751139121548} turnSpeed: 300 moveSpeed: 2 @@ -2884,7 +2880,6 @@ MonoBehaviour: frozenMaterial: {fileID: 2100000, guid: 66163cf35956a4be08e801b750c26f33, type: 2} myLaser: {fileID: 1421240237750412} contribute: 0 - useVectorObs: 1 --- !u!114 &114986980423924774 MonoBehaviour: m_ObjectHideFlags: 0 @@ -2930,7 +2925,6 @@ MonoBehaviour: m_EditorClassIdentifier: DecisionPeriod: 5 TakeActionsBetweenDecisions: 1 - offsetStep: 0 --- !u!1 &1617924810425504 GameObject: m_ObjectHideFlags: 0 @@ -3442,11 +3436,11 @@ MonoBehaviour: m_Name: m_EditorClassIdentifier: m_BrainParameters: - vectorObservationSize: 4 - numStackedVectorObservations: 1 - vectorActionSize: 03000000030000000300000002000000 - vectorActionDescriptions: [] - vectorActionSpaceType: 0 + VectorObservationSize: 0 + NumStackedVectorObservations: 1 + VectorActionSize: 03000000030000000300000002000000 + VectorActionDescriptions: [] + VectorActionSpaceType: 0 m_Model: {fileID: 11400000, guid: 36ab3e93020504f48858d0856f939685, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -3468,7 +3462,7 @@ MonoBehaviour: agentParameters: maxStep: 0 hasUpgradedFromAgentParameters: 1 - maxStep: 5000 + MaxStep: 5000 area: {fileID: 1819751139121548} turnSpeed: 300 moveSpeed: 2 @@ -3478,7 +3472,6 @@ MonoBehaviour: frozenMaterial: {fileID: 2100000, guid: 66163cf35956a4be08e801b750c26f33, type: 2} myLaser: {fileID: 1617924810425504} contribute: 0 - useVectorObs: 1 --- !u!114 &114644889237473510 MonoBehaviour: m_ObjectHideFlags: 0 @@ -3524,7 +3517,6 @@ MonoBehaviour: m_EditorClassIdentifier: DecisionPeriod: 5 TakeActionsBetweenDecisions: 1 - offsetStep: 0 --- !u!1 &1688105343773098 GameObject: m_ObjectHideFlags: 0 @@ -3759,11 +3751,11 @@ MonoBehaviour: m_Name: m_EditorClassIdentifier: m_BrainParameters: - vectorObservationSize: 4 - numStackedVectorObservations: 1 - vectorActionSize: 03000000030000000300000002000000 - vectorActionDescriptions: [] - vectorActionSpaceType: 0 + VectorObservationSize: 0 + NumStackedVectorObservations: 1 + VectorActionSize: 03000000030000000300000002000000 + VectorActionDescriptions: [] + VectorActionSpaceType: 0 m_Model: {fileID: 11400000, guid: 36ab3e93020504f48858d0856f939685, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -3785,7 +3777,7 @@ MonoBehaviour: agentParameters: maxStep: 0 hasUpgradedFromAgentParameters: 1 - maxStep: 5000 + MaxStep: 5000 area: {fileID: 1819751139121548} turnSpeed: 300 moveSpeed: 2 @@ -3795,7 +3787,6 @@ MonoBehaviour: frozenMaterial: {fileID: 2100000, guid: 66163cf35956a4be08e801b750c26f33, type: 2} myLaser: {fileID: 1045923826166930} contribute: 0 - useVectorObs: 1 --- !u!114 &114276061479012222 MonoBehaviour: m_ObjectHideFlags: 0 @@ -3841,7 +3832,6 @@ MonoBehaviour: m_EditorClassIdentifier: DecisionPeriod: 5 TakeActionsBetweenDecisions: 1 - offsetStep: 0 --- !u!1 &1729825611722018 GameObject: m_ObjectHideFlags: 0 diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs index 5b91b52436..d379aeb511 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -9,6 +9,7 @@ namespace Unity.MLAgents.Sensors.Reflection public class ObservableAttribute : Attribute { string m_Name; + int m_NumStackedObservations; const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; static Dictionary s_TypeSizes = new Dictionary() @@ -22,9 +23,16 @@ public class ObservableAttribute : Attribute {typeof(Quaternion), 4}, }; - public ObservableAttribute(string name=null) + /// + /// ObservableAttribute constructor. + /// + /// Optional override for the sensor name. Note that all sensors for an Agent + /// must have a unique name. + /// Number of frames to concatenate observations from. + public ObservableAttribute(string name=null, int numStackedObservations=1) { m_Name = name; + m_NumStackedObservations = numStackedObservations; } internal static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o) @@ -106,47 +114,62 @@ internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, Pr SensorName = sensorName }; + ISensor sensor = null; if (memberType == typeof(Int32)) { - return new IntReflectionSensor(reflectionSensorInfo); + sensor = new IntReflectionSensor(reflectionSensorInfo); } if (memberType == typeof(float)) { - return new FloatReflectionSensor(reflectionSensorInfo); + sensor = new FloatReflectionSensor(reflectionSensorInfo); } if (memberType == typeof(bool)) { - return new BoolReflectionSensor(reflectionSensorInfo); + sensor = new BoolReflectionSensor(reflectionSensorInfo); } if (memberType == typeof(Vector2)) { - return new Vector2ReflectionSensor(reflectionSensorInfo); + sensor = new Vector2ReflectionSensor(reflectionSensorInfo); } if (memberType == typeof(Vector3)) { - return new Vector3ReflectionSensor(reflectionSensorInfo); + sensor = new Vector3ReflectionSensor(reflectionSensorInfo); } if (memberType == typeof(Vector4)) { - return new Vector4ReflectionSensor(reflectionSensorInfo); + sensor = new Vector4ReflectionSensor(reflectionSensorInfo); } if (memberType == typeof(Quaternion)) { - return new QuaternionReflectionSensor(reflectionSensorInfo); + sensor = new QuaternionReflectionSensor(reflectionSensorInfo); } - throw new UnityAgentsException($"Unsupported Observable type: {memberType.Name}"); + if (sensor == null) + { + throw new UnityAgentsException($"Unsupported Observable type: {memberType.Name}"); + } + + // Wrap the base sensor in a StackingSensor if we're using stacking. + if (observableAttribute.m_NumStackedObservations > 1) + { + return new StackingSensor(sensor, observableAttribute.m_NumStackedObservations); + } + + return sensor; + + + } internal static int GetTotalObservationSize(object o, List errorsOut) { int sizeOut = 0; - foreach (var (field, _) in GetObservableFields(o)) + foreach (var (field, attr) in GetObservableFields(o)) { if (s_TypeSizes.ContainsKey(field.FieldType)) { - sizeOut += s_TypeSizes[field.FieldType]; + sizeOut += s_TypeSizes[field.FieldType] * attr.m_NumStackedObservations; } else { @@ -154,12 +177,12 @@ internal static int GetTotalObservationSize(object o, List errorsOut) } } - foreach (var (prop, _) in GetObservableProperties(o)) + foreach (var (prop, attr) in GetObservableProperties(o)) { if (s_TypeSizes.ContainsKey(prop.PropertyType)) { - sizeOut += s_TypeSizes[prop.PropertyType]; + sizeOut += s_TypeSizes[prop.PropertyType] * attr.m_NumStackedObservations; } else { diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs index cd32d4fcf4..0fcc48963e 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs @@ -212,5 +212,30 @@ public void TestGetTotalObservationSizeErrors() Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, errors)); Assert.AreEqual(2, errors.Count); } + + class StackingClass + { + [Observable(numStackedObservations: 2)] + public float FloatVal; + } + + [Test] + public void TestObservableAttributeStacking() + { + var c = new StackingClass(); + c.FloatVal = 1.0f; + var sensors = ObservableAttribute.GetObservableSensors(c); + var sensor = sensors[0]; + Assert.AreEqual(typeof(StackingSensor), sensor.GetType()); + SensorTestHelper.CompareObservation(sensor, new[] { 0.0f, 1.0f }); + + sensor.Update(); + c.FloatVal = 3.0f; + SensorTestHelper.CompareObservation(sensor, new[] { 1.0f, 3.0f }); + + var errors = new List(); + Assert.AreEqual(2, ObservableAttribute.GetTotalObservationSize(c, errors)); + Assert.AreEqual(0, errors.Count); + } } } From 6748abb1186699f447b9f554fff95477418625b6 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 12 May 2020 17:28:57 -0700 Subject: [PATCH 08/23] use DeclaredOnly to filter members --- .../Editor/BehaviorParametersEditor.cs | 2 +- com.unity.ml-agents/Runtime/Agent.cs | 35 +++++--------- .../Sensors/Reflection/ObservableAttribute.cs | 27 +++++------ .../Editor/Sensor/ObservableAttributeTests.cs | 48 +++++++++++++++++-- 4 files changed, 68 insertions(+), 44 deletions(-) diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs index 22d48828a2..c6edb45187 100644 --- a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs +++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs @@ -107,7 +107,7 @@ void DisplayFailedModelChecks() if (agent != null) { List observableErrors = new List(); - observableSensorSizes = ObservableAttribute.GetTotalObservationSize(agent, observableErrors); + observableSensorSizes = ObservableAttribute.GetTotalObservationSize(agent, false, observableErrors); foreach (var check in observableErrors) { EditorGUILayout.HelpBox(check, MessageType.Warning); diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 61950e6c0a..b166de585a 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -397,7 +397,11 @@ public void LazyInitialize() m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic); ResetData(); Initialize(); - InitializeSensors(); + + using (TimerStack.Instance.Scoped("InitializeSensors")) + { + InitializeSensors(); + } // The first time the Academy resets, all Agents in the scene will be // forced to reset through the event. @@ -818,29 +822,12 @@ public virtual void Heuristic(float[] actionsOut) /// internal void InitializeSensors() { -// // Iterate over Observables -// var fields = this.GetType().GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); -// foreach (var field in fields) -// { -// var attr = (ObservableAttribute)Attribute.GetCustomAttribute(field, typeof(ObservableAttribute)); -// if (attr != null) -// { -// sensors.Add(new AttributeFieldSensor(this, field, attr)); -// } -// } -// -// var properties = this.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); -// foreach (var prop in properties) -// { -// var attr = (ObservableAttribute)Attribute.GetCustomAttribute(prop, typeof(ObservableAttribute)); -// if (attr != null) -// { -// sensors.Add(new AttributePropertySensor(this, prop, attr)); -// } -// } - - var observableSensors = ObservableAttribute.GetObservableSensors(this); - sensors.AddRange(observableSensors); + using (TimerStack.Instance.Scoped("GetObservableSensors")) + { + // TODO enum for whether or not to search full hierarchy + var observableSensors = ObservableAttribute.GetObservableSensors(this, false); + sensors.AddRange(observableSensors); + } // Get all attached sensor components SensorComponent[] attachedSensorComponents; diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs index d379aeb511..c7335b1fb1 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -35,9 +35,10 @@ public ObservableAttribute(string name=null, int numStackedObservations=1) m_NumStackedObservations = numStackedObservations; } - internal static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o) + internal static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool declaredOnly) { - var fields = o.GetType().GetFields(k_BindingFlags); + var bindingFlags = k_BindingFlags | (declaredOnly ? BindingFlags.DeclaredOnly : 0); + var fields = o.GetType().GetFields(bindingFlags); foreach (var field in fields) { var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute)); @@ -48,9 +49,11 @@ public ObservableAttribute(string name=null, int numStackedObservations=1) } } - internal static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o) + internal static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool declaredOnly) { - var properties = o.GetType().GetProperties(k_BindingFlags); + var bindingFlags = k_BindingFlags | (declaredOnly ? BindingFlags.DeclaredOnly : 0); + // TODO check PropertyInfo.CanRead or filter via bindingFlag + var properties = o.GetType().GetProperties(bindingFlags); foreach (var prop in properties) { var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); @@ -61,15 +64,15 @@ public ObservableAttribute(string name=null, int numStackedObservations=1) } } - internal static List GetObservableSensors(object o) + internal static List GetObservableSensors(object o, bool declaredOnly) { var sensorsOut = new List(); - foreach (var (field, attr) in GetObservableFields(o)) + foreach (var (field, attr) in GetObservableFields(o, declaredOnly)) { sensorsOut.Add(CreateReflectionSensor(o, field, null, attr)); } - foreach (var (prop, attr) in GetObservableProperties(o)) + foreach (var (prop, attr) in GetObservableProperties(o, declaredOnly)) { sensorsOut.Add(CreateReflectionSensor(o, null, prop, attr)); } @@ -156,16 +159,12 @@ internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, Pr } return sensor; - - - - } - internal static int GetTotalObservationSize(object o, List errorsOut) + internal static int GetTotalObservationSize(object o, bool declaredOnly, List errorsOut) { int sizeOut = 0; - foreach (var (field, attr) in GetObservableFields(o)) + foreach (var (field, attr) in GetObservableFields(o, declaredOnly)) { if (s_TypeSizes.ContainsKey(field.FieldType)) { @@ -177,7 +176,7 @@ internal static int GetTotalObservationSize(object o, List errorsOut) } } - foreach (var (prop, attr) in GetObservableProperties(o)) + foreach (var (prop, attr) in GetObservableProperties(o, declaredOnly)) { if (s_TypeSizes.ContainsKey(prop.PropertyType)) diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs index 0fcc48963e..31d7395280 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs @@ -151,7 +151,7 @@ public void TestGetObservableSensors() testClass.m_QuaternionMember = new Quaternion(5.0f, 5.1f, 5.2f, 5.3f); testClass.QuaternionProperty = new Quaternion(5.4f, 5.5f, 5.5f, 5.7f); - var sensors = ObservableAttribute.GetObservableSensors(testClass); + var sensors = ObservableAttribute.GetObservableSensors(testClass, false); var sensorsByName = new Dictionary(); foreach (var sensor in sensors) @@ -187,7 +187,7 @@ public void TestGetTotalObservationSize() var testClass = new TestClass(); var errors = new List(); var expectedObsSize = 2 * (1 + 1 + 1 + 2 + 3 + 4 + 4); - Assert.AreEqual(expectedObsSize, ObservableAttribute.GetTotalObservationSize(testClass, errors)); + Assert.AreEqual(expectedObsSize, ObservableAttribute.GetTotalObservationSize(testClass, false, errors)); Assert.AreEqual(0, errors.Count); } @@ -202,6 +202,7 @@ double DoubleProperty get => m_Double; set => m_Double = value; } + // TODO handle a set-only [Observable] property } [Test] @@ -209,7 +210,7 @@ public void TestGetTotalObservationSizeErrors() { var bad = new BadClass(); var errors = new List(); - Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, errors)); + Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, false, errors)); Assert.AreEqual(2, errors.Count); } @@ -224,7 +225,7 @@ public void TestObservableAttributeStacking() { var c = new StackingClass(); c.FloatVal = 1.0f; - var sensors = ObservableAttribute.GetObservableSensors(c); + var sensors = ObservableAttribute.GetObservableSensors(c, false); var sensor = sensors[0]; Assert.AreEqual(typeof(StackingSensor), sensor.GetType()); SensorTestHelper.CompareObservation(sensor, new[] { 0.0f, 1.0f }); @@ -234,8 +235,45 @@ public void TestObservableAttributeStacking() SensorTestHelper.CompareObservation(sensor, new[] { 1.0f, 3.0f }); var errors = new List(); - Assert.AreEqual(2, ObservableAttribute.GetTotalObservationSize(c, errors)); + Assert.AreEqual(2, ObservableAttribute.GetTotalObservationSize(c, false, errors)); Assert.AreEqual(0, errors.Count); } + + class BaseClass + { + [Observable("base")] + protected float m_BaseField; + + [Observable("private")] + float m_PrivateField; + } + + class DerivedClass : BaseClass + { + [Observable("derived")] + float m_DerivedField; + } + + [Test] + public void TestObservableAttributeDeclaredOnly() + { + var d = new DerivedClass(); + + // declaredOnly=false will get fields in the derived class, plus public and protected inherited fields + var sensorAll = ObservableAttribute.GetObservableSensors(d, false); + Assert.AreEqual(2, sensorAll.Count); + // Note - actual order doesn't matter here, we can change this to use a HashSet if neeed. + Assert.AreEqual("derived", sensorAll[0].GetName()); + Assert.AreEqual("base", sensorAll[1].GetName()); + + // declaredOnly=true will only get fields in the derived class + var sensorsDerivedOnly = ObservableAttribute.GetObservableSensors(d, true); + Assert.AreEqual(1, sensorsDerivedOnly.Count); + Assert.AreEqual("derived", sensorsDerivedOnly[0].GetName()); + + var b = new BaseClass(); + var baseSensors = ObservableAttribute.GetObservableSensors(b, false); + Assert.AreEqual(2, baseSensors.Count); + } } } From ce98062ff5d227fe57efc84cbb03c1c4d70a3a9c Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 12 May 2020 17:59:23 -0700 Subject: [PATCH 09/23] ignore write-only properties --- .../Runtime/Sensors/Reflection/ObservableAttribute.cs | 7 ++++++- .../Tests/Editor/Sensor/ObservableAttributeTests.cs | 11 ++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs index c7335b1fb1..7057094c6d 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -37,6 +37,7 @@ public ObservableAttribute(string name=null, int numStackedObservations=1) internal static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool declaredOnly) { + // TODO cache these (and properties) by type, so that we only have to reflect once. var bindingFlags = k_BindingFlags | (declaredOnly ? BindingFlags.DeclaredOnly : 0); var fields = o.GetType().GetFields(bindingFlags); foreach (var field in fields) @@ -52,10 +53,14 @@ public ObservableAttribute(string name=null, int numStackedObservations=1) internal static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool declaredOnly) { var bindingFlags = k_BindingFlags | (declaredOnly ? BindingFlags.DeclaredOnly : 0); - // TODO check PropertyInfo.CanRead or filter via bindingFlag var properties = o.GetType().GetProperties(bindingFlags); foreach (var prop in properties) { + if (!prop.CanRead) + { + // Ignore write-only properties. + continue; + } var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); if (attr != null) { diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs index 31d7395280..5868b69e62 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs @@ -202,13 +202,22 @@ double DoubleProperty get => m_Double; set => m_Double = value; } - // TODO handle a set-only [Observable] property + + float m_WriteOnlyProperty; + + [Observable] + // No get property, so we shouldn't be able to make a sensor out of this. + public float WriteOnlyProperty + { + set => m_WriteOnlyProperty = value; + } } [Test] public void TestGetTotalObservationSizeErrors() { var bad = new BadClass(); + bad.WriteOnlyProperty = 1.0f; var errors = new List(); Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, false, errors)); Assert.AreEqual(2, errors.Count); From c203926c5a1e4d64708a97ed85f7a4e464717881 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 12 May 2020 18:05:58 -0700 Subject: [PATCH 10/23] fix error message --- .../Runtime/Inference/BarracudaModelParamLoader.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs index 78b15ce62e..8d8a4bc092 100644 --- a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs @@ -498,7 +498,6 @@ static string CheckVectorObsShape( } } - // TODO account for totalVectorSensorSize in error message totalVectorSensorSize += observableAttributeTotalSize; if (vecObsSizeBp * numStackedVector + totalVectorSensorSize != totalVecObsSizeT) @@ -522,7 +521,9 @@ static string CheckVectorObsShape( sensorSizes += "]"; return $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " + - $"but received {vecObsSizeBp} x {numStackedVector} vector observations and " + + $"but received: \n" + + $"Vector observations: {vecObsSizeBp} x {numStackedVector}\n" + + $"Total [Observable] attributes: {observableAttributeTotalSize}\n"+ $"SensorComponent sizes: {sensorSizes}."; } return null; From 51c72721939aaf2dbfbe5474796892591fae112a Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 12 May 2020 18:44:51 -0700 Subject: [PATCH 11/23] docstrings --- .../Runtime/Policies/BrainParameters.cs | 6 +-- .../Sensors/Reflection/ObservableAttribute.cs | 53 +++++++++++++++---- .../Editor/Sensor/ObservableAttributeTests.cs | 3 +- 3 files changed, 48 insertions(+), 14 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs b/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs index e427ead794..2141cc0e49 100644 --- a/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs +++ b/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs @@ -34,11 +34,9 @@ public enum SpaceType public class BrainParameters { /// - /// The size of the observation space. - /// - /// An agent creates the observation vector in its + /// The number of the observations that are added in /// - /// implementation. + /// /// /// The length of the vector containing observation values. /// diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs index 7057094c6d..0d0dd5de98 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -5,6 +5,44 @@ namespace Unity.MLAgents.Sensors.Reflection { + /// + /// Specify that a field or property should be used to generate observations for an Agent. + /// For each field or property that uses ObservableAttribute, a corresponding + /// will be created during Agent initialization, and this + /// sensor will read the values during training and inference. + /// + /// + /// ObservableAttribute is intended to make initial setup of an Agent easier. Because it + /// uses reflection to read the values of fields and properties at runtime, this may + /// be much slower than reading the values directly. If the performance of + /// ObservableAttribute is an issue, you can get the same functionality by overriding + /// or creating a custom + /// implementation to read the values without reflection. + /// + /// Note that you do not need to adjust the VectorObservationSize in + /// when adding ObservableAttribute + /// to fields or properties. + /// + /// + /// This sample class will produce two observations, one for the m_Health field, and one + /// for the HealthPercent property. + /// + /// using Unity.MLAgents; + /// using Unity.MLAgents.Sensors.Reflection; + /// + /// public class MyAgent : Agent + /// { + /// [Observable] + /// int m_Health; + /// + /// [Observable] + /// float HealthPercent + /// { + /// get => return 100.0f * m_Health / float(m_MaxHealth); + /// } + /// } + /// + /// [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] public class ObservableAttribute : Attribute { @@ -29,7 +67,7 @@ public class ObservableAttribute : Attribute /// Optional override for the sensor name. Note that all sensors for an Agent /// must have a unique name. /// Number of frames to concatenate observations from. - public ObservableAttribute(string name=null, int numStackedObservations=1) + public ObservableAttribute(string name = null, int numStackedObservations = 1) { m_Name = name; m_NumStackedObservations = numStackedObservations; @@ -72,12 +110,12 @@ public ObservableAttribute(string name=null, int numStackedObservations=1) internal static List GetObservableSensors(object o, bool declaredOnly) { var sensorsOut = new List(); - foreach (var (field, attr) in GetObservableFields(o, declaredOnly)) + foreach (var(field, attr) in GetObservableFields(o, declaredOnly)) { sensorsOut.Add(CreateReflectionSensor(o, field, null, attr)); } - foreach (var (prop, attr) in GetObservableProperties(o, declaredOnly)) + foreach (var(prop, attr) in GetObservableProperties(o, declaredOnly)) { sensorsOut.Add(CreateReflectionSensor(o, null, prop, attr)); } @@ -129,7 +167,7 @@ internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, Pr } if (memberType == typeof(float)) { - sensor = new FloatReflectionSensor(reflectionSensorInfo); + sensor = new FloatReflectionSensor(reflectionSensorInfo); } if (memberType == typeof(bool)) { @@ -169,7 +207,7 @@ internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, Pr internal static int GetTotalObservationSize(object o, bool declaredOnly, List errorsOut) { int sizeOut = 0; - foreach (var (field, attr) in GetObservableFields(o, declaredOnly)) + foreach (var(field, attr) in GetObservableFields(o, declaredOnly)) { if (s_TypeSizes.ContainsKey(field.FieldType)) { @@ -181,9 +219,8 @@ internal static int GetTotalObservationSize(object o, bool declaredOnly, List Date: Wed, 13 May 2020 10:17:23 -0700 Subject: [PATCH 12/23] agent enum (WIP) --- com.unity.ml-agents/Runtime/Agent.cs | 27 +++++++++++++++++++ .../Communicator/GrpcExtensionsTests.cs.meta | 11 ++++++++ .../Tests/Editor/DemonstrationTests.cs | 2 +- ...ditModeTestInternalBrainTensorGenerator.cs | 17 ++++++------ .../Tests/Editor/MLAgentsEditModeTest.cs | 9 +++++-- 5 files changed, 55 insertions(+), 11 deletions(-) create mode 100644 com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index b166de585a..c0279d5377 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -220,6 +220,33 @@ internal struct AgentParameters [FormerlySerializedAs("maxStep")] [HideInInspector] public int MaxStep; + public enum ObservableAttributeBehavior + { + /// + /// All ObservableAttributes on the Agent will be ignored. If there are no + /// ObservableAttributes on the Agent, this will result in the fastest + /// initialization time. + /// + Ignore, + + /// + /// Only members on the declared class will be examined; members that are + /// inherited are ignored. This is the default behavior, and a reasonable + /// tradeoff between performance and flexibility. + /// + /// This corresponds to setting the + /// [BindingFlags.DeclaredOnly](https://docs.microsoft.com/en-us/dotnet/api/system.reflection.bindingflags?view=netcore-3.1) + /// when examining the fields and properties of the Agent class instance. + /// + SkipInherited, + + /// + /// All members on the class will be examined. This can lead to slower + /// startup times + /// + ExamineAll + } + /// Current Agent information (message sent to Brain). AgentInfo m_Info; diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta new file mode 100644 index 0000000000..73dd8b25ac --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e5e4df2934c014aa3b835b9eb9ad20b3 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs b/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs index ef0e760471..f55ad7a5b6 100644 --- a/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs +++ b/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs @@ -136,7 +136,7 @@ public void TestAgentWrite() BrainParametersProto.Parser.ParseDelimitedFrom(reader); var agentInfoProto = AgentInfoActionPairProto.Parser.ParseDelimitedFrom(reader).AgentInfo; - var obs = agentInfoProto.Observations[2]; // skip dummy sensors + var obs = agentInfoProto.Observations[3]; // skip dummy sensors { var vecObs = obs.FloatData.Data; Assert.AreEqual(bpA.BrainParameters.VectorObservationSize, vecObs.Count); diff --git a/com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs b/com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs index c5434b842f..3b33ffc5a3 100644 --- a/com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs +++ b/com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs @@ -100,15 +100,16 @@ public void GenerateVectorObservation() { var inputTensor = new TensorProxy { - shape = new long[] { 2, 3 } + shape = new long[] { 2, 4 } }; const int batchSize = 4; var agentInfos = GetFakeAgents(); var alloc = new TensorCachingAllocator(); var generator = new VectorObservationGenerator(alloc); - generator.AddSensorIndex(0); - generator.AddSensorIndex(1); - generator.AddSensorIndex(2); + generator.AddSensorIndex(0); // ObservableAttribute (size 1) + generator.AddSensorIndex(1); // TestSensor (size 0) + generator.AddSensorIndex(2); // TestSensor (size 0) + generator.AddSensorIndex(3); // VectorSensor (size 3) var agent0 = agentInfos[0]; var agent1 = agentInfos[1]; var inputs = new List @@ -118,10 +119,10 @@ public void GenerateVectorObservation() }; generator.Generate(inputTensor, batchSize, inputs); Assert.IsNotNull(inputTensor.data); - Assert.AreEqual(inputTensor.data[0, 0], 1); - Assert.AreEqual(inputTensor.data[0, 2], 3); - Assert.AreEqual(inputTensor.data[1, 0], 4); - Assert.AreEqual(inputTensor.data[1, 2], 6); + Assert.AreEqual(inputTensor.data[0, 1], 1); + Assert.AreEqual(inputTensor.data[0, 3], 3); + Assert.AreEqual(inputTensor.data[1, 1], 4); + Assert.AreEqual(inputTensor.data[1, 3], 6); alloc.Dispose(); } diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index 3b3e6ca1d4..d75bb765fe 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -4,6 +4,7 @@ using System.Reflection; using System.Collections.Generic; using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; using Unity.MLAgents.Policies; using Unity.MLAgents.SideChannels; @@ -61,6 +62,9 @@ internal IPolicy GetPolicy() public TestSensor sensor1; public TestSensor sensor2; + [Observable("observableFloat")] + public float observableFloat; + public override void Initialize() { initializeAgentCalls += 1; @@ -271,8 +275,9 @@ public void TestAgent() Assert.AreEqual(0, agent2.agentActionCalls); // Make sure the Sensors were sorted - Assert.AreEqual(agent1.sensors[0].GetName(), "testsensor1"); - Assert.AreEqual(agent1.sensors[1].GetName(), "testsensor2"); + Assert.AreEqual(agent1.sensors[0].GetName(), "observableFloat"); + Assert.AreEqual(agent1.sensors[1].GetName(), "testsensor1"); + Assert.AreEqual(agent1.sensors[2].GetName(), "testsensor2"); } } From d53756a0bab30781ef9a56be9114d82fbf44e8aa Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 13 May 2020 13:01:46 -0700 Subject: [PATCH 13/23] agent enum and unit tests --- com.unity.ml-agents/Editor/AgentEditor.cs | 3 + com.unity.ml-agents/Runtime/Agent.cs | 29 ++++++++-- .../Tests/Editor/DemonstrationTests.cs | 2 +- .../Tests/Editor/MLAgentsEditModeTest.cs | 55 +++++++++++++++++++ 4 files changed, 83 insertions(+), 6 deletions(-) diff --git a/com.unity.ml-agents/Editor/AgentEditor.cs b/com.unity.ml-agents/Editor/AgentEditor.cs index fdebf1cb15..122f841b74 100644 --- a/com.unity.ml-agents/Editor/AgentEditor.cs +++ b/com.unity.ml-agents/Editor/AgentEditor.cs @@ -22,6 +22,9 @@ public override void OnInspectorGUI() new GUIContent("Max Step", "The per-agent maximum number of steps.") ); + var observableAttributeBehavior = serializedAgent.FindProperty("m_observableAttributeBehavior"); + EditorGUILayout.PropertyField(observableAttributeBehavior); + serializedAgent.ApplyModifiedProperties(); EditorGUILayout.LabelField("", GUI.skin.horizontalSlider); diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index c0279d5377..0628c6b92c 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -220,7 +220,10 @@ internal struct AgentParameters [FormerlySerializedAs("maxStep")] [HideInInspector] public int MaxStep; - public enum ObservableAttributeBehavior + /// + /// Options for controlling how the Agent class is searched for s. + /// + public enum ObservableAttributeHandling { /// /// All ObservableAttributes on the Agent will be ignored. If there are no @@ -247,6 +250,19 @@ public enum ObservableAttributeBehavior ExamineAll } + [HideInInspector, SerializeField] + ObservableAttributeHandling m_observableAttributeBehavior = ObservableAttributeHandling.SkipInherited; + + /// + /// Determines how the Agent class is searched for s. + /// + public ObservableAttributeHandling ObservableAttributeBehavior + { + get { return m_observableAttributeBehavior; } + set { m_observableAttributeBehavior = value; } + } + + /// Current Agent information (message sent to Brain). AgentInfo m_Info; @@ -849,11 +865,14 @@ public virtual void Heuristic(float[] actionsOut) /// internal void InitializeSensors() { - using (TimerStack.Instance.Scoped("GetObservableSensors")) + if (ObservableAttributeBehavior != ObservableAttributeHandling.Ignore) { - // TODO enum for whether or not to search full hierarchy - var observableSensors = ObservableAttribute.GetObservableSensors(this, false); - sensors.AddRange(observableSensors); + var declaredOnly = (ObservableAttributeBehavior == ObservableAttributeHandling.SkipInherited); + using (TimerStack.Instance.Scoped("GetObservableSensors")) + { + var observableSensors = ObservableAttribute.GetObservableSensors(this, declaredOnly); + sensors.AddRange(observableSensors); + } } // Get all attached sensor components diff --git a/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs b/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs index f55ad7a5b6..ef0e760471 100644 --- a/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs +++ b/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs @@ -136,7 +136,7 @@ public void TestAgentWrite() BrainParametersProto.Parser.ParseDelimitedFrom(reader); var agentInfoProto = AgentInfoActionPairProto.Parser.ParseDelimitedFrom(reader).AgentInfo; - var obs = agentInfoProto.Observations[3]; // skip dummy sensors + var obs = agentInfoProto.Observations[2]; // skip dummy sensors { var vecObs = obs.FloatData.Data; Assert.AreEqual(bpA.BrainParameters.VectorObservationSize, vecObs.Count); diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index d75bb765fe..7fb0f9c1e0 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -250,9 +250,11 @@ public void TestAgent() var agentGo1 = new GameObject("TestAgent"); agentGo1.AddComponent(); var agent1 = agentGo1.GetComponent(); + var agentGo2 = new GameObject("TestAgent"); agentGo2.AddComponent(); var agent2 = agentGo2.GetComponent(); + agent2.ObservableAttributeBehavior = Agent.ObservableAttributeHandling.Ignore; Assert.AreEqual(0, agent1.agentOnEpisodeBeginCalls); Assert.AreEqual(0, agent2.agentOnEpisodeBeginCalls); @@ -278,6 +280,10 @@ public void TestAgent() Assert.AreEqual(agent1.sensors[0].GetName(), "observableFloat"); Assert.AreEqual(agent1.sensors[1].GetName(), "testsensor1"); Assert.AreEqual(agent1.sensors[2].GetName(), "testsensor2"); + + // agent2 should only have two sensors (no observableFloat) + Assert.AreEqual(agent2.sensors[0].GetName(), "testsensor1"); + Assert.AreEqual(agent2.sensors[1].GetName(), "testsensor2"); } } @@ -746,4 +752,53 @@ public void TestAgentDontCallBaseOnEnable() _InnerAgentTestOnEnableOverride(); } } + + [TestFixture] + public class ObservableAttributeBehaviorTests + { + public class BaseObservableAgent : Agent + { + [Observable] + public float BaseField; + } + + public class DerivedObservableAgent : BaseObservableAgent + { + [Observable] + public float DerivedField; + } + + + [Test] + public void TestObservableAttributeBehaviorIgnore() + { + var variants = new[] + { + // No observables found + (Agent.ObservableAttributeHandling.Ignore, 0), + // Only DerivedField found + (Agent.ObservableAttributeHandling.SkipInherited, 1), + // DerivedField and BaseField found + (Agent.ObservableAttributeHandling.ExamineAll, 2) + }; + + foreach (var (behavior, expectedNumSensors) in variants) + { + var go = new GameObject(); + var agent = go.AddComponent(); + agent.ObservableAttributeBehavior = behavior; + agent.LazyInitialize(); + int numAttributeSensors = 0; + foreach (var sensor in agent.sensors) + { + if (sensor.GetType() != typeof(VectorSensor)) + { + numAttributeSensors++; + } + } + Assert.AreEqual(expectedNumSensors, numAttributeSensors); + } + + } + } } From 06239c9f35d678467e01cf3d145fd1eb1ba56ff5 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 13 May 2020 13:10:35 -0700 Subject: [PATCH 14/23] fix comment --- .../Runtime/Inference/BarracudaModelParamLoader.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs index 8d8a4bc092..cd63c51bf6 100644 --- a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs @@ -538,7 +538,7 @@ static string CheckVectorObsShape( /// /// The tensor that is expected by the model /// Array of attached sensor components (unused). - /// /// Sum of the sizes of all ObservableAttributes (unused). + /// Sum of the sizes of all ObservableAttributes (unused). /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. static string CheckPreviousActionShape( From 8189e1d9c4769d4722482db89b2d2f70eaa49aa2 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 13 May 2020 13:16:39 -0700 Subject: [PATCH 15/23] cleanup TODO --- .../Runtime/Sensors/Reflection/ReflectionSensorBase.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs index 6419735087..10dbad0f18 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs @@ -27,8 +27,6 @@ internal abstract class ReflectionSensorBase : ISensor public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size) { - // TODO 2 constructors? - m_Object = reflectionSensorInfo.Object; m_FieldInfo = reflectionSensorInfo.FieldInfo; m_PropertyInfo = reflectionSensorInfo.PropertyInfo; From 19cf075d00d374f657fbabbbdcf8df811e4cbd05 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Thu, 14 May 2020 10:53:38 -0700 Subject: [PATCH 16/23] ignore by default, rename declaredOnly param, docstrings --- .../Editor/BehaviorParametersEditor.cs | 2 +- com.unity.ml-agents/Runtime/Agent.cs | 16 ++--- .../Inference/BarracudaModelParamLoader.cs | 2 +- .../Reflection/BoolReflectionSensor.cs | 4 ++ .../Reflection/FloatReflectionSensor.cs | 4 ++ .../Sensors/Reflection/IntReflectionSensor.cs | 4 ++ .../Sensors/Reflection/ObservableAttribute.cs | 64 +++++++++++++++---- .../Reflection/QuaternionReflectionSensor.cs | 4 ++ .../Reflection/ReflectionSensorBase.cs | 21 ++++-- .../Reflection/Vector2ReflectionSensor.cs | 4 ++ .../Reflection/Vector3ReflectionSensor.cs | 4 ++ .../Reflection/Vector4ReflectionSensor.cs | 4 ++ .../Tests/Editor/MLAgentsEditModeTest.cs | 11 ++-- .../Editor/Sensor/ObservableAttributeTests.cs | 16 ++--- 14 files changed, 121 insertions(+), 39 deletions(-) diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs index c6edb45187..a3840392fe 100644 --- a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs +++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs @@ -104,7 +104,7 @@ void DisplayFailedModelChecks() int observableSensorSizes = 0; var agent = behaviorParameters.GetComponent(); - if (agent != null) + if (agent != null && agent.ObservableAttributeBehavior != Agent.ObservableAttributeOptions.Ignore) { List observableErrors = new List(); observableSensorSizes = ObservableAttribute.GetTotalObservationSize(agent, false, observableErrors); diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 0628c6b92c..972cab5f3e 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -223,7 +223,7 @@ internal struct AgentParameters /// /// Options for controlling how the Agent class is searched for s. /// - public enum ObservableAttributeHandling + public enum ObservableAttributeOptions { /// /// All ObservableAttributes on the Agent will be ignored. If there are no @@ -241,7 +241,7 @@ public enum ObservableAttributeHandling /// [BindingFlags.DeclaredOnly](https://docs.microsoft.com/en-us/dotnet/api/system.reflection.bindingflags?view=netcore-3.1) /// when examining the fields and properties of the Agent class instance. /// - SkipInherited, + ExcludeInherited, /// /// All members on the class will be examined. This can lead to slower @@ -251,12 +251,12 @@ public enum ObservableAttributeHandling } [HideInInspector, SerializeField] - ObservableAttributeHandling m_observableAttributeBehavior = ObservableAttributeHandling.SkipInherited; + ObservableAttributeOptions m_observableAttributeBehavior = ObservableAttributeOptions.Ignore; /// /// Determines how the Agent class is searched for s. /// - public ObservableAttributeHandling ObservableAttributeBehavior + public ObservableAttributeOptions ObservableAttributeBehavior { get { return m_observableAttributeBehavior; } set { m_observableAttributeBehavior = value; } @@ -865,12 +865,12 @@ public virtual void Heuristic(float[] actionsOut) /// internal void InitializeSensors() { - if (ObservableAttributeBehavior != ObservableAttributeHandling.Ignore) + if (ObservableAttributeBehavior != ObservableAttributeOptions.Ignore) { - var declaredOnly = (ObservableAttributeBehavior == ObservableAttributeHandling.SkipInherited); - using (TimerStack.Instance.Scoped("GetObservableSensors")) + var excludeInherited = (ObservableAttributeBehavior == ObservableAttributeOptions.ExcludeInherited); + using (TimerStack.Instance.Scoped("CreateObservableSensors")) { - var observableSensors = ObservableAttribute.GetObservableSensors(this, declaredOnly); + var observableSensors = ObservableAttribute.CreateObservableSensors(this, excludeInherited); sensors.AddRange(observableSensors); } } diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs index cd63c51bf6..41efa93653 100644 --- a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs @@ -523,7 +523,7 @@ static string CheckVectorObsShape( return $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " + $"but received: \n" + $"Vector observations: {vecObsSizeBp} x {numStackedVector}\n" + - $"Total [Observable] attributes: {observableAttributeTotalSize}\n"+ + $"Total [Observable] attributes: {observableAttributeTotalSize}\n" + $"SensorComponent sizes: {sensorSizes}."; } return null; diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs index b71ebf4aa6..0bd8e60b88 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs @@ -1,5 +1,9 @@ namespace Unity.MLAgents.Sensors.Reflection { + /// + /// Sensor that wraps a boolean field or property of an object, and returns + /// that as an observation. + /// internal class BoolReflectionSensor : ReflectionSensorBase { internal BoolReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs index d34a7faeb3..47daa282d3 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs @@ -1,5 +1,9 @@ namespace Unity.MLAgents.Sensors.Reflection { + /// + /// Sensor that wraps a float field or property of an object, and returns + /// that as an observation. + /// internal class FloatReflectionSensor : ReflectionSensorBase { internal FloatReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs index 748762cbba..645341dc19 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs @@ -1,5 +1,9 @@ namespace Unity.MLAgents.Sensors.Reflection { + /// + /// Sensor that wraps an integer field or property of an object, and returns + /// that as an observation. + /// internal class IntReflectionSensor : ReflectionSensorBase { internal IntReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs index 0d0dd5de98..caeaf40ca8 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -49,7 +49,14 @@ public class ObservableAttribute : Attribute string m_Name; int m_NumStackedObservations; + /// + /// Default binding flags used for reflection of members and properties. + /// const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; + + /// + /// Supported types and their observation sizes. + /// static Dictionary s_TypeSizes = new Dictionary() { {typeof(int), 1}, @@ -73,10 +80,16 @@ public ObservableAttribute(string name = null, int numStackedObservations = 1) m_NumStackedObservations = numStackedObservations; } - internal static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool declaredOnly) + /// + /// Returns a FieldInfo for all fields that have an ObservableAttribute + /// + /// Object being reflected + /// Whether to exclude inherited properties or not. + /// + static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool excludeInherited) { // TODO cache these (and properties) by type, so that we only have to reflect once. - var bindingFlags = k_BindingFlags | (declaredOnly ? BindingFlags.DeclaredOnly : 0); + var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); var fields = o.GetType().GetFields(bindingFlags); foreach (var field in fields) { @@ -88,9 +101,15 @@ public ObservableAttribute(string name = null, int numStackedObservations = 1) } } - internal static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool declaredOnly) + /// + /// Returns a PropertyInfo for all fields that have an ObservableAttribute + /// + /// Object being reflected + /// Whether to exclude inherited properties or not. + /// + static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool excludeInherited) { - var bindingFlags = k_BindingFlags | (declaredOnly ? BindingFlags.DeclaredOnly : 0); + var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); var properties = o.GetType().GetProperties(bindingFlags); foreach (var prop in properties) { @@ -107,15 +126,21 @@ public ObservableAttribute(string name = null, int numStackedObservations = 1) } } - internal static List GetObservableSensors(object o, bool declaredOnly) + /// + /// Creates sensors for each field and property with ObservableAttribute. + /// + /// Object being reflected + /// Whether to exclude inherited properties or not. + /// + internal static List CreateObservableSensors(object o, bool excludeInherited) { var sensorsOut = new List(); - foreach (var(field, attr) in GetObservableFields(o, declaredOnly)) + foreach (var(field, attr) in GetObservableFields(o, excludeInherited)) { sensorsOut.Add(CreateReflectionSensor(o, field, null, attr)); } - foreach (var(prop, attr) in GetObservableProperties(o, declaredOnly)) + foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited)) { sensorsOut.Add(CreateReflectionSensor(o, null, prop, attr)); } @@ -123,7 +148,16 @@ internal static List GetObservableSensors(object o, bool declaredOnly) return sensorsOut; } - internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute) + /// + /// Create the ISensor for either the field or property on the provided object. + /// + /// + /// + /// + /// + /// + /// + static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute) { string memberName; string declaringTypeName; @@ -204,10 +238,18 @@ internal static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, Pr return sensor; } - internal static int GetTotalObservationSize(object o, bool declaredOnly, List errorsOut) + /// + /// Gets the sum of the observation sizes of the Observable fields and properties on an object. + /// Also appends errors to the errorsOut array. + /// + /// + /// + /// + /// + internal static int GetTotalObservationSize(object o, bool excludeInherited, List errorsOut) { int sizeOut = 0; - foreach (var(field, attr) in GetObservableFields(o, declaredOnly)) + foreach (var(field, attr) in GetObservableFields(o, excludeInherited)) { if (s_TypeSizes.ContainsKey(field.FieldType)) { @@ -219,7 +261,7 @@ internal static int GetTotalObservationSize(object o, bool declaredOnly, List + /// Sensor that wraps a quaternion field or property of an object, and returns + /// that as an observation. + /// internal class QuaternionReflectionSensor : ReflectionSensorBase { internal QuaternionReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs index 10dbad0f18..410a89b2bf 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs @@ -2,7 +2,9 @@ namespace Unity.MLAgents.Sensors.Reflection { - // Construction info for a ReflectionSensorBase. + /// + /// Construction info for a ReflectionSensorBase. + /// internal struct ReflectionSensorInfo { public object Object; @@ -13,15 +15,21 @@ internal struct ReflectionSensorInfo public string SensorName; } + /// + /// Abstract base class for reflection-based sensors. + /// internal abstract class ReflectionSensorBase : ISensor { protected object m_Object; + // Exactly one of m_FieldInfo and m_PropertyInfo should be non-null. protected FieldInfo m_FieldInfo; protected PropertyInfo m_PropertyInfo; + // Not currently used, but might want later. protected ObservableAttribute m_ObservableAttribute; + // Cached sensor names and shapes. string m_SensorName; int[] m_Shape; @@ -32,7 +40,7 @@ public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size) m_PropertyInfo = reflectionSensorInfo.PropertyInfo; m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute; m_SensorName = reflectionSensorInfo.SensorName; - m_Shape = new [] {size}; + m_Shape = new[] {size}; } /// @@ -50,6 +58,11 @@ public int Write(ObservationWriter writer) internal abstract void WriteReflectedField(ObservationWriter writer); + /// + /// Get either the reflected field, or return the reflected property. + /// This should be used by implementations in their WriteReflectedField() method. + /// + /// protected object GetReflectedValue() { return m_FieldInfo != null ? @@ -64,10 +77,10 @@ public byte[] GetCompressedObservation() } /// - public void Update() { } + public void Update() {} /// - public void Reset() { } + public void Reset() {} /// public SensorCompressionType GetCompressionType() diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs index b5a142ee4d..5523c89ba4 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs @@ -1,5 +1,9 @@ namespace Unity.MLAgents.Sensors.Reflection { + /// + /// Sensor that wraps a Vector2 field or property of an object, and returns + /// that as an observation. + /// internal class Vector2ReflectionSensor : ReflectionSensorBase { internal Vector2ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs index 1823277eec..d7268084c8 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs @@ -1,5 +1,9 @@ namespace Unity.MLAgents.Sensors.Reflection { + /// + /// Sensor that wraps a Vector3 field or property of an object, and returns + /// that as an observation. + /// internal class Vector3ReflectionSensor : ReflectionSensorBase { internal Vector3ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs index f5adb657c3..4994d4dbbb 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs @@ -1,5 +1,9 @@ namespace Unity.MLAgents.Sensors.Reflection { + /// + /// Sensor that wraps a Vector4 field or property of an object, and returns + /// that as an observation. + /// internal class Vector4ReflectionSensor : ReflectionSensorBase { internal Vector4ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index 7fb0f9c1e0..69cfc2b06f 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -254,7 +254,7 @@ public void TestAgent() var agentGo2 = new GameObject("TestAgent"); agentGo2.AddComponent(); var agent2 = agentGo2.GetComponent(); - agent2.ObservableAttributeBehavior = Agent.ObservableAttributeHandling.Ignore; + agent2.ObservableAttributeBehavior = Agent.ObservableAttributeOptions.Ignore; Assert.AreEqual(0, agent1.agentOnEpisodeBeginCalls); Assert.AreEqual(0, agent2.agentOnEpisodeBeginCalls); @@ -775,14 +775,14 @@ public void TestObservableAttributeBehaviorIgnore() var variants = new[] { // No observables found - (Agent.ObservableAttributeHandling.Ignore, 0), + (Agent.ObservableAttributeOptions.Ignore, 0), // Only DerivedField found - (Agent.ObservableAttributeHandling.SkipInherited, 1), + (SkipInherited: Agent.ObservableAttributeOptions.ExcludeInherited, 1), // DerivedField and BaseField found - (Agent.ObservableAttributeHandling.ExamineAll, 2) + (Agent.ObservableAttributeOptions.ExamineAll, 2) }; - foreach (var (behavior, expectedNumSensors) in variants) + foreach (var(behavior, expectedNumSensors) in variants) { var go = new GameObject(); var agent = go.AddComponent(); @@ -798,7 +798,6 @@ public void TestObservableAttributeBehaviorIgnore() } Assert.AreEqual(expectedNumSensors, numAttributeSensors); } - } } } diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs index 74b4685b6f..3ebf98f8a5 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs @@ -151,7 +151,7 @@ public void TestGetObservableSensors() testClass.m_QuaternionMember = new Quaternion(5.0f, 5.1f, 5.2f, 5.3f); testClass.QuaternionProperty = new Quaternion(5.4f, 5.5f, 5.5f, 5.7f); - var sensors = ObservableAttribute.GetObservableSensors(testClass, false); + var sensors = ObservableAttribute.CreateObservableSensors(testClass, false); var sensorsByName = new Dictionary(); foreach (var sensor in sensors) @@ -234,7 +234,7 @@ public void TestObservableAttributeStacking() { var c = new StackingClass(); c.FloatVal = 1.0f; - var sensors = ObservableAttribute.GetObservableSensors(c, false); + var sensors = ObservableAttribute.CreateObservableSensors(c, false); var sensor = sensors[0]; Assert.AreEqual(typeof(StackingSensor), sensor.GetType()); SensorTestHelper.CompareObservation(sensor, new[] { 0.0f, 1.0f }); @@ -264,25 +264,25 @@ class DerivedClass : BaseClass } [Test] - public void TestObservableAttributeDeclaredOnly() + public void TestObservableAttributeExcludeInherited() { var d = new DerivedClass(); d.m_BaseField = 1.0f; - // declaredOnly=false will get fields in the derived class, plus public and protected inherited fields - var sensorAll = ObservableAttribute.GetObservableSensors(d, false); + // excludeInherited=false will get fields in the derived class, plus public and protected inherited fields + var sensorAll = ObservableAttribute.CreateObservableSensors(d, false); Assert.AreEqual(2, sensorAll.Count); // Note - actual order doesn't matter here, we can change this to use a HashSet if neeed. Assert.AreEqual("derived", sensorAll[0].GetName()); Assert.AreEqual("base", sensorAll[1].GetName()); - // declaredOnly=true will only get fields in the derived class - var sensorsDerivedOnly = ObservableAttribute.GetObservableSensors(d, true); + // excludeInherited=true will only get fields in the derived class + var sensorsDerivedOnly = ObservableAttribute.CreateObservableSensors(d, true); Assert.AreEqual(1, sensorsDerivedOnly.Count); Assert.AreEqual("derived", sensorsDerivedOnly[0].GetName()); var b = new BaseClass(); - var baseSensors = ObservableAttribute.GetObservableSensors(b, false); + var baseSensors = ObservableAttribute.CreateObservableSensors(b, false); Assert.AreEqual(2, baseSensors.Count); } } From d0def071741a0e96818d6f685acfa59d67c04a6f Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Thu, 14 May 2020 11:13:06 -0700 Subject: [PATCH 17/23] fix tests --- com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index 69cfc2b06f..562be4b4e0 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -30,6 +30,11 @@ public void Dispose() {} public class TestAgent : Agent { + public TestAgent() + { + ObservableAttributeBehavior = ObservableAttributeOptions.ExcludeInherited; + } + internal AgentInfo _Info { get @@ -777,7 +782,7 @@ public void TestObservableAttributeBehaviorIgnore() // No observables found (Agent.ObservableAttributeOptions.Ignore, 0), // Only DerivedField found - (SkipInherited: Agent.ObservableAttributeOptions.ExcludeInherited, 1), + (Agent.ObservableAttributeOptions.ExcludeInherited, 1), // DerivedField and BaseField found (Agent.ObservableAttributeOptions.ExamineAll, 2) }; From 580c2075137d2afe17948e72d10787c38704f27f Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Thu, 14 May 2020 11:40:52 -0700 Subject: [PATCH 18/23] rename, cleanup, revert FoodCollector --- .../Prefabs/FoodCollectorArea.prefab | 70 +++++++++++-------- .../Scripts/FoodCollectorAgent.cs | 21 +++--- com.unity.ml-agents/Editor/AgentEditor.cs | 2 +- .../Editor/BehaviorParametersEditor.cs | 2 +- com.unity.ml-agents/Runtime/Agent.cs | 12 ++-- .../Sensors/Reflection/IntReflectionSensor.cs | 2 +- .../Tests/Editor/MLAgentsEditModeTest.cs | 6 +- 7 files changed, 61 insertions(+), 54 deletions(-) diff --git a/Project/Assets/ML-Agents/Examples/FoodCollector/Prefabs/FoodCollectorArea.prefab b/Project/Assets/ML-Agents/Examples/FoodCollector/Prefabs/FoodCollectorArea.prefab index a02f2f4f5a..34bf31e657 100644 --- a/Project/Assets/ML-Agents/Examples/FoodCollector/Prefabs/FoodCollectorArea.prefab +++ b/Project/Assets/ML-Agents/Examples/FoodCollector/Prefabs/FoodCollectorArea.prefab @@ -2178,11 +2178,11 @@ MonoBehaviour: m_Name: m_EditorClassIdentifier: m_BrainParameters: - VectorObservationSize: 0 - NumStackedVectorObservations: 1 - VectorActionSize: 03000000030000000300000002000000 - VectorActionDescriptions: [] - VectorActionSpaceType: 0 + vectorObservationSize: 4 + numStackedVectorObservations: 1 + vectorActionSize: 03000000030000000300000002000000 + vectorActionDescriptions: [] + vectorActionSpaceType: 0 m_Model: {fileID: 11400000, guid: 36ab3e93020504f48858d0856f939685, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -2204,7 +2204,7 @@ MonoBehaviour: agentParameters: maxStep: 0 hasUpgradedFromAgentParameters: 1 - MaxStep: 5000 + maxStep: 5000 area: {fileID: 1819751139121548} turnSpeed: 300 moveSpeed: 2 @@ -2214,6 +2214,7 @@ MonoBehaviour: frozenMaterial: {fileID: 2100000, guid: 66163cf35956a4be08e801b750c26f33, type: 2} myLaser: {fileID: 1081721624670010} contribute: 0 + useVectorObs: 1 --- !u!114 &114725457980523372 MonoBehaviour: m_ObjectHideFlags: 0 @@ -2259,6 +2260,7 @@ MonoBehaviour: m_EditorClassIdentifier: DecisionPeriod: 5 TakeActionsBetweenDecisions: 1 + offsetStep: 0 --- !u!114 &1222199865870203693 MonoBehaviour: m_ObjectHideFlags: 0 @@ -2515,11 +2517,11 @@ MonoBehaviour: m_Name: m_EditorClassIdentifier: m_BrainParameters: - VectorObservationSize: 0 - NumStackedVectorObservations: 1 - VectorActionSize: 03000000030000000300000002000000 - VectorActionDescriptions: [] - VectorActionSpaceType: 0 + vectorObservationSize: 4 + numStackedVectorObservations: 1 + vectorActionSize: 03000000030000000300000002000000 + vectorActionDescriptions: [] + vectorActionSpaceType: 0 m_Model: {fileID: 11400000, guid: 36ab3e93020504f48858d0856f939685, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -2541,7 +2543,7 @@ MonoBehaviour: agentParameters: maxStep: 0 hasUpgradedFromAgentParameters: 1 - MaxStep: 5000 + maxStep: 5000 area: {fileID: 1819751139121548} turnSpeed: 300 moveSpeed: 2 @@ -2551,6 +2553,7 @@ MonoBehaviour: frozenMaterial: {fileID: 2100000, guid: 66163cf35956a4be08e801b750c26f33, type: 2} myLaser: {fileID: 1941433838307300} contribute: 0 + useVectorObs: 1 --- !u!114 &114443152683847924 MonoBehaviour: m_ObjectHideFlags: 0 @@ -2596,6 +2599,7 @@ MonoBehaviour: m_EditorClassIdentifier: DecisionPeriod: 5 TakeActionsBetweenDecisions: 1 + offsetStep: 0 --- !u!1 &1528397385587768 GameObject: m_ObjectHideFlags: 0 @@ -2844,11 +2848,11 @@ MonoBehaviour: m_Name: m_EditorClassIdentifier: m_BrainParameters: - VectorObservationSize: 0 - NumStackedVectorObservations: 1 - VectorActionSize: 03000000030000000300000002000000 - VectorActionDescriptions: [] - VectorActionSpaceType: 0 + vectorObservationSize: 4 + numStackedVectorObservations: 1 + vectorActionSize: 03000000030000000300000002000000 + vectorActionDescriptions: [] + vectorActionSpaceType: 0 m_Model: {fileID: 11400000, guid: 36ab3e93020504f48858d0856f939685, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -2870,7 +2874,7 @@ MonoBehaviour: agentParameters: maxStep: 0 hasUpgradedFromAgentParameters: 1 - MaxStep: 5000 + maxStep: 5000 area: {fileID: 1819751139121548} turnSpeed: 300 moveSpeed: 2 @@ -2880,6 +2884,7 @@ MonoBehaviour: frozenMaterial: {fileID: 2100000, guid: 66163cf35956a4be08e801b750c26f33, type: 2} myLaser: {fileID: 1421240237750412} contribute: 0 + useVectorObs: 1 --- !u!114 &114986980423924774 MonoBehaviour: m_ObjectHideFlags: 0 @@ -2925,6 +2930,7 @@ MonoBehaviour: m_EditorClassIdentifier: DecisionPeriod: 5 TakeActionsBetweenDecisions: 1 + offsetStep: 0 --- !u!1 &1617924810425504 GameObject: m_ObjectHideFlags: 0 @@ -3436,11 +3442,11 @@ MonoBehaviour: m_Name: m_EditorClassIdentifier: m_BrainParameters: - VectorObservationSize: 0 - NumStackedVectorObservations: 1 - VectorActionSize: 03000000030000000300000002000000 - VectorActionDescriptions: [] - VectorActionSpaceType: 0 + vectorObservationSize: 4 + numStackedVectorObservations: 1 + vectorActionSize: 03000000030000000300000002000000 + vectorActionDescriptions: [] + vectorActionSpaceType: 0 m_Model: {fileID: 11400000, guid: 36ab3e93020504f48858d0856f939685, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -3462,7 +3468,7 @@ MonoBehaviour: agentParameters: maxStep: 0 hasUpgradedFromAgentParameters: 1 - MaxStep: 5000 + maxStep: 5000 area: {fileID: 1819751139121548} turnSpeed: 300 moveSpeed: 2 @@ -3472,6 +3478,7 @@ MonoBehaviour: frozenMaterial: {fileID: 2100000, guid: 66163cf35956a4be08e801b750c26f33, type: 2} myLaser: {fileID: 1617924810425504} contribute: 0 + useVectorObs: 1 --- !u!114 &114644889237473510 MonoBehaviour: m_ObjectHideFlags: 0 @@ -3517,6 +3524,7 @@ MonoBehaviour: m_EditorClassIdentifier: DecisionPeriod: 5 TakeActionsBetweenDecisions: 1 + offsetStep: 0 --- !u!1 &1688105343773098 GameObject: m_ObjectHideFlags: 0 @@ -3751,11 +3759,11 @@ MonoBehaviour: m_Name: m_EditorClassIdentifier: m_BrainParameters: - VectorObservationSize: 0 - NumStackedVectorObservations: 1 - VectorActionSize: 03000000030000000300000002000000 - VectorActionDescriptions: [] - VectorActionSpaceType: 0 + vectorObservationSize: 4 + numStackedVectorObservations: 1 + vectorActionSize: 03000000030000000300000002000000 + vectorActionDescriptions: [] + vectorActionSpaceType: 0 m_Model: {fileID: 11400000, guid: 36ab3e93020504f48858d0856f939685, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -3777,7 +3785,7 @@ MonoBehaviour: agentParameters: maxStep: 0 hasUpgradedFromAgentParameters: 1 - MaxStep: 5000 + maxStep: 5000 area: {fileID: 1819751139121548} turnSpeed: 300 moveSpeed: 2 @@ -3787,6 +3795,7 @@ MonoBehaviour: frozenMaterial: {fileID: 2100000, guid: 66163cf35956a4be08e801b750c26f33, type: 2} myLaser: {fileID: 1045923826166930} contribute: 0 + useVectorObs: 1 --- !u!114 &114276061479012222 MonoBehaviour: m_ObjectHideFlags: 0 @@ -3832,6 +3841,7 @@ MonoBehaviour: m_EditorClassIdentifier: DecisionPeriod: 5 TakeActionsBetweenDecisions: 1 + offsetStep: 0 --- !u!1 &1729825611722018 GameObject: m_ObjectHideFlags: 0 diff --git a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs index 9e11ad2c77..87c9d316ae 100644 --- a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs +++ b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs @@ -1,22 +1,16 @@ using UnityEngine; using Unity.MLAgents; -using Unity.MLAgents.Sensors.Reflection; +using Unity.MLAgents.Sensors; public class FoodCollectorAgent : Agent { FoodCollectorSettings m_FoodCollecterSettings; public GameObject area; FoodCollectorArea m_MyArea; - - [Observable] bool m_Frozen; - bool m_Poisoned; bool m_Satiated; - - [Observable] bool m_Shoot; - float m_FrozenTime; float m_EffectTime; Rigidbody m_AgentRb; @@ -32,6 +26,7 @@ public class FoodCollectorAgent : Agent public Material frozenMaterial; public GameObject myLaser; public bool contribute; + public bool useVectorObs; EnvironmentParameters m_ResetParams; @@ -44,13 +39,15 @@ public override void Initialize() SetResetParameters(); } - [Observable] - Vector2 localRbVelocity + public override void CollectObservations(VectorSensor sensor) { - get + if (useVectorObs) { - var rbVel = transform.InverseTransformDirection(m_AgentRb.velocity); - return new Vector2(rbVel.x, rbVel.z); + var localVelocity = transform.InverseTransformDirection(m_AgentRb.velocity); + sensor.AddObservation(localVelocity.x); + sensor.AddObservation(localVelocity.z); + sensor.AddObservation(System.Convert.ToInt32(m_Frozen)); + sensor.AddObservation(System.Convert.ToInt32(m_Shoot)); } } diff --git a/com.unity.ml-agents/Editor/AgentEditor.cs b/com.unity.ml-agents/Editor/AgentEditor.cs index 122f841b74..3a8de61867 100644 --- a/com.unity.ml-agents/Editor/AgentEditor.cs +++ b/com.unity.ml-agents/Editor/AgentEditor.cs @@ -22,7 +22,7 @@ public override void OnInspectorGUI() new GUIContent("Max Step", "The per-agent maximum number of steps.") ); - var observableAttributeBehavior = serializedAgent.FindProperty("m_observableAttributeBehavior"); + var observableAttributeBehavior = serializedAgent.FindProperty("m_observableAttributeHandling"); EditorGUILayout.PropertyField(observableAttributeBehavior); serializedAgent.ApplyModifiedProperties(); diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs index a3840392fe..9e99ebb193 100644 --- a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs +++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs @@ -104,7 +104,7 @@ void DisplayFailedModelChecks() int observableSensorSizes = 0; var agent = behaviorParameters.GetComponent(); - if (agent != null && agent.ObservableAttributeBehavior != Agent.ObservableAttributeOptions.Ignore) + if (agent != null && agent.ObservableAttributeHandling != Agent.ObservableAttributeOptions.Ignore) { List observableErrors = new List(); observableSensorSizes = ObservableAttribute.GetTotalObservationSize(agent, false, observableErrors); diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 972cab5f3e..70caa66ba7 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -251,15 +251,15 @@ public enum ObservableAttributeOptions } [HideInInspector, SerializeField] - ObservableAttributeOptions m_observableAttributeBehavior = ObservableAttributeOptions.Ignore; + ObservableAttributeOptions m_observableAttributeHandling = ObservableAttributeOptions.Ignore; /// /// Determines how the Agent class is searched for s. /// - public ObservableAttributeOptions ObservableAttributeBehavior + public ObservableAttributeOptions ObservableAttributeHandling { - get { return m_observableAttributeBehavior; } - set { m_observableAttributeBehavior = value; } + get { return m_observableAttributeHandling; } + set { m_observableAttributeHandling = value; } } @@ -865,9 +865,9 @@ public virtual void Heuristic(float[] actionsOut) /// internal void InitializeSensors() { - if (ObservableAttributeBehavior != ObservableAttributeOptions.Ignore) + if (ObservableAttributeHandling != ObservableAttributeOptions.Ignore) { - var excludeInherited = (ObservableAttributeBehavior == ObservableAttributeOptions.ExcludeInherited); + var excludeInherited = (ObservableAttributeHandling == ObservableAttributeOptions.ExcludeInherited); using (TimerStack.Instance.Scoped("CreateObservableSensors")) { var observableSensors = ObservableAttribute.CreateObservableSensors(this, excludeInherited); diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs index 645341dc19..93149275f5 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs @@ -13,7 +13,7 @@ internal IntReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) internal override void WriteReflectedField(ObservationWriter writer) { var intVal = (System.Int32)GetReflectedValue(); - writer[0] = intVal; + writer[0] = (float)intVal; } } } diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index 562be4b4e0..f0f2ae1c8e 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -32,7 +32,7 @@ public class TestAgent : Agent { public TestAgent() { - ObservableAttributeBehavior = ObservableAttributeOptions.ExcludeInherited; + ObservableAttributeHandling = ObservableAttributeOptions.ExcludeInherited; } internal AgentInfo _Info @@ -259,7 +259,7 @@ public void TestAgent() var agentGo2 = new GameObject("TestAgent"); agentGo2.AddComponent(); var agent2 = agentGo2.GetComponent(); - agent2.ObservableAttributeBehavior = Agent.ObservableAttributeOptions.Ignore; + agent2.ObservableAttributeHandling = Agent.ObservableAttributeOptions.Ignore; Assert.AreEqual(0, agent1.agentOnEpisodeBeginCalls); Assert.AreEqual(0, agent2.agentOnEpisodeBeginCalls); @@ -791,7 +791,7 @@ public void TestObservableAttributeBehaviorIgnore() { var go = new GameObject(); var agent = go.AddComponent(); - agent.ObservableAttributeBehavior = behavior; + agent.ObservableAttributeHandling = behavior; agent.LazyInitialize(); int numAttributeSensors = 0; foreach (var sensor in agent.sensors) From b02defcc1bc8e86b39fa1276fc187b44fdc9e3ac Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Thu, 14 May 2020 11:56:48 -0700 Subject: [PATCH 19/23] warning for write-only, no exception for invalid type --- .../Sensors/Reflection/ObservableAttribute.cs | 37 ++++++++++++++----- .../Editor/Sensor/ObservableAttributeTests.cs | 8 +++- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs index caeaf40ca8..f5ab13c380 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -113,11 +113,6 @@ public ObservableAttribute(string name = null, int numStackedObservations = 1) var properties = o.GetType().GetProperties(bindingFlags); foreach (var prop in properties) { - if (!prop.CanRead) - { - // Ignore write-only properties. - continue; - } var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); if (attr != null) { @@ -137,12 +132,25 @@ internal static List CreateObservableSensors(object o, bool excludeInhe var sensorsOut = new List(); foreach (var(field, attr) in GetObservableFields(o, excludeInherited)) { - sensorsOut.Add(CreateReflectionSensor(o, field, null, attr)); + var sensor = CreateReflectionSensor(o, field, null, attr); + if (sensor != null) + { + sensorsOut.Add(sensor); + } } foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited)) { - sensorsOut.Add(CreateReflectionSensor(o, null, prop, attr)); + if (!prop.CanRead) + { + // Skip unreadable properties. + continue; + } + var sensor = CreateReflectionSensor(o, null, prop, attr); + if (sensor != null) + { + sensorsOut.Add(sensor); + } } return sensorsOut; @@ -150,6 +158,7 @@ internal static List CreateObservableSensors(object o, bool excludeInhe /// /// Create the ISensor for either the field or property on the provided object. + /// If the data type is unsupported, or the property is write-only, returns null. /// /// /// @@ -226,7 +235,8 @@ static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInf if (sensor == null) { - throw new UnityAgentsException($"Unsupported Observable type: {memberType.Name}"); + // For unsupported types, return null and we'll filter them out later. + return null; } // Wrap the base sensor in a StackingSensor if we're using stacking. @@ -265,11 +275,18 @@ internal static int GetTotalObservationSize(object o, bool excludeInherited, Lis { if (s_TypeSizes.ContainsKey(prop.PropertyType)) { - sizeOut += s_TypeSizes[prop.PropertyType] * attr.m_NumStackedObservations; + if (prop.CanRead) + { + sizeOut += s_TypeSizes[prop.PropertyType] * attr.m_NumStackedObservations; + } + else + { + errorsOut.Add($"Observable property {prop.Name} is write-only."); + } } else { - errorsOut.Add($"Unsupported Observable type {prop.PropertyType.Name} on field {prop.Name}"); + errorsOut.Add($"Unsupported Observable type {prop.PropertyType.Name} on property {prop.Name}"); } } diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs index 3ebf98f8a5..b7b51cfe3b 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs @@ -214,13 +214,17 @@ public float WriteOnlyProperty } [Test] - public void TestGetTotalObservationSizeErrors() + public void TestInvalidObservables() { var bad = new BadClass(); bad.WriteOnlyProperty = 1.0f; var errors = new List(); Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, false, errors)); - Assert.AreEqual(2, errors.Count); + Assert.AreEqual(3, errors.Count); + + // Should be able to safely generate sensors (and get nothing back) + var sensors = ObservableAttribute.CreateObservableSensors(bad, false); + Assert.AreEqual(0, sensors.Count); } class StackingClass From 8bb5ebbba95777e187d42b25feaf6b95b2c48339 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Fri, 15 May 2020 09:49:01 -0700 Subject: [PATCH 20/23] move observableAttributeHandling to BehaviorParameters --- com.unity.ml-agents/Editor/AgentEditor.cs | 3 -- .../Editor/BehaviorParametersEditor.cs | 14 ++++-- com.unity.ml-agents/Runtime/Agent.cs | 49 ++----------------- .../Runtime/Policies/BehaviorParameters.cs | 43 ++++++++++++++++ ...ditModeTestInternalBrainTensorGenerator.cs | 7 ++- .../Tests/Editor/MLAgentsEditModeTest.cs | 17 +++---- .../Tests/Runtime/RuntimeAPITest.cs | 5 ++ 7 files changed, 72 insertions(+), 66 deletions(-) diff --git a/com.unity.ml-agents/Editor/AgentEditor.cs b/com.unity.ml-agents/Editor/AgentEditor.cs index 3a8de61867..fdebf1cb15 100644 --- a/com.unity.ml-agents/Editor/AgentEditor.cs +++ b/com.unity.ml-agents/Editor/AgentEditor.cs @@ -22,9 +22,6 @@ public override void OnInspectorGUI() new GUIContent("Max Step", "The per-agent maximum number of steps.") ); - var observableAttributeBehavior = serializedAgent.FindProperty("m_observableAttributeHandling"); - EditorGUILayout.PropertyField(observableAttributeBehavior); - serializedAgent.ApplyModifiedProperties(); EditorGUILayout.LabelField("", GUI.skin.horizontalSlider); diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs index 9e99ebb193..3eab08da8e 100644 --- a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs +++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs @@ -62,6 +62,7 @@ public override void OnInspectorGUI() EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); { EditorGUILayout.PropertyField(so.FindProperty("m_UseChildSensors"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservableAttributeHandling"), true); } EditorGUI.EndDisabledGroup(); @@ -92,6 +93,7 @@ void DisplayFailedModelChecks() var model = (NNModel)serializedObject.FindProperty("m_Model").objectReferenceValue; var behaviorParameters = (BehaviorParameters)target; + // Grab the sensor components, since we need them to determine the observation sizes. SensorComponent[] sensorComponents; if (behaviorParameters.UseChildSensors) { @@ -102,12 +104,14 @@ void DisplayFailedModelChecks() sensorComponents = behaviorParameters.GetComponents(); } - int observableSensorSizes = 0; + // Get the total size of the sensors generated by ObservableAttributes. + // If there are any errors (e.g. unsupported type, write-only properties), display them too. + int observableAttributeSensorTotalSize = 0; var agent = behaviorParameters.GetComponent(); - if (agent != null && agent.ObservableAttributeHandling != Agent.ObservableAttributeOptions.Ignore) + if (agent != null && behaviorParameters.ObservableAttributeHandling != ObservableAttributeOptions.Ignore) { List observableErrors = new List(); - observableSensorSizes = ObservableAttribute.GetTotalObservationSize(agent, false, observableErrors); + observableAttributeSensorTotalSize = ObservableAttribute.GetTotalObservationSize(agent, false, observableErrors); foreach (var check in observableErrors) { EditorGUILayout.HelpBox(check, MessageType.Warning); @@ -122,8 +126,8 @@ void DisplayFailedModelChecks() if (brainParameters != null) { var failedChecks = Inference.BarracudaModelParamLoader.CheckModel( - barracudaModel, brainParameters, sensorComponents, observableSensorSizes, - behaviorParameters.BehaviorType + barracudaModel, brainParameters, sensorComponents, + observableAttributeSensorTotalSize, behaviorParameters.BehaviorType ); foreach (var check in failedChecks) { diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 70caa66ba7..501e9c589b 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Collections.ObjectModel; -using System.Reflection; using UnityEngine; using Unity.Barracuda; using Unity.MLAgents.Sensors; @@ -220,49 +219,6 @@ internal struct AgentParameters [FormerlySerializedAs("maxStep")] [HideInInspector] public int MaxStep; - /// - /// Options for controlling how the Agent class is searched for s. - /// - public enum ObservableAttributeOptions - { - /// - /// All ObservableAttributes on the Agent will be ignored. If there are no - /// ObservableAttributes on the Agent, this will result in the fastest - /// initialization time. - /// - Ignore, - - /// - /// Only members on the declared class will be examined; members that are - /// inherited are ignored. This is the default behavior, and a reasonable - /// tradeoff between performance and flexibility. - /// - /// This corresponds to setting the - /// [BindingFlags.DeclaredOnly](https://docs.microsoft.com/en-us/dotnet/api/system.reflection.bindingflags?view=netcore-3.1) - /// when examining the fields and properties of the Agent class instance. - /// - ExcludeInherited, - - /// - /// All members on the class will be examined. This can lead to slower - /// startup times - /// - ExamineAll - } - - [HideInInspector, SerializeField] - ObservableAttributeOptions m_observableAttributeHandling = ObservableAttributeOptions.Ignore; - - /// - /// Determines how the Agent class is searched for s. - /// - public ObservableAttributeOptions ObservableAttributeHandling - { - get { return m_observableAttributeHandling; } - set { m_observableAttributeHandling = value; } - } - - /// Current Agent information (message sent to Brain). AgentInfo m_Info; @@ -865,9 +821,10 @@ public virtual void Heuristic(float[] actionsOut) /// internal void InitializeSensors() { - if (ObservableAttributeHandling != ObservableAttributeOptions.Ignore) + if (m_PolicyFactory.ObservableAttributeHandling != ObservableAttributeOptions.Ignore) { - var excludeInherited = (ObservableAttributeHandling == ObservableAttributeOptions.ExcludeInherited); + var excludeInherited = + m_PolicyFactory.ObservableAttributeHandling == ObservableAttributeOptions.ExcludeInherited; using (TimerStack.Instance.Scoped("CreateObservableSensors")) { var observableSensors = ObservableAttribute.CreateObservableSensors(this, excludeInherited); diff --git a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs index 013df3bad4..4234fc3f79 100644 --- a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs +++ b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs @@ -2,6 +2,7 @@ using System; using UnityEngine; using UnityEngine.Serialization; +using Unity.MLAgents.Sensors.Reflection; namespace Unity.MLAgents.Policies { @@ -30,6 +31,36 @@ public enum BehaviorType InferenceOnly } + /// + /// Options for controlling how the Agent class is searched for s. + /// + public enum ObservableAttributeOptions + { + /// + /// All ObservableAttributes on the Agent will be ignored. If there are no + /// ObservableAttributes on the Agent, this will result in the fastest + /// initialization time. + /// + Ignore, + + /// + /// Only members on the declared class will be examined; members that are + /// inherited are ignored. This is the default behavior, and a reasonable + /// tradeoff between performance and flexibility. + /// + /// This corresponds to setting the + /// [BindingFlags.DeclaredOnly](https://docs.microsoft.com/en-us/dotnet/api/system.reflection.bindingflags?view=netcore-3.1) + /// when examining the fields and properties of the Agent class instance. + /// + ExcludeInherited, + + /// + /// All members on the class will be examined. This can lead to slower + /// startup times + /// + ExamineAll + } + /// /// A component for setting an instance's behavior and /// brain properties. @@ -129,6 +160,18 @@ public bool UseChildSensors set { m_UseChildSensors = value; } } + [HideInInspector, SerializeField] + ObservableAttributeOptions m_ObservableAttributeHandling = ObservableAttributeOptions.Ignore; + + /// + /// Determines how the Agent class is searched for s. + /// + public ObservableAttributeOptions ObservableAttributeHandling + { + get { return m_ObservableAttributeHandling; } + set { m_ObservableAttributeHandling = value; } + } + /// /// Returns the behavior name, concatenated with any other metadata (i.e. team id). /// diff --git a/com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs b/com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs index 3b33ffc5a3..b19ff8f967 100644 --- a/com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs +++ b/com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs @@ -4,6 +4,7 @@ using UnityEngine; using Unity.MLAgents.Inference; using Unity.MLAgents.Policies; +using Unity.MLAgents.Sensors.Reflection; namespace Unity.MLAgents.Tests { @@ -19,18 +20,20 @@ public void SetUp() } } - static List GetFakeAgents() + static List GetFakeAgents(ObservableAttributeOptions observableAttributeOptions = ObservableAttributeOptions.Ignore) { var goA = new GameObject("goA"); var bpA = goA.AddComponent(); bpA.BrainParameters.VectorObservationSize = 3; bpA.BrainParameters.NumStackedVectorObservations = 1; + bpA.ObservableAttributeHandling = observableAttributeOptions; var agentA = goA.AddComponent(); var goB = new GameObject("goB"); var bpB = goB.AddComponent(); bpB.BrainParameters.VectorObservationSize = 3; bpB.BrainParameters.NumStackedVectorObservations = 1; + bpB.ObservableAttributeHandling = observableAttributeOptions; var agentB = goB.AddComponent(); var agents = new List { agentA, agentB }; @@ -103,7 +106,7 @@ public void GenerateVectorObservation() shape = new long[] { 2, 4 } }; const int batchSize = 4; - var agentInfos = GetFakeAgents(); + var agentInfos = GetFakeAgents(ObservableAttributeOptions.ExamineAll); var alloc = new TensorCachingAllocator(); var generator = new VectorObservationGenerator(alloc); generator.AddSensorIndex(0); // ObservableAttribute (size 1) diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index f0f2ae1c8e..a7d47b5b7c 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -30,11 +30,6 @@ public void Dispose() {} public class TestAgent : Agent { - public TestAgent() - { - ObservableAttributeHandling = ObservableAttributeOptions.ExcludeInherited; - } - internal AgentInfo _Info { get @@ -255,11 +250,12 @@ public void TestAgent() var agentGo1 = new GameObject("TestAgent"); agentGo1.AddComponent(); var agent1 = agentGo1.GetComponent(); + var bp1 = agentGo1.GetComponent(); + bp1.ObservableAttributeHandling = ObservableAttributeOptions.ExcludeInherited; var agentGo2 = new GameObject("TestAgent"); agentGo2.AddComponent(); var agent2 = agentGo2.GetComponent(); - agent2.ObservableAttributeHandling = Agent.ObservableAttributeOptions.Ignore; Assert.AreEqual(0, agent1.agentOnEpisodeBeginCalls); Assert.AreEqual(0, agent2.agentOnEpisodeBeginCalls); @@ -780,18 +776,19 @@ public void TestObservableAttributeBehaviorIgnore() var variants = new[] { // No observables found - (Agent.ObservableAttributeOptions.Ignore, 0), + (ObservableAttributeOptions.Ignore, 0), // Only DerivedField found - (Agent.ObservableAttributeOptions.ExcludeInherited, 1), + (ObservableAttributeOptions.ExcludeInherited, 1), // DerivedField and BaseField found - (Agent.ObservableAttributeOptions.ExamineAll, 2) + (ObservableAttributeOptions.ExamineAll, 2) }; foreach (var(behavior, expectedNumSensors) in variants) { var go = new GameObject(); var agent = go.AddComponent(); - agent.ObservableAttributeHandling = behavior; + var bp = go.GetComponent (); + bp.ObservableAttributeHandling = behavior; agent.LazyInitialize(); int numAttributeSensors = 0; foreach (var sensor in agent.sensors) diff --git a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs index 535fbfb7b3..54cfd5a822 100644 --- a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs +++ b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs @@ -4,6 +4,7 @@ using Unity.MLAgents; using Unity.MLAgents.Policies; using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; using NUnit.Framework; using UnityEngine; using UnityEngine.TestTools; @@ -15,6 +16,9 @@ public class PublicApiAgent : Agent { public int numHeuristicCalls; + [Observable] + public float ObservableFloat; + public override void Heuristic(float[] actionsOut) { numHeuristicCalls++; @@ -69,6 +73,7 @@ public IEnumerator RuntimeApiTestWithEnumeratorPasses() behaviorParams.BehaviorName = "TestBehavior"; behaviorParams.TeamId = 42; behaviorParams.UseChildSensors = true; + behaviorParams.ObservableAttributeHandling = ObservableAttributeOptions.ExamineAll; // Can't actually create an Agent with InferenceOnly and no model, so change back From 4c526d95e3e30f258878bf007139aea53fab0d27 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Fri, 15 May 2020 09:53:53 -0700 Subject: [PATCH 21/23] autoformatting --- com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs | 2 +- .../Tests/Editor/Sensor/ObservableAttributeTests.cs | 3 +-- com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs | 5 ++--- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index a7d47b5b7c..5dcaf0dd53 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -787,7 +787,7 @@ public void TestObservableAttributeBehaviorIgnore() { var go = new GameObject(); var agent = go.AddComponent(); - var bp = go.GetComponent (); + var bp = go.GetComponent(); bp.ObservableAttributeHandling = behavior; agent.LazyInitialize(); int numAttributeSensors = 0; diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs index b7b51cfe3b..b7afb08493 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs @@ -7,7 +7,6 @@ namespace Unity.MLAgents.Tests { - [TestFixture] public class ObservableAttributeTests { @@ -179,8 +178,8 @@ public void TestGetObservableSensors() SensorTestHelper.CompareObservation(sensorsByName["quaternionMember"], new[] { 5.0f, 5.1f, 5.2f, 5.3f }); SensorTestHelper.CompareObservation(sensorsByName["quaternionProperty"], new[] { 5.4f, 5.5f, 5.5f, 5.7f }); - } + [Test] public void TestGetTotalObservationSize() { diff --git a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs index 54cfd5a822..19a5e18bae 100644 --- a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs +++ b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs @@ -1,4 +1,4 @@ -#if UNITY_INCLUDE_TESTS +#if UNITY_INCLUDE_TESTS using System.Collections; using System.Collections.Generic; using Unity.MLAgents; @@ -11,7 +11,6 @@ namespace Tests { - public class PublicApiAgent : Agent { public int numHeuristicCalls; @@ -40,7 +39,7 @@ public override ISensor CreateSensor() public override int[] GetObservationShape() { - int[] shape = (int[]) wrappedComponent.GetObservationShape().Clone(); + int[] shape = (int[])wrappedComponent.GetObservationShape().Clone(); for (var i = 0; i < shape.Length; i++) { shape[i] *= numStacks; From d3ea20f0d23aa92d8bab2e1e1062ac76206770c2 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Fri, 15 May 2020 10:23:21 -0700 Subject: [PATCH 22/23] changelog --- com.unity.ml-agents/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index cfc2544051..48417c9d94 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to - `get_behavior_names()` and `get_behavior_spec()` on UnityEnvironment were replaced by the `behavior_specs` property. (#3946) ### Minor Changes #### com.unity.ml-agents (C#) +- `ObservableAttribute` was added. Adding the attribute to fields or properties on an Agent will allow it to generate + observations via reflection. #### ml-agents / ml-agents-envs / gym-unity (Python) - Curriculum and Parameter Randomization configurations have been merged into the main training configuration file. Note that this means training From 8d29f86fe16936a572d3a07af888eb94bee4c9a2 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Fri, 15 May 2020 13:00:28 -0700 Subject: [PATCH 23/23] fix up sensor creation logic --- .../Sensors/Reflection/ObservableAttribute.cs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs index f5ab13c380..fb056fd09d 100644 --- a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -208,32 +208,31 @@ static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInf { sensor = new IntReflectionSensor(reflectionSensorInfo); } - if (memberType == typeof(float)) + else if (memberType == typeof(float)) { sensor = new FloatReflectionSensor(reflectionSensorInfo); } - if (memberType == typeof(bool)) + else if (memberType == typeof(bool)) { sensor = new BoolReflectionSensor(reflectionSensorInfo); } - if (memberType == typeof(Vector2)) + else if (memberType == typeof(Vector2)) { sensor = new Vector2ReflectionSensor(reflectionSensorInfo); } - if (memberType == typeof(Vector3)) + else if (memberType == typeof(Vector3)) { sensor = new Vector3ReflectionSensor(reflectionSensorInfo); } - if (memberType == typeof(Vector4)) + else if (memberType == typeof(Vector4)) { sensor = new Vector4ReflectionSensor(reflectionSensorInfo); } - if (memberType == typeof(Quaternion)) + else if (memberType == typeof(Quaternion)) { sensor = new QuaternionReflectionSensor(reflectionSensorInfo); } - - if (sensor == null) + else { // For unsupported types, return null and we'll filter them out later. return null;