Skip to content

ObservableAttribute #3925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
May 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to
- `get_behavior_names()` and `get_behavior_spec()` on UnityEnvironment were replaced by the `behavior_specs` property. (#3946)
### Minor Changes
#### com.unity.ml-agents (C#)
- `ObservableAttribute` was added. Adding the attribute to fields or properties on an Agent will allow it to generate
observations via reflection.
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Curriculum and Parameter Randomization configurations have been merged
into the main training configuration file. Note that this means training
Expand Down
23 changes: 22 additions & 1 deletion com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using System.Collections.Generic;
using UnityEditor;
using Unity.Barracuda;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using UnityEngine;

namespace Unity.MLAgents.Editor
Expand Down Expand Up @@ -60,6 +62,7 @@ public override void OnInspectorGUI()
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
EditorGUILayout.PropertyField(so.FindProperty("m_UseChildSensors"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservableAttributeHandling"), true);
}
EditorGUI.EndDisabledGroup();

Expand Down Expand Up @@ -89,6 +92,8 @@ void DisplayFailedModelChecks()
Model barracudaModel = null;
var model = (NNModel)serializedObject.FindProperty("m_Model").objectReferenceValue;
var behaviorParameters = (BehaviorParameters)target;

// Grab the sensor components, since we need them to determine the observation sizes.
SensorComponent[] sensorComponents;
if (behaviorParameters.UseChildSensors)
{
Expand All @@ -98,6 +103,21 @@ void DisplayFailedModelChecks()
{
sensorComponents = behaviorParameters.GetComponents<SensorComponent>();
}

// Get the total size of the sensors generated by ObservableAttributes.
// If there are any errors (e.g. unsupported type, write-only properties), display them too.
int observableAttributeSensorTotalSize = 0;
var agent = behaviorParameters.GetComponent<Agent>();
if (agent != null && behaviorParameters.ObservableAttributeHandling != ObservableAttributeOptions.Ignore)
{
List<string> observableErrors = new List<string>();
observableAttributeSensorTotalSize = ObservableAttribute.GetTotalObservationSize(agent, false, observableErrors);
foreach (var check in observableErrors)
{
EditorGUILayout.HelpBox(check, MessageType.Warning);
}
}

var brainParameters = behaviorParameters.BrainParameters;
if (model != null)
{
Expand All @@ -106,7 +126,8 @@ void DisplayFailedModelChecks()
if (brainParameters != null)
{
var failedChecks = Inference.BarracudaModelParamLoader.CheckModel(
barracudaModel, brainParameters, sensorComponents, behaviorParameters.BehaviorType
barracudaModel, brainParameters, sensorComponents,
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType
);
foreach (var check in failedChecks)
{
Expand Down
18 changes: 17 additions & 1 deletion com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using UnityEngine;
using Unity.Barracuda;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Policies;
using UnityEngine.Serialization;
Expand Down Expand Up @@ -395,7 +396,11 @@ public void LazyInitialize()
m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic);
ResetData();
Initialize();
InitializeSensors();

using (TimerStack.Instance.Scoped("InitializeSensors"))
{
InitializeSensors();
}

// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.
Expand Down Expand Up @@ -816,6 +821,17 @@ public virtual void Heuristic(float[] actionsOut)
/// </summary>
internal void InitializeSensors()
{
if (m_PolicyFactory.ObservableAttributeHandling != ObservableAttributeOptions.Ignore)
{
var excludeInherited =
m_PolicyFactory.ObservableAttributeHandling == ObservableAttributeOptions.ExcludeInherited;
using (TimerStack.Instance.Scoped("CreateObservableSensors"))
{
var observableSensors = ObservableAttribute.CreateObservableSensors(this, excludeInherited);
sensors.AddRange(observableSensors);
}
}

// Get all attached sensor components
SensorComponent[] attachedSensorComponents;
if (m_PolicyFactory.UseChildSensors)
Expand Down
43 changes: 28 additions & 15 deletions com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,12 @@ public static string[] GetOutputNames(Model model)
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="sensorComponents">Attached sensor components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <param name="behaviorType">BehaviorType or the Agent to check.</param>
/// <returns>The list the error messages of the checks that failed</returns>
public static IEnumerable<string> CheckModel(Model model, BrainParameters brainParameters,
SensorComponent[] sensorComponents, BehaviorType behaviorType = BehaviorType.Default)
SensorComponent[] sensorComponents, int observableAttributeTotalSize = 0,
BehaviorType behaviorType = BehaviorType.Default)
{
List<string> failedModelChecks = new List<string>();
if (model == null)
Expand Down Expand Up @@ -182,7 +184,7 @@ public static IEnumerable<string> CheckModel(Model model, BrainParameters brainP
CheckOutputTensorPresence(model, memorySize))
;
failedModelChecks.AddRange(
CheckInputTensorShape(model, brainParameters, sensorComponents)
CheckInputTensorShape(model, brainParameters, sensorComponents, observableAttributeTotalSize)
);
failedModelChecks.AddRange(
CheckOutputTensorShape(model, brainParameters, isContinuous, actionSize)
Expand Down Expand Up @@ -253,6 +255,7 @@ static IEnumerable<string> CheckIntScalarPresenceHelper(
/// Whether the model is expecting continuous or discrete control.
/// </param>
/// <param name="sensorComponents">Array of attached sensor components</param>
/// <param name="observableAttributeTotalSize">Total size of ObservableAttributes</param>
/// <returns>
/// A IEnumerable of string corresponding to the failed input presence checks.
/// </returns>
Expand Down Expand Up @@ -404,25 +407,27 @@ static string CheckVisualObsShape(
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="sensorComponents">Attached sensors</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <returns>The list the error messages of the checks that failed</returns>
static IEnumerable<string> CheckInputTensorShape(
Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents)
Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents,
int observableAttributeTotalSize)
{
var failedModelChecks = new List<string>();
var tensorTester =
new Dictionary<string, Func<BrainParameters, TensorProxy, SensorComponent[], string>>()
new Dictionary<string, Func<BrainParameters, TensorProxy, SensorComponent[], int, string>>()
{
{TensorNames.VectorObservationPlaceholder, CheckVectorObsShape},
{TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape},
{TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs) => null)},
{TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs) => null)},
{TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs) => null)},
{TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs) => null)},
{TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)},
};

foreach (var mem in model.memories)
{
tensorTester[mem.input] = ((bp, tensor, scs) => null);
tensorTester[mem.input] = ((bp, tensor, scs, i) => null);
}

var visObsIndex = 0;
Expand All @@ -434,7 +439,7 @@ static IEnumerable<string> CheckInputTensorShape(
continue;
}
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
(bp, tensor, scs) => CheckVisualObsShape(tensor, sensorComponent);
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent);
visObsIndex++;
}

Expand All @@ -452,7 +457,7 @@ static IEnumerable<string> CheckInputTensorShape(
else
{
var tester = tensorTester[tensor.name];
var error = tester.Invoke(brainParameters, tensor, sensorComponents);
var error = tester.Invoke(brainParameters, tensor, sensorComponents, observableAttributeTotalSize);
if (error != null)
{
failedModelChecks.Add(error);
Expand All @@ -471,12 +476,14 @@ static IEnumerable<string> CheckInputTensorShape(
/// </param>
/// <param name="tensorProxy">The tensor that is expected by the model</param>
/// <param name="sensorComponents">Array of attached sensor components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckVectorObsShape(
BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents)
BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents,
int observableAttributeTotalSize)
{
var vecObsSizeBp = brainParameters.VectorObservationSize;
var numStackedVector = brainParameters.NumStackedVectorObservations;
Expand All @@ -491,6 +498,8 @@ static string CheckVectorObsShape(
}
}

totalVectorSensorSize += observableAttributeTotalSize;

if (vecObsSizeBp * numStackedVector + totalVectorSensorSize != totalVecObsSizeT)
{
var sensorSizes = "";
Expand All @@ -512,7 +521,9 @@ static string CheckVectorObsShape(

sensorSizes += "]";
return $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " +
$"but received {vecObsSizeBp} x {numStackedVector} vector observations and " +
$"but received: \n" +
$"Vector observations: {vecObsSizeBp} x {numStackedVector}\n" +
$"Total [Observable] attributes: {observableAttributeTotalSize}\n" +
$"SensorComponent sizes: {sensorSizes}.";
}
return null;
Expand All @@ -526,11 +537,13 @@ static string CheckVectorObsShape(
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="tensorProxy"> The tensor that is expected by the model</param>
/// <param name="sensorComponents">Array of attached sensor components</param>
/// <param name="sensorComponents">Array of attached sensor components (unused).</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes (unused).</param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
static string CheckPreviousActionShape(
BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents)
BrainParameters brainParameters, TensorProxy tensorProxy,
SensorComponent[] sensorComponents, int observableAttributeTotalSize)
{
var numberActionsBp = brainParameters.VectorActionSize.Length;
var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1];
Expand Down
43 changes: 43 additions & 0 deletions com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System;
using UnityEngine;
using UnityEngine.Serialization;
using Unity.MLAgents.Sensors.Reflection;

namespace Unity.MLAgents.Policies
{
Expand Down Expand Up @@ -30,6 +31,36 @@ public enum BehaviorType
InferenceOnly
}

/// <summary>
/// Options for controlling how the Agent class is searched for <see cref="ObservableAttribute"/>s.
/// </summary>
public enum ObservableAttributeOptions
{
/// <summary>
/// All ObservableAttributes on the Agent will be ignored. If there are no
/// ObservableAttributes on the Agent, this will result in the fastest
/// initialization time.
/// </summary>
Ignore,

/// <summary>
/// Only members on the declared class will be examined; members that are
/// inherited are ignored. This is the default behavior, and a reasonable
/// tradeoff between performance and flexibility.
/// </summary>
/// <remarks>This corresponds to setting the
/// [BindingFlags.DeclaredOnly](https://docs.microsoft.com/en-us/dotnet/api/system.reflection.bindingflags?view=netcore-3.1)
/// when examining the fields and properties of the Agent class instance.
/// </remarks>
ExcludeInherited,

/// <summary>
/// All members on the class will be examined. This can lead to slower
/// startup times
/// </summary>
ExamineAll
}

/// <summary>
/// A component for setting an <seealso cref="Agent"/> instance's behavior and
/// brain properties.
Expand Down Expand Up @@ -129,6 +160,18 @@ public bool UseChildSensors
set { m_UseChildSensors = value; }
}

[HideInInspector, SerializeField]
ObservableAttributeOptions m_ObservableAttributeHandling = ObservableAttributeOptions.Ignore;

/// <summary>
/// Determines how the Agent class is searched for <see cref="ObservableAttribute"/>s.
/// </summary>
public ObservableAttributeOptions ObservableAttributeHandling
{
get { return m_ObservableAttributeHandling; }
set { m_ObservableAttributeHandling = value; }
}

/// <summary>
/// Returns the behavior name, concatenated with any other metadata (i.e. team id).
/// </summary>
Expand Down
6 changes: 2 additions & 4 deletions com.unity.ml-agents/Runtime/Policies/BrainParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@ public enum SpaceType
public class BrainParameters
{
/// <summary>
/// The size of the observation space.
/// </summary>
/// <remarks>An agent creates the observation vector in its
/// The number of the observations that are added in
/// <see cref="Agent.CollectObservations(Sensors.VectorSensor)"/>
/// implementation.</remarks>
/// </summary>
/// <value>
/// The length of the vector containing observation values.
/// </value>
Expand Down
8 changes: 8 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/Reflection.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a boolean field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class BoolReflectionSensor : ReflectionSensorBase
{
internal BoolReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 1)
{}

internal override void WriteReflectedField(ObservationWriter writer)
{
var boolVal = (System.Boolean)GetReflectedValue();
writer[0] = boolVal ? 1.0f : 0.0f;
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a float field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class FloatReflectionSensor : ReflectionSensorBase
{
internal FloatReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 1)
{}

internal override void WriteReflectedField(ObservationWriter writer)
{
var floatVal = (System.Single)GetReflectedValue();
writer[0] = floatVal;
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading