diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index cfc2544051..48417c9d94 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to - `get_behavior_names()` and `get_behavior_spec()` on UnityEnvironment were replaced by the `behavior_specs` property. (#3946) ### Minor Changes #### com.unity.ml-agents (C#) +- `ObservableAttribute` was added. Adding the attribute to fields or properties on an Agent will allow it to generate + observations via reflection. #### ml-agents / ml-agents-envs / gym-unity (Python) - Curriculum and Parameter Randomization configurations have been merged into the main training configuration file. Note that this means training diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs index 08b49a1826..3eab08da8e 100644 --- a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs +++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs @@ -1,7 +1,9 @@ +using System.Collections.Generic; using UnityEditor; using Unity.Barracuda; using Unity.MLAgents.Policies; using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; using UnityEngine; namespace Unity.MLAgents.Editor @@ -60,6 +62,7 @@ public override void OnInspectorGUI() EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); { EditorGUILayout.PropertyField(so.FindProperty("m_UseChildSensors"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservableAttributeHandling"), true); } EditorGUI.EndDisabledGroup(); @@ -89,6 +92,8 @@ void DisplayFailedModelChecks() Model barracudaModel = null; var model = (NNModel)serializedObject.FindProperty("m_Model").objectReferenceValue; var behaviorParameters = (BehaviorParameters)target; + + // Grab the sensor components, since we need them to determine the observation sizes. SensorComponent[] sensorComponents; if (behaviorParameters.UseChildSensors) { @@ -98,6 +103,21 @@ void DisplayFailedModelChecks() { sensorComponents = behaviorParameters.GetComponents(); } + + // Get the total size of the sensors generated by ObservableAttributes. + // If there are any errors (e.g. unsupported type, write-only properties), display them too. + int observableAttributeSensorTotalSize = 0; + var agent = behaviorParameters.GetComponent(); + if (agent != null && behaviorParameters.ObservableAttributeHandling != ObservableAttributeOptions.Ignore) + { + List observableErrors = new List(); + observableAttributeSensorTotalSize = ObservableAttribute.GetTotalObservationSize(agent, false, observableErrors); + foreach (var check in observableErrors) + { + EditorGUILayout.HelpBox(check, MessageType.Warning); + } + } + var brainParameters = behaviorParameters.BrainParameters; if (model != null) { @@ -106,7 +126,8 @@ void DisplayFailedModelChecks() if (brainParameters != null) { var failedChecks = Inference.BarracudaModelParamLoader.CheckModel( - barracudaModel, brainParameters, sensorComponents, behaviorParameters.BehaviorType + barracudaModel, brainParameters, sensorComponents, + observableAttributeSensorTotalSize, behaviorParameters.BehaviorType ); foreach (var check in failedChecks) { diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index c142abbd52..501e9c589b 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -4,6 +4,7 @@ using UnityEngine; using Unity.Barracuda; using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; using Unity.MLAgents.Demonstrations; using Unity.MLAgents.Policies; using UnityEngine.Serialization; @@ -395,7 +396,11 @@ public void LazyInitialize() m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic); ResetData(); Initialize(); - InitializeSensors(); + + using (TimerStack.Instance.Scoped("InitializeSensors")) + { + InitializeSensors(); + } // The first time the Academy resets, all Agents in the scene will be // forced to reset through the event. @@ -816,6 +821,17 @@ public virtual void Heuristic(float[] actionsOut) /// internal void InitializeSensors() { + if (m_PolicyFactory.ObservableAttributeHandling != ObservableAttributeOptions.Ignore) + { + var excludeInherited = + m_PolicyFactory.ObservableAttributeHandling == ObservableAttributeOptions.ExcludeInherited; + using (TimerStack.Instance.Scoped("CreateObservableSensors")) + { + var observableSensors = ObservableAttribute.CreateObservableSensors(this, excludeInherited); + sensors.AddRange(observableSensors); + } + } + // Get all attached sensor components SensorComponent[] attachedSensorComponents; if (m_PolicyFactory.UseChildSensors) diff --git a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs index b83ed2a9a6..41efa93653 100644 --- a/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs +++ b/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs @@ -126,10 +126,12 @@ public static string[] GetOutputNames(Model model) /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// Attached sensor components + /// Sum of the sizes of all ObservableAttributes. /// BehaviorType or the Agent to check. /// The list the error messages of the checks that failed public static IEnumerable CheckModel(Model model, BrainParameters brainParameters, - SensorComponent[] sensorComponents, BehaviorType behaviorType = BehaviorType.Default) + SensorComponent[] sensorComponents, int observableAttributeTotalSize = 0, + BehaviorType behaviorType = BehaviorType.Default) { List failedModelChecks = new List(); if (model == null) @@ -182,7 +184,7 @@ public static IEnumerable CheckModel(Model model, BrainParameters brainP CheckOutputTensorPresence(model, memorySize)) ; failedModelChecks.AddRange( - CheckInputTensorShape(model, brainParameters, sensorComponents) + CheckInputTensorShape(model, brainParameters, sensorComponents, observableAttributeTotalSize) ); failedModelChecks.AddRange( CheckOutputTensorShape(model, brainParameters, isContinuous, actionSize) @@ -253,6 +255,7 @@ static IEnumerable CheckIntScalarPresenceHelper( /// Whether the model is expecting continuous or discrete control. /// /// Array of attached sensor components + /// Total size of ObservableAttributes /// /// A IEnumerable of string corresponding to the failed input presence checks. /// @@ -404,25 +407,27 @@ static string CheckVisualObsShape( /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// Attached sensors + /// Sum of the sizes of all ObservableAttributes. /// The list the error messages of the checks that failed static IEnumerable CheckInputTensorShape( - Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents) + Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents, + int observableAttributeTotalSize) { var failedModelChecks = new List(); var tensorTester = - new Dictionary>() + new Dictionary>() { {TensorNames.VectorObservationPlaceholder, CheckVectorObsShape}, {TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape}, - {TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs) => null)}, - {TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs) => null)}, - {TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs) => null)}, - {TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs) => null)}, + {TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)}, + {TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)}, + {TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)}, + {TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)}, }; foreach (var mem in model.memories) { - tensorTester[mem.input] = ((bp, tensor, scs) => null); + tensorTester[mem.input] = ((bp, tensor, scs, i) => null); } var visObsIndex = 0; @@ -434,7 +439,7 @@ static IEnumerable CheckInputTensorShape( continue; } tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] = - (bp, tensor, scs) => CheckVisualObsShape(tensor, sensorComponent); + (bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent); visObsIndex++; } @@ -452,7 +457,7 @@ static IEnumerable CheckInputTensorShape( else { var tester = tensorTester[tensor.name]; - var error = tester.Invoke(brainParameters, tensor, sensorComponents); + var error = tester.Invoke(brainParameters, tensor, sensorComponents, observableAttributeTotalSize); if (error != null) { failedModelChecks.Add(error); @@ -471,12 +476,14 @@ static IEnumerable CheckInputTensorShape( /// /// The tensor that is expected by the model /// Array of attached sensor components + /// Sum of the sizes of all ObservableAttributes. /// /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. /// static string CheckVectorObsShape( - BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents) + BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents, + int observableAttributeTotalSize) { var vecObsSizeBp = brainParameters.VectorObservationSize; var numStackedVector = brainParameters.NumStackedVectorObservations; @@ -491,6 +498,8 @@ static string CheckVectorObsShape( } } + totalVectorSensorSize += observableAttributeTotalSize; + if (vecObsSizeBp * numStackedVector + totalVectorSensorSize != totalVecObsSizeT) { var sensorSizes = ""; @@ -512,7 +521,9 @@ static string CheckVectorObsShape( sensorSizes += "]"; return $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " + - $"but received {vecObsSizeBp} x {numStackedVector} vector observations and " + + $"but received: \n" + + $"Vector observations: {vecObsSizeBp} x {numStackedVector}\n" + + $"Total [Observable] attributes: {observableAttributeTotalSize}\n" + $"SensorComponent sizes: {sensorSizes}."; } return null; @@ -526,11 +537,13 @@ static string CheckVectorObsShape( /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// The tensor that is expected by the model - /// Array of attached sensor components + /// Array of attached sensor components (unused). + /// Sum of the sizes of all ObservableAttributes (unused). /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. static string CheckPreviousActionShape( - BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents) + BrainParameters brainParameters, TensorProxy tensorProxy, + SensorComponent[] sensorComponents, int observableAttributeTotalSize) { var numberActionsBp = brainParameters.VectorActionSize.Length; var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1]; diff --git a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs index 013df3bad4..4234fc3f79 100644 --- a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs +++ b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs @@ -2,6 +2,7 @@ using System; using UnityEngine; using UnityEngine.Serialization; +using Unity.MLAgents.Sensors.Reflection; namespace Unity.MLAgents.Policies { @@ -30,6 +31,36 @@ public enum BehaviorType InferenceOnly } + /// + /// Options for controlling how the Agent class is searched for s. + /// + public enum ObservableAttributeOptions + { + /// + /// All ObservableAttributes on the Agent will be ignored. If there are no + /// ObservableAttributes on the Agent, this will result in the fastest + /// initialization time. + /// + Ignore, + + /// + /// Only members on the declared class will be examined; members that are + /// inherited are ignored. This is the default behavior, and a reasonable + /// tradeoff between performance and flexibility. + /// + /// This corresponds to setting the + /// [BindingFlags.DeclaredOnly](https://docs.microsoft.com/en-us/dotnet/api/system.reflection.bindingflags?view=netcore-3.1) + /// when examining the fields and properties of the Agent class instance. + /// + ExcludeInherited, + + /// + /// All members on the class will be examined. This can lead to slower + /// startup times + /// + ExamineAll + } + /// /// A component for setting an instance's behavior and /// brain properties. @@ -129,6 +160,18 @@ public bool UseChildSensors set { m_UseChildSensors = value; } } + [HideInInspector, SerializeField] + ObservableAttributeOptions m_ObservableAttributeHandling = ObservableAttributeOptions.Ignore; + + /// + /// Determines how the Agent class is searched for s. + /// + public ObservableAttributeOptions ObservableAttributeHandling + { + get { return m_ObservableAttributeHandling; } + set { m_ObservableAttributeHandling = value; } + } + /// /// Returns the behavior name, concatenated with any other metadata (i.e. team id). /// diff --git a/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs b/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs index e427ead794..2141cc0e49 100644 --- a/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs +++ b/com.unity.ml-agents/Runtime/Policies/BrainParameters.cs @@ -34,11 +34,9 @@ public enum SpaceType public class BrainParameters { /// - /// The size of the observation space. - /// - /// An agent creates the observation vector in its + /// The number of the observations that are added in /// - /// implementation. + /// /// /// The length of the vector containing observation values. /// diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection.meta new file mode 100644 index 0000000000..fb7288f717 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 08ece3d7e9bb94089a9d59c6f269ab0a +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs new file mode 100644 index 0000000000..0bd8e60b88 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs @@ -0,0 +1,19 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps a boolean field or property of an object, and returns + /// that as an observation. + /// + internal class BoolReflectionSensor : ReflectionSensorBase + { + internal BoolReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 1) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var boolVal = (System.Boolean)GetReflectedValue(); + writer[0] = boolVal ? 1.0f : 0.0f; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta new file mode 100644 index 0000000000..5cac420f11 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: be795c90750a6420d93f569b69ddc1ba +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs new file mode 100644 index 0000000000..47daa282d3 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs @@ -0,0 +1,19 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps a float field or property of an object, and returns + /// that as an observation. + /// + internal class FloatReflectionSensor : ReflectionSensorBase + { + internal FloatReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 1) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var floatVal = (System.Single)GetReflectedValue(); + writer[0] = floatVal; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta new file mode 100644 index 0000000000..2de8b18c7c --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 51ed837d5b7cd44349287ac8066120fc +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs new file mode 100644 index 0000000000..93149275f5 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs @@ -0,0 +1,19 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps an integer field or property of an object, and returns + /// that as an observation. + /// + internal class IntReflectionSensor : ReflectionSensorBase + { + internal IntReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 1) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var intVal = (System.Int32)GetReflectedValue(); + writer[0] = (float)intVal; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta new file mode 100644 index 0000000000..a07726937f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5cae4c843cc074d11a549aaa3904c898 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs new file mode 100644 index 0000000000..fb056fd09d --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs @@ -0,0 +1,295 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using UnityEngine; + +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Specify that a field or property should be used to generate observations for an Agent. + /// For each field or property that uses ObservableAttribute, a corresponding + /// will be created during Agent initialization, and this + /// sensor will read the values during training and inference. + /// + /// + /// ObservableAttribute is intended to make initial setup of an Agent easier. Because it + /// uses reflection to read the values of fields and properties at runtime, this may + /// be much slower than reading the values directly. If the performance of + /// ObservableAttribute is an issue, you can get the same functionality by overriding + /// or creating a custom + /// implementation to read the values without reflection. + /// + /// Note that you do not need to adjust the VectorObservationSize in + /// when adding ObservableAttribute + /// to fields or properties. + /// + /// + /// This sample class will produce two observations, one for the m_Health field, and one + /// for the HealthPercent property. + /// + /// using Unity.MLAgents; + /// using Unity.MLAgents.Sensors.Reflection; + /// + /// public class MyAgent : Agent + /// { + /// [Observable] + /// int m_Health; + /// + /// [Observable] + /// float HealthPercent + /// { + /// get => return 100.0f * m_Health / float(m_MaxHealth); + /// } + /// } + /// + /// + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] + public class ObservableAttribute : Attribute + { + string m_Name; + int m_NumStackedObservations; + + /// + /// Default binding flags used for reflection of members and properties. + /// + const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; + + /// + /// Supported types and their observation sizes. + /// + static Dictionary s_TypeSizes = new Dictionary() + { + {typeof(int), 1}, + {typeof(bool), 1}, + {typeof(float), 1}, + {typeof(Vector2), 2}, + {typeof(Vector3), 3}, + {typeof(Vector4), 4}, + {typeof(Quaternion), 4}, + }; + + /// + /// ObservableAttribute constructor. + /// + /// Optional override for the sensor name. Note that all sensors for an Agent + /// must have a unique name. + /// Number of frames to concatenate observations from. + public ObservableAttribute(string name = null, int numStackedObservations = 1) + { + m_Name = name; + m_NumStackedObservations = numStackedObservations; + } + + /// + /// Returns a FieldInfo for all fields that have an ObservableAttribute + /// + /// Object being reflected + /// Whether to exclude inherited properties or not. + /// + static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool excludeInherited) + { + // TODO cache these (and properties) by type, so that we only have to reflect once. + var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); + var fields = o.GetType().GetFields(bindingFlags); + foreach (var field in fields) + { + var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute)); + if (attr != null) + { + yield return (field, attr); + } + } + } + + /// + /// Returns a PropertyInfo for all fields that have an ObservableAttribute + /// + /// Object being reflected + /// Whether to exclude inherited properties or not. + /// + static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool excludeInherited) + { + var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); + var properties = o.GetType().GetProperties(bindingFlags); + foreach (var prop in properties) + { + var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); + if (attr != null) + { + yield return (prop, attr); + } + } + } + + /// + /// Creates sensors for each field and property with ObservableAttribute. + /// + /// Object being reflected + /// Whether to exclude inherited properties or not. + /// + internal static List CreateObservableSensors(object o, bool excludeInherited) + { + var sensorsOut = new List(); + foreach (var(field, attr) in GetObservableFields(o, excludeInherited)) + { + var sensor = CreateReflectionSensor(o, field, null, attr); + if (sensor != null) + { + sensorsOut.Add(sensor); + } + } + + foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited)) + { + if (!prop.CanRead) + { + // Skip unreadable properties. + continue; + } + var sensor = CreateReflectionSensor(o, null, prop, attr); + if (sensor != null) + { + sensorsOut.Add(sensor); + } + } + + return sensorsOut; + } + + /// + /// Create the ISensor for either the field or property on the provided object. + /// If the data type is unsupported, or the property is write-only, returns null. + /// + /// + /// + /// + /// + /// + /// + static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute) + { + string memberName; + string declaringTypeName; + Type memberType; + if (fieldInfo != null) + { + declaringTypeName = fieldInfo.DeclaringType.Name; + memberName = fieldInfo.Name; + memberType = fieldInfo.FieldType; + } + else + { + declaringTypeName = propertyInfo.DeclaringType.Name; + memberName = propertyInfo.Name; + memberType = propertyInfo.PropertyType; + } + + string sensorName; + if (string.IsNullOrEmpty(observableAttribute.m_Name)) + { + sensorName = $"ObservableAttribute:{declaringTypeName}.{memberName}"; + } + else + { + sensorName = observableAttribute.m_Name; + } + + var reflectionSensorInfo = new ReflectionSensorInfo + { + Object = o, + FieldInfo = fieldInfo, + PropertyInfo = propertyInfo, + ObservableAttribute = observableAttribute, + SensorName = sensorName + }; + + ISensor sensor = null; + if (memberType == typeof(Int32)) + { + sensor = new IntReflectionSensor(reflectionSensorInfo); + } + else if (memberType == typeof(float)) + { + sensor = new FloatReflectionSensor(reflectionSensorInfo); + } + else if (memberType == typeof(bool)) + { + sensor = new BoolReflectionSensor(reflectionSensorInfo); + } + else if (memberType == typeof(Vector2)) + { + sensor = new Vector2ReflectionSensor(reflectionSensorInfo); + } + else if (memberType == typeof(Vector3)) + { + sensor = new Vector3ReflectionSensor(reflectionSensorInfo); + } + else if (memberType == typeof(Vector4)) + { + sensor = new Vector4ReflectionSensor(reflectionSensorInfo); + } + else if (memberType == typeof(Quaternion)) + { + sensor = new QuaternionReflectionSensor(reflectionSensorInfo); + } + else + { + // For unsupported types, return null and we'll filter them out later. + return null; + } + + // Wrap the base sensor in a StackingSensor if we're using stacking. + if (observableAttribute.m_NumStackedObservations > 1) + { + return new StackingSensor(sensor, observableAttribute.m_NumStackedObservations); + } + + return sensor; + } + + /// + /// Gets the sum of the observation sizes of the Observable fields and properties on an object. + /// Also appends errors to the errorsOut array. + /// + /// + /// + /// + /// + internal static int GetTotalObservationSize(object o, bool excludeInherited, List errorsOut) + { + int sizeOut = 0; + foreach (var(field, attr) in GetObservableFields(o, excludeInherited)) + { + if (s_TypeSizes.ContainsKey(field.FieldType)) + { + sizeOut += s_TypeSizes[field.FieldType] * attr.m_NumStackedObservations; + } + else + { + errorsOut.Add($"Unsupported Observable type {field.FieldType.Name} on field {field.Name}"); + } + } + + foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited)) + { + if (s_TypeSizes.ContainsKey(prop.PropertyType)) + { + if (prop.CanRead) + { + sizeOut += s_TypeSizes[prop.PropertyType] * attr.m_NumStackedObservations; + } + else + { + errorsOut.Add($"Observable property {prop.Name} is write-only."); + } + } + else + { + errorsOut.Add($"Unsupported Observable type {prop.PropertyType.Name} on property {prop.Name}"); + } + } + + return sizeOut; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta new file mode 100644 index 0000000000..41659283da --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a75086dc66a594baea6b8b2935f5dacf +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs new file mode 100644 index 0000000000..41e9f5d22f --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs @@ -0,0 +1,22 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps a quaternion field or property of an object, and returns + /// that as an observation. + /// + internal class QuaternionReflectionSensor : ReflectionSensorBase + { + internal QuaternionReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 4) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var quatVal = (UnityEngine.Quaternion)GetReflectedValue(); + writer[0] = quatVal.x; + writer[1] = quatVal.y; + writer[2] = quatVal.z; + writer[3] = quatVal.w; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta new file mode 100644 index 0000000000..f3970e6b51 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: d38241d74074d459bb4590f7f5d16c80 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs new file mode 100644 index 0000000000..410a89b2bf --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs @@ -0,0 +1,97 @@ +using System.Reflection; + +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Construction info for a ReflectionSensorBase. + /// + internal struct ReflectionSensorInfo + { + public object Object; + + public FieldInfo FieldInfo; + public PropertyInfo PropertyInfo; + public ObservableAttribute ObservableAttribute; + public string SensorName; + } + + /// + /// Abstract base class for reflection-based sensors. + /// + internal abstract class ReflectionSensorBase : ISensor + { + protected object m_Object; + + // Exactly one of m_FieldInfo and m_PropertyInfo should be non-null. + protected FieldInfo m_FieldInfo; + protected PropertyInfo m_PropertyInfo; + + // Not currently used, but might want later. + protected ObservableAttribute m_ObservableAttribute; + + // Cached sensor names and shapes. + string m_SensorName; + int[] m_Shape; + + public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size) + { + m_Object = reflectionSensorInfo.Object; + m_FieldInfo = reflectionSensorInfo.FieldInfo; + m_PropertyInfo = reflectionSensorInfo.PropertyInfo; + m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute; + m_SensorName = reflectionSensorInfo.SensorName; + m_Shape = new[] {size}; + } + + /// + public int[] GetObservationShape() + { + return m_Shape; + } + + /// + public int Write(ObservationWriter writer) + { + WriteReflectedField(writer); + return m_Shape[0]; + } + + internal abstract void WriteReflectedField(ObservationWriter writer); + + /// + /// Get either the reflected field, or return the reflected property. + /// This should be used by implementations in their WriteReflectedField() method. + /// + /// + protected object GetReflectedValue() + { + return m_FieldInfo != null ? + m_FieldInfo.GetValue(m_Object) : + m_PropertyInfo.GetMethod.Invoke(m_Object, null); + } + + /// + public byte[] GetCompressedObservation() + { + return null; + } + + /// + public void Update() {} + + /// + public void Reset() {} + + /// + public SensorCompressionType GetCompressionType() + { + return SensorCompressionType.None; + } + + /// + public string GetName() + { + return m_SensorName; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta new file mode 100644 index 0000000000..cef19bb598 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 6b68d855fb94a45fbbeb0dbe968a35f8 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs new file mode 100644 index 0000000000..5523c89ba4 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs @@ -0,0 +1,20 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps a Vector2 field or property of an object, and returns + /// that as an observation. + /// + internal class Vector2ReflectionSensor : ReflectionSensorBase + { + internal Vector2ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 2) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var vecVal = (UnityEngine.Vector2)GetReflectedValue(); + writer[0] = vecVal.x; + writer[1] = vecVal.y; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta new file mode 100644 index 0000000000..2b78c25ffe --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: da06ff33f6f2d409cbf240cffa2ba0be +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs new file mode 100644 index 0000000000..d7268084c8 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs @@ -0,0 +1,21 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps a Vector3 field or property of an object, and returns + /// that as an observation. + /// + internal class Vector3ReflectionSensor : ReflectionSensorBase + { + internal Vector3ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 3) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var vecVal = (UnityEngine.Vector3)GetReflectedValue(); + writer[0] = vecVal.x; + writer[1] = vecVal.y; + writer[2] = vecVal.z; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta new file mode 100644 index 0000000000..771b690b07 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e756976ec2a0943cfbc0f97a6550a85b +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs new file mode 100644 index 0000000000..4994d4dbbb --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs @@ -0,0 +1,22 @@ +namespace Unity.MLAgents.Sensors.Reflection +{ + /// + /// Sensor that wraps a Vector4 field or property of an object, and returns + /// that as an observation. + /// + internal class Vector4ReflectionSensor : ReflectionSensorBase + { + internal Vector4ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) + : base(reflectionSensorInfo, 4) + {} + + internal override void WriteReflectedField(ObservationWriter writer) + { + var vecVal = (UnityEngine.Vector4)GetReflectedValue(); + writer[0] = vecVal.x; + writer[1] = vecVal.y; + writer[2] = vecVal.z; + writer[3] = vecVal.w; + } + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta new file mode 100644 index 0000000000..3d938af6c8 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 01d93aaa1b42b47b8960d303d7c498d3 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta new file mode 100644 index 0000000000..73dd8b25ac --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e5e4df2934c014aa3b835b9eb9ad20b3 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs b/com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs index c5434b842f..b19ff8f967 100644 --- a/com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs +++ b/com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs @@ -4,6 +4,7 @@ using UnityEngine; using Unity.MLAgents.Inference; using Unity.MLAgents.Policies; +using Unity.MLAgents.Sensors.Reflection; namespace Unity.MLAgents.Tests { @@ -19,18 +20,20 @@ public void SetUp() } } - static List GetFakeAgents() + static List GetFakeAgents(ObservableAttributeOptions observableAttributeOptions = ObservableAttributeOptions.Ignore) { var goA = new GameObject("goA"); var bpA = goA.AddComponent(); bpA.BrainParameters.VectorObservationSize = 3; bpA.BrainParameters.NumStackedVectorObservations = 1; + bpA.ObservableAttributeHandling = observableAttributeOptions; var agentA = goA.AddComponent(); var goB = new GameObject("goB"); var bpB = goB.AddComponent(); bpB.BrainParameters.VectorObservationSize = 3; bpB.BrainParameters.NumStackedVectorObservations = 1; + bpB.ObservableAttributeHandling = observableAttributeOptions; var agentB = goB.AddComponent(); var agents = new List { agentA, agentB }; @@ -100,15 +103,16 @@ public void GenerateVectorObservation() { var inputTensor = new TensorProxy { - shape = new long[] { 2, 3 } + shape = new long[] { 2, 4 } }; const int batchSize = 4; - var agentInfos = GetFakeAgents(); + var agentInfos = GetFakeAgents(ObservableAttributeOptions.ExamineAll); var alloc = new TensorCachingAllocator(); var generator = new VectorObservationGenerator(alloc); - generator.AddSensorIndex(0); - generator.AddSensorIndex(1); - generator.AddSensorIndex(2); + generator.AddSensorIndex(0); // ObservableAttribute (size 1) + generator.AddSensorIndex(1); // TestSensor (size 0) + generator.AddSensorIndex(2); // TestSensor (size 0) + generator.AddSensorIndex(3); // VectorSensor (size 3) var agent0 = agentInfos[0]; var agent1 = agentInfos[1]; var inputs = new List @@ -118,10 +122,10 @@ public void GenerateVectorObservation() }; generator.Generate(inputTensor, batchSize, inputs); Assert.IsNotNull(inputTensor.data); - Assert.AreEqual(inputTensor.data[0, 0], 1); - Assert.AreEqual(inputTensor.data[0, 2], 3); - Assert.AreEqual(inputTensor.data[1, 0], 4); - Assert.AreEqual(inputTensor.data[1, 2], 6); + Assert.AreEqual(inputTensor.data[0, 1], 1); + Assert.AreEqual(inputTensor.data[0, 3], 3); + Assert.AreEqual(inputTensor.data[1, 1], 4); + Assert.AreEqual(inputTensor.data[1, 3], 6); alloc.Dispose(); } diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index 3b3e6ca1d4..5dcaf0dd53 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -4,6 +4,7 @@ using System.Reflection; using System.Collections.Generic; using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; using Unity.MLAgents.Policies; using Unity.MLAgents.SideChannels; @@ -61,6 +62,9 @@ internal IPolicy GetPolicy() public TestSensor sensor1; public TestSensor sensor2; + [Observable("observableFloat")] + public float observableFloat; + public override void Initialize() { initializeAgentCalls += 1; @@ -246,6 +250,9 @@ public void TestAgent() var agentGo1 = new GameObject("TestAgent"); agentGo1.AddComponent(); var agent1 = agentGo1.GetComponent(); + var bp1 = agentGo1.GetComponent(); + bp1.ObservableAttributeHandling = ObservableAttributeOptions.ExcludeInherited; + var agentGo2 = new GameObject("TestAgent"); agentGo2.AddComponent(); var agent2 = agentGo2.GetComponent(); @@ -271,8 +278,13 @@ public void TestAgent() Assert.AreEqual(0, agent2.agentActionCalls); // Make sure the Sensors were sorted - Assert.AreEqual(agent1.sensors[0].GetName(), "testsensor1"); - Assert.AreEqual(agent1.sensors[1].GetName(), "testsensor2"); + Assert.AreEqual(agent1.sensors[0].GetName(), "observableFloat"); + Assert.AreEqual(agent1.sensors[1].GetName(), "testsensor1"); + Assert.AreEqual(agent1.sensors[2].GetName(), "testsensor2"); + + // agent2 should only have two sensors (no observableFloat) + Assert.AreEqual(agent2.sensors[0].GetName(), "testsensor1"); + Assert.AreEqual(agent2.sensors[1].GetName(), "testsensor2"); } } @@ -741,4 +753,53 @@ public void TestAgentDontCallBaseOnEnable() _InnerAgentTestOnEnableOverride(); } } + + [TestFixture] + public class ObservableAttributeBehaviorTests + { + public class BaseObservableAgent : Agent + { + [Observable] + public float BaseField; + } + + public class DerivedObservableAgent : BaseObservableAgent + { + [Observable] + public float DerivedField; + } + + + [Test] + public void TestObservableAttributeBehaviorIgnore() + { + var variants = new[] + { + // No observables found + (ObservableAttributeOptions.Ignore, 0), + // Only DerivedField found + (ObservableAttributeOptions.ExcludeInherited, 1), + // DerivedField and BaseField found + (ObservableAttributeOptions.ExamineAll, 2) + }; + + foreach (var(behavior, expectedNumSensors) in variants) + { + var go = new GameObject(); + var agent = go.AddComponent(); + var bp = go.GetComponent(); + bp.ObservableAttributeHandling = behavior; + agent.LazyInitialize(); + int numAttributeSensors = 0; + foreach (var sensor in agent.sensors) + { + if (sensor.GetType() != typeof(VectorSensor)) + { + numAttributeSensors++; + } + } + Assert.AreEqual(expectedNumSensors, numAttributeSensors); + } + } + } } diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs new file mode 100644 index 0000000000..b7afb08493 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs @@ -0,0 +1,292 @@ +using System; +using System.Collections.Generic; +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; + +namespace Unity.MLAgents.Tests +{ + [TestFixture] + public class ObservableAttributeTests + { + class TestClass + { + // Non-observables + int m_NonObservableInt; + float m_NonObservableFloat; + + // + // Int + // + [Observable] + public int m_IntMember; + + int m_IntProperty; + + [Observable] + public int IntProperty + { + get => m_IntProperty; + set => m_IntProperty = value; + } + + // + // Float + // + [Observable("floatMember")] + public float m_FloatMember; + + float m_FloatProperty; + [Observable("floatProperty")] + public float FloatProperty + { + get => m_FloatProperty; + set => m_FloatProperty = value; + } + + // + // Bool + // + [Observable("boolMember")] + public bool m_BoolMember; + + bool m_BoolProperty; + [Observable("boolProperty")] + public bool BoolProperty + { + get => m_BoolProperty; + set => m_BoolProperty = value; + } + + // + // Vector2 + // + + [Observable("vector2Member")] + public Vector2 m_Vector2Member; + + Vector2 m_Vector2Property; + + [Observable("vector2Property")] + public Vector2 Vector2Property + { + get => m_Vector2Property; + set => m_Vector2Property = value; + } + + // + // Vector3 + // + [Observable("vector3Member")] + public Vector3 m_Vector3Member; + + Vector3 m_Vector3Property; + + [Observable("vector3Property")] + public Vector3 Vector3Property + { + get => m_Vector3Property; + set => m_Vector3Property = value; + } + + // + // Vector4 + // + + [Observable("vector4Member")] + public Vector4 m_Vector4Member; + + Vector4 m_Vector4Property; + + [Observable("vector4Property")] + public Vector4 Vector4Property + { + get => m_Vector4Property; + set => m_Vector4Property = value; + } + + // + // Quaternion + // + [Observable("quaternionMember")] + public Quaternion m_QuaternionMember; + + Quaternion m_QuaternionProperty; + + [Observable("quaternionProperty")] + public Quaternion QuaternionProperty + { + get => m_QuaternionProperty; + set => m_QuaternionProperty = value; + } + } + + [Test] + public void TestGetObservableSensors() + { + var testClass = new TestClass(); + testClass.m_IntMember = 1; + testClass.IntProperty = 2; + + testClass.m_FloatMember = 1.1f; + testClass.FloatProperty = 1.2f; + + testClass.m_BoolMember = true; + testClass.BoolProperty = true; + + testClass.m_Vector2Member = new Vector2(2.0f, 2.1f); + testClass.Vector2Property = new Vector2(2.2f, 2.3f); + + testClass.m_Vector3Member = new Vector3(3.0f, 3.1f, 3.2f); + testClass.Vector3Property = new Vector3(3.3f, 3.4f, 3.5f); + + testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f); + testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f); + + testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f); + testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f); + + testClass.m_QuaternionMember = new Quaternion(5.0f, 5.1f, 5.2f, 5.3f); + testClass.QuaternionProperty = new Quaternion(5.4f, 5.5f, 5.5f, 5.7f); + + var sensors = ObservableAttribute.CreateObservableSensors(testClass, false); + + var sensorsByName = new Dictionary(); + foreach (var sensor in sensors) + { + sensorsByName[sensor.GetName()] = sensor; + } + + SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.m_IntMember"], new[] { 1.0f }); + SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.IntProperty"], new[] { 2.0f }); + + SensorTestHelper.CompareObservation(sensorsByName["floatMember"], new[] { 1.1f }); + SensorTestHelper.CompareObservation(sensorsByName["floatProperty"], new[] { 1.2f }); + + SensorTestHelper.CompareObservation(sensorsByName["boolMember"], new[] { 1.0f }); + SensorTestHelper.CompareObservation(sensorsByName["boolProperty"], new[] { 1.0f }); + + SensorTestHelper.CompareObservation(sensorsByName["vector2Member"], new[] { 2.0f, 2.1f }); + SensorTestHelper.CompareObservation(sensorsByName["vector2Property"], new[] { 2.2f, 2.3f }); + + SensorTestHelper.CompareObservation(sensorsByName["vector3Member"], new[] { 3.0f, 3.1f, 3.2f }); + SensorTestHelper.CompareObservation(sensorsByName["vector3Property"], new[] { 3.3f, 3.4f, 3.5f }); + + SensorTestHelper.CompareObservation(sensorsByName["vector4Member"], new[] { 4.0f, 4.1f, 4.2f, 4.3f }); + SensorTestHelper.CompareObservation(sensorsByName["vector4Property"], new[] { 4.4f, 4.5f, 4.5f, 4.7f }); + + SensorTestHelper.CompareObservation(sensorsByName["quaternionMember"], new[] { 5.0f, 5.1f, 5.2f, 5.3f }); + SensorTestHelper.CompareObservation(sensorsByName["quaternionProperty"], new[] { 5.4f, 5.5f, 5.5f, 5.7f }); + } + + [Test] + public void TestGetTotalObservationSize() + { + var testClass = new TestClass(); + var errors = new List(); + var expectedObsSize = 2 * (1 + 1 + 1 + 2 + 3 + 4 + 4); + Assert.AreEqual(expectedObsSize, ObservableAttribute.GetTotalObservationSize(testClass, false, errors)); + Assert.AreEqual(0, errors.Count); + } + + class BadClass + { + [Observable] + double m_Double; + + [Observable] + double DoubleProperty + { + get => m_Double; + set => m_Double = value; + } + + float m_WriteOnlyProperty; + + [Observable] + // No get property, so we shouldn't be able to make a sensor out of this. + public float WriteOnlyProperty + { + set => m_WriteOnlyProperty = value; + } + } + + [Test] + public void TestInvalidObservables() + { + var bad = new BadClass(); + bad.WriteOnlyProperty = 1.0f; + var errors = new List(); + Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, false, errors)); + Assert.AreEqual(3, errors.Count); + + // Should be able to safely generate sensors (and get nothing back) + var sensors = ObservableAttribute.CreateObservableSensors(bad, false); + Assert.AreEqual(0, sensors.Count); + } + + class StackingClass + { + [Observable(numStackedObservations: 2)] + public float FloatVal; + } + + [Test] + public void TestObservableAttributeStacking() + { + var c = new StackingClass(); + c.FloatVal = 1.0f; + var sensors = ObservableAttribute.CreateObservableSensors(c, false); + var sensor = sensors[0]; + Assert.AreEqual(typeof(StackingSensor), sensor.GetType()); + SensorTestHelper.CompareObservation(sensor, new[] { 0.0f, 1.0f }); + + sensor.Update(); + c.FloatVal = 3.0f; + SensorTestHelper.CompareObservation(sensor, new[] { 1.0f, 3.0f }); + + var errors = new List(); + Assert.AreEqual(2, ObservableAttribute.GetTotalObservationSize(c, false, errors)); + Assert.AreEqual(0, errors.Count); + } + + class BaseClass + { + [Observable("base")] + public float m_BaseField; + + [Observable("private")] + float m_PrivateField; + } + + class DerivedClass : BaseClass + { + [Observable("derived")] + float m_DerivedField; + } + + [Test] + public void TestObservableAttributeExcludeInherited() + { + var d = new DerivedClass(); + d.m_BaseField = 1.0f; + + // excludeInherited=false will get fields in the derived class, plus public and protected inherited fields + var sensorAll = ObservableAttribute.CreateObservableSensors(d, false); + Assert.AreEqual(2, sensorAll.Count); + // Note - actual order doesn't matter here, we can change this to use a HashSet if neeed. + Assert.AreEqual("derived", sensorAll[0].GetName()); + Assert.AreEqual("base", sensorAll[1].GetName()); + + // excludeInherited=true will only get fields in the derived class + var sensorsDerivedOnly = ObservableAttribute.CreateObservableSensors(d, true); + Assert.AreEqual(1, sensorsDerivedOnly.Count); + Assert.AreEqual("derived", sensorsDerivedOnly[0].GetName()); + + var b = new BaseClass(); + var baseSensors = ObservableAttribute.CreateObservableSensors(b, false); + Assert.AreEqual(2, baseSensors.Count); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta new file mode 100644 index 0000000000..611fdcfa12 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 33d7912e6b3504412bd261b40e46df32 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs index 535fbfb7b3..19a5e18bae 100644 --- a/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs +++ b/com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs @@ -1,20 +1,23 @@ -#if UNITY_INCLUDE_TESTS +#if UNITY_INCLUDE_TESTS using System.Collections; using System.Collections.Generic; using Unity.MLAgents; using Unity.MLAgents.Policies; using Unity.MLAgents.Sensors; +using Unity.MLAgents.Sensors.Reflection; using NUnit.Framework; using UnityEngine; using UnityEngine.TestTools; namespace Tests { - public class PublicApiAgent : Agent { public int numHeuristicCalls; + [Observable] + public float ObservableFloat; + public override void Heuristic(float[] actionsOut) { numHeuristicCalls++; @@ -36,7 +39,7 @@ public override ISensor CreateSensor() public override int[] GetObservationShape() { - int[] shape = (int[]) wrappedComponent.GetObservationShape().Clone(); + int[] shape = (int[])wrappedComponent.GetObservationShape().Clone(); for (var i = 0; i < shape.Length; i++) { shape[i] *= numStacks; @@ -69,6 +72,7 @@ public IEnumerator RuntimeApiTestWithEnumeratorPasses() behaviorParams.BehaviorName = "TestBehavior"; behaviorParams.TeamId = 42; behaviorParams.UseChildSensors = true; + behaviorParams.ObservableAttributeHandling = ObservableAttributeOptions.ExamineAll; // Can't actually create an Agent with InferenceOnly and no model, so change back