diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index 8a34b187ea..a318fabce3 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -13,11 +13,19 @@ and this project adheres to
- Added the Random Network Distillation (RND) intrinsic reward signal to the Pytorch
trainers. To use RND, add a `rnd` section to the `reward_signals` section of your
yaml configuration file. [More information here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-Configuration-File.md#rnd-intrinsic-reward)
-
### Minor Changes
#### com.unity.ml-agents (C#)
+ - Stacking for compressed observations is now supported. An addtional setting
+ option `Observation Stacks` is added in editor to sensor components that support
+ compressed observations. A new class `ISparseChannelSensor` with an
+ additional method `GetCompressedChannelMapping()`is added to generate a mapping
+ of the channels in compressed data to the actual channel after decompression,
+ for the python side to decompress correctly. (#4476)
#### ml-agents / ml-agents-envs / gym-unity (Python)
-
+ - The Communication API was changed to 1.2.0 to indicate support for stacked
+ compressed observation. A new entry `compressed_channel_mapping` is added to the
+ proto to handle decompression correctly. Newer versions of the package that wish to
+ make use of this will also need a compatible version of the Python trainers. (#4476)
### Bug Fixes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
diff --git a/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs b/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs
index 99904e4363..baf9873324 100644
--- a/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs
+++ b/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs
@@ -25,6 +25,7 @@ public override void OnInspectorGUI()
EditorGUILayout.PropertyField(so.FindProperty("m_Width"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
+ EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true);
diff --git a/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs b/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs
index 55c621a9c4..10a7769d8d 100644
--- a/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs
+++ b/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs
@@ -20,6 +20,7 @@ public override void OnInspectorGUI()
EditorGUILayout.PropertyField(so.FindProperty("m_RenderTexture"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
+ EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
}
EditorGUI.EndDisabledGroup();
diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs
index fac6934c22..383526b2e8 100644
--- a/com.unity.ml-agents/Runtime/Academy.cs
+++ b/com.unity.ml-agents/Runtime/Academy.cs
@@ -74,9 +74,13 @@ public class Academy : IDisposable
/// 1.1.0
/// Support concatenated PNGs for compressed observations.
///
+ /// -
+ /// 1.2.0
+ /// Support compression mapping for stacked compressed observations.
+ ///
///
///
- const string k_ApiVersion = "1.1.0";
+ const string k_ApiVersion = "1.2.0";
///
/// Unity package version of com.unity.ml-agents.
diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
index 131685f8d9..5f79b5f068 100644
--- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
+++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
@@ -222,7 +222,8 @@ public static List ToAgentActionList(this UnityRLInputProto.Types.ListA
///
/// Static flag to make sure that we only fire the warning once.
///
- private static bool s_HaveWarnedAboutTrainerCapabilities = false;
+ private static bool s_HaveWarnedTrainerCapabilitiesMultiPng = false;
+ private static bool s_HaveWarnedTrainerCapabilitiesMapping = false;
///
/// Generate an ObservationProto for the sensor using the provided ObservationWriter.
@@ -243,10 +244,27 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
if (!trainerCanHandle)
{
- if (!s_HaveWarnedAboutTrainerCapabilities)
+ if (!s_HaveWarnedTrainerCapabilitiesMultiPng)
{
Debug.LogWarning($"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}.");
- s_HaveWarnedAboutTrainerCapabilities = true;
+ s_HaveWarnedTrainerCapabilitiesMultiPng = true;
+ }
+ compressionType = SensorCompressionType.None;
+ }
+ }
+ // Check capabilities if we need mapping for compressed observations
+ if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3)
+ {
+ var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping;
+ var isTrivialMapping = IsTrivialMapping(sensor);
+ if (!trainerCanHandleMapping && !isTrivialMapping)
+ {
+ if (!s_HaveWarnedTrainerCapabilitiesMapping)
+ {
+ Debug.LogWarning($"The sensor {sensor.GetName()} is using non-trivial mapping and " +
+ "the attached trainer doesn't support compression mapping. " +
+ "Switching to uncompressed observations.");
+ s_HaveWarnedTrainerCapabilitiesMapping = true;
}
compressionType = SensorCompressionType.None;
}
@@ -283,12 +301,16 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
"return SensorCompressionType.None from GetCompressionType()."
);
}
-
observationProto = new ObservationProto
{
CompressedData = ByteString.CopyFrom(compressedObs),
CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
};
+ var compressibleSensor = sensor as ISparseChannelSensor;
+ if (compressibleSensor != null)
+ {
+ observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
+ }
}
observationProto.Shape.AddRange(shape);
return observationProto;
@@ -300,7 +322,8 @@ public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto
return new UnityRLCapabilities
{
BaseRLCapabilities = proto.BaseRLCapabilities,
- ConcatenatedPngObservations = proto.ConcatenatedPngObservations
+ ConcatenatedPngObservations = proto.ConcatenatedPngObservations,
+ CompressedChannelMapping = proto.CompressedChannelMapping,
};
}
@@ -310,7 +333,36 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps)
{
BaseRLCapabilities = rlCaps.BaseRLCapabilities,
ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations,
+ CompressedChannelMapping = rlCaps.CompressedChannelMapping,
};
}
+
+ internal static bool IsTrivialMapping(ISensor sensor)
+ {
+ var compressibleSensor = sensor as ISparseChannelSensor;
+ if (compressibleSensor is null)
+ {
+ return true;
+ }
+ var mapping = compressibleSensor.GetCompressedChannelMapping();
+ if (mapping == null)
+ {
+ return true;
+ }
+ // check if mapping equals zero mapping
+ if (mapping.Length == 3 && mapping.All(m => m == 0))
+ {
+ return true;
+ }
+ // check if mapping equals identity mapping
+ for (var i = 0; i < mapping.Length; i++)
+ {
+ if (mapping[i] != i)
+ {
+ return false;
+ }
+ }
+ return true;
+ }
}
}
diff --git a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
index 7e188af421..30a72c19e3 100644
--- a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
+++ b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
@@ -6,15 +6,17 @@ internal class UnityRLCapabilities
{
public bool BaseRLCapabilities;
public bool ConcatenatedPngObservations;
+ public bool CompressedChannelMapping;
///
/// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This
/// struct will be used to inform users if and when they are using C# / Trainer features that are mismatched.
///
- public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true)
+ public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true, bool compressedChannelMapping = true)
{
BaseRLCapabilities = baseRlCapabilities;
ConcatenatedPngObservations = concatenatedPngObservations;
+ CompressedChannelMapping = compressedChannelMapping;
}
///
diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
index c30d7ec6f9..57050424e7 100644
--- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
+++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
@@ -25,14 +25,15 @@ static CapabilitiesReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
- "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiWwoYVW5pdHlSTENh",
+ "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMifQoYVW5pdHlSTENh",
"cGFiaWxpdGllc1Byb3RvEhoKEmJhc2VSTENhcGFiaWxpdGllcxgBIAEoCBIj",
- "Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAhCJaoCIlVuaXR5",
- "Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
+ "Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAgSIAoYY29tcHJl",
+ "c3NlZENoYW5uZWxNYXBwaW5nGAMgASgIQiWqAiJVbml0eS5NTEFnZW50cy5D",
+ "b21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
- new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations" }, null, null, null)
+ new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping" }, null, null, null)
}));
}
#endregion
@@ -71,6 +72,7 @@ public UnityRLCapabilitiesProto() {
public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() {
baseRLCapabilities_ = other.baseRLCapabilities_;
concatenatedPngObservations_ = other.concatenatedPngObservations_;
+ compressedChannelMapping_ = other.compressedChannelMapping_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}
@@ -107,6 +109,20 @@ public bool ConcatenatedPngObservations {
}
}
+ /// Field number for the "compressedChannelMapping" field.
+ public const int CompressedChannelMappingFieldNumber = 3;
+ private bool compressedChannelMapping_;
+ ///
+ /// compression mapping for stacking compressed observations.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool CompressedChannelMapping {
+ get { return compressedChannelMapping_; }
+ set {
+ compressedChannelMapping_ = value;
+ }
+ }
+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);
@@ -122,6 +138,7 @@ public bool Equals(UnityRLCapabilitiesProto other) {
}
if (BaseRLCapabilities != other.BaseRLCapabilities) return false;
if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false;
+ if (CompressedChannelMapping != other.CompressedChannelMapping) return false;
return Equals(_unknownFields, other._unknownFields);
}
@@ -130,6 +147,7 @@ public override int GetHashCode() {
int hash = 1;
if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode();
if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode();
+ if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
@@ -151,6 +169,10 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(16);
output.WriteBool(ConcatenatedPngObservations);
}
+ if (CompressedChannelMapping != false) {
+ output.WriteRawTag(24);
+ output.WriteBool(CompressedChannelMapping);
+ }
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
@@ -165,6 +187,9 @@ public int CalculateSize() {
if (ConcatenatedPngObservations != false) {
size += 1 + 1;
}
+ if (CompressedChannelMapping != false) {
+ size += 1 + 1;
+ }
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
@@ -182,6 +207,9 @@ public void MergeFrom(UnityRLCapabilitiesProto other) {
if (other.ConcatenatedPngObservations != false) {
ConcatenatedPngObservations = other.ConcatenatedPngObservations;
}
+ if (other.CompressedChannelMapping != false) {
+ CompressedChannelMapping = other.CompressedChannelMapping;
+ }
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}
@@ -201,6 +229,10 @@ public void MergeFrom(pb::CodedInputStream input) {
ConcatenatedPngObservations = input.ReadBool();
break;
}
+ case 24: {
+ CompressedChannelMapping = input.ReadBool();
+ break;
+ }
}
}
}
diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
index 8c38ef31a8..fee69f568c 100644
--- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
+++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
@@ -25,19 +25,20 @@ static ObservationReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjRtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0",
- "aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyL5AQoQT2JzZXJ2YXRp",
+ "aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKdAgoQT2JzZXJ2YXRp",
"b25Qcm90bxINCgVzaGFwZRgBIAMoBRJEChBjb21wcmVzc2lvbl90eXBlGAIg",
"ASgOMiouY29tbXVuaWNhdG9yX29iamVjdHMuQ29tcHJlc3Npb25UeXBlUHJv",
"dG8SGQoPY29tcHJlc3NlZF9kYXRhGAMgASgMSAASRgoKZmxvYXRfZGF0YRgE",
"IAEoCzIwLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8u",
- "RmxvYXREYXRhSAAaGQoJRmxvYXREYXRhEgwKBGRhdGEYASADKAJCEgoQb2Jz",
- "ZXJ2YXRpb25fZGF0YSopChRDb21wcmVzc2lvblR5cGVQcm90bxIICgROT05F",
- "EAASBwoDUE5HEAFCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9i",
- "amVjdHNiBnByb3RvMw=="));
+ "RmxvYXREYXRhSAASIgoaY29tcHJlc3NlZF9jaGFubmVsX21hcHBpbmcYBSAD",
+ "KAUaGQoJRmxvYXREYXRhEgwKBGRhdGEYASADKAJCEgoQb2JzZXJ2YXRpb25f",
+ "ZGF0YSopChRDb21wcmVzc2lvblR5cGVQcm90bxIICgROT05FEAASBwoDUE5H",
+ "EAFCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy",
+ "b3RvMw=="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
- new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
+ new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
}));
}
#endregion
@@ -79,6 +80,7 @@ public ObservationProto() {
public ObservationProto(ObservationProto other) : this() {
shape_ = other.shape_.Clone();
compressionType_ = other.compressionType_;
+ compressedChannelMapping_ = other.compressedChannelMapping_.Clone();
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;
@@ -139,6 +141,16 @@ public ObservationProto Clone() {
}
}
+ /// Field number for the "compressed_channel_mapping" field.
+ public const int CompressedChannelMappingFieldNumber = 5;
+ private static readonly pb::FieldCodec _repeated_compressedChannelMapping_codec
+ = pb::FieldCodec.ForInt32(42);
+ private readonly pbc::RepeatedField compressedChannelMapping_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField CompressedChannelMapping {
+ get { return compressedChannelMapping_; }
+ }
+
private object observationData_;
/// Enum of possible cases for the "observation_data" oneof.
public enum ObservationDataOneofCase {
@@ -175,6 +187,7 @@ public bool Equals(ObservationProto other) {
if (CompressionType != other.CompressionType) return false;
if (CompressedData != other.CompressedData) return false;
if (!object.Equals(FloatData, other.FloatData)) return false;
+ if(!compressedChannelMapping_.Equals(other.compressedChannelMapping_)) return false;
if (ObservationDataCase != other.ObservationDataCase) return false;
return Equals(_unknownFields, other._unknownFields);
}
@@ -186,6 +199,7 @@ public override int GetHashCode() {
if (CompressionType != 0) hash ^= CompressionType.GetHashCode();
if (observationDataCase_ == ObservationDataOneofCase.CompressedData) hash ^= CompressedData.GetHashCode();
if (observationDataCase_ == ObservationDataOneofCase.FloatData) hash ^= FloatData.GetHashCode();
+ hash ^= compressedChannelMapping_.GetHashCode();
hash ^= (int) observationDataCase_;
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
@@ -213,6 +227,7 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(34);
output.WriteMessage(FloatData);
}
+ compressedChannelMapping_.WriteTo(output, _repeated_compressedChannelMapping_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
@@ -231,6 +246,7 @@ public int CalculateSize() {
if (observationDataCase_ == ObservationDataOneofCase.FloatData) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(FloatData);
}
+ size += compressedChannelMapping_.CalculateSize(_repeated_compressedChannelMapping_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
@@ -246,6 +262,7 @@ public void MergeFrom(ObservationProto other) {
if (other.CompressionType != 0) {
CompressionType = other.CompressionType;
}
+ compressedChannelMapping_.Add(other.compressedChannelMapping_);
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;
@@ -291,6 +308,11 @@ public void MergeFrom(pb::CodedInputStream input) {
FloatData = subBuilder;
break;
}
+ case 42:
+ case 40: {
+ compressedChannelMapping_.AddEntriesFrom(input, _repeated_compressedChannelMapping_codec);
+ break;
+ }
}
}
}
diff --git a/com.unity.ml-agents/Runtime/SensorHelper.cs b/com.unity.ml-agents/Runtime/SensorHelper.cs
index 471a768b27..5870d0789a 100644
--- a/com.unity.ml-agents/Runtime/SensorHelper.cs
+++ b/com.unity.ml-agents/Runtime/SensorHelper.cs
@@ -1,4 +1,5 @@
using UnityEngine;
+using Unity.Barracuda;
namespace Unity.MLAgents.Sensors
{
@@ -62,5 +63,58 @@ public static bool CompareObservation(ISensor sensor, float[] expected, out stri
errorMessage = null;
return true;
}
+
+ public static bool CompareObservation(ISensor sensor, float[,,] expected, out string errorMessage)
+ {
+ var tensorShape = new TensorShape(0, expected.GetLength(0), expected.GetLength(1), expected.GetLength(2));
+ var numExpected = tensorShape.height * tensorShape.width * tensorShape.channels;
+ const float fill = -1337f;
+ var output = new float[numExpected];
+ for (var i = 0; i < numExpected; i++)
+ {
+ output[i] = fill;
+ }
+
+ if (numExpected > 0)
+ {
+ if (fill != output[0])
+ {
+ errorMessage = "Error setting output buffer.";
+ return false;
+ }
+ }
+
+ ObservationWriter writer = new ObservationWriter();
+ writer.SetTarget(output, sensor.GetObservationShape(), 0);
+
+ // Make sure ObservationWriter didn't touch anything
+ if (numExpected > 0)
+ {
+ if (fill != output[0])
+ {
+ errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have.";
+ return false;
+ }
+ }
+
+ sensor.Write(writer);
+ for (var h = 0; h < tensorShape.height; h++)
+ {
+ for (var w = 0; w < tensorShape.width; w++)
+ {
+ for (var c = 0; c < tensorShape.channels; c++)
+ {
+ if (expected[h, w, c] != output[tensorShape.Index(0, h, w, c)])
+ {
+ errorMessage = $"Expected and actual differed in position [{h}, {w}, {c}]. " +
+ "Expected: {expected[h, w, c]} Actual: {output[tensorShape.Index(0, h, w, c)]} ";
+ return false;
+ }
+ }
+ }
+ }
+ errorMessage = null;
+ return true;
+ }
}
}
diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
index fcfa350797..ebb9d9dc73 100644
--- a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
@@ -75,6 +75,11 @@ public bool Grayscale
set { m_Grayscale = value; }
}
+ [HideInInspector, SerializeField]
+ [Range(1, 50)]
+ [Tooltip("Number of camera frames that will be stacked before being fed to the neural network.")]
+ int m_ObservationStacks = 1;
+
[HideInInspector, SerializeField, FormerlySerializedAs("compression")]
SensorCompressionType m_Compression = SensorCompressionType.PNG;
@@ -87,6 +92,16 @@ public SensorCompressionType CompressionType
set { m_Compression = value; UpdateSensor(); }
}
+ ///
+ /// Whether to stack previous observations. Using 1 means no previous observations.
+ /// Note that changing this after the sensor is created has no effect.
+ ///
+ public int ObservationStacks
+ {
+ get { return m_ObservationStacks; }
+ set { m_ObservationStacks = value; }
+ }
+
///
/// Creates the
///
@@ -94,6 +109,11 @@ public SensorCompressionType CompressionType
public override ISensor CreateSensor()
{
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression);
+
+ if (ObservationStacks != 1)
+ {
+ return new StackingSensor(m_Sensor, ObservationStacks);
+ }
return m_Sensor;
}
@@ -103,7 +123,13 @@ public override ISensor CreateSensor()
/// The observation shape of the associated object.
public override int[] GetObservationShape()
{
- return CameraSensor.GenerateShape(m_Width, m_Height, Grayscale);
+ var stacks = ObservationStacks > 1 ? ObservationStacks : 1;
+ var cameraSensorshape = CameraSensor.GenerateShape(m_Width, m_Height, Grayscale);
+ if (stacks > 1)
+ {
+ cameraSensorshape[cameraSensorshape.Length - 1] *= stacks;
+ }
+ return cameraSensorshape;
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs
new file mode 100644
index 0000000000..06b517eb42
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs
@@ -0,0 +1,20 @@
+namespace Unity.MLAgents.Sensors
+{
+ ///
+ /// Sensor interface for sparse channel sensor which requires a compressed channel mapping.
+ ///
+ public interface ISparseChannelSensor : ISensor
+ {
+ ///
+ /// Returns the mapping of the channels in compressed data to the actual channel after decompression.
+ /// The mapping is a list of interger index with the same length as
+ /// the number of output observation layers (channels), including padding if there's any.
+ /// Each index indicates the actual channel the layer will go into.
+ /// Layers with the same index will be averaged, and layers with negative index will be dropped.
+ /// For example, mapping for CameraSensor using grayscale and stacking of two: [0, 0, 0, 1, 1, 1]
+ /// Mapping for GridSensor of 4 channels and stacking of two: [0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1]
+ ///
+ /// Mapping of the compressed data
+ int[] GetCompressedChannelMapping();
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs.meta
new file mode 100644
index 0000000000..bebec4f1cf
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 63bb76c1e31c24fa5b4a384ea0edbfb0
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs
index f308507e0f..d22a3ffcaf 100644
--- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs
@@ -54,6 +54,11 @@ public bool Grayscale
set { m_Grayscale = value; }
}
+ [HideInInspector, SerializeField]
+ [Range(1, 50)]
+ [Tooltip("Number of frames that will be stacked before being fed to the neural network.")]
+ int m_ObservationStacks = 1;
+
[HideInInspector, SerializeField, FormerlySerializedAs("compression")]
SensorCompressionType m_Compression = SensorCompressionType.PNG;
@@ -66,10 +71,24 @@ public SensorCompressionType CompressionType
set { m_Compression = value; UpdateSensor(); }
}
+ ///
+ /// Whether to stack previous observations. Using 1 means no previous observations.
+ /// Note that changing this after the sensor is created has no effect.
+ ///
+ public int ObservationStacks
+ {
+ get { return m_ObservationStacks; }
+ set { m_ObservationStacks = value; }
+ }
+
///
public override ISensor CreateSensor()
{
m_Sensor = new RenderTextureSensor(RenderTexture, Grayscale, SensorName, m_Compression);
+ if (ObservationStacks != 1)
+ {
+ return new StackingSensor(m_Sensor, ObservationStacks);
+ }
return m_Sensor;
}
@@ -78,8 +97,15 @@ public override int[] GetObservationShape()
{
var width = RenderTexture != null ? RenderTexture.width : 0;
var height = RenderTexture != null ? RenderTexture.height : 0;
+ var observationShape = new[] { height, width, Grayscale ? 1 : 3 };
+
+ var stacks = ObservationStacks > 1 ? ObservationStacks : 1;
+ if (stacks > 1)
+ {
+ observationShape[2] *= stacks;
+ }
- return new[] { height, width, Grayscale ? 1 : 3 };
+ return observationShape;
}
///
diff --git a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
index 962ffe1a9e..6064a758dd 100644
--- a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
@@ -1,4 +1,11 @@
using System;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using UnityEngine;
+using Unity.Barracuda;
+
+
+[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")]
namespace Unity.MLAgents.Sensors
{
@@ -8,10 +15,9 @@ namespace Unity.MLAgents.Sensors
/// For example, 4 stacked sets of observations would be output like
/// | t = now - 3 | t = now -3 | t = now - 2 | t = now |
/// Internally, a circular buffer of arrays is used. The m_CurrentIndex represents the most recent observation.
- ///
- /// Currently, compressed and multidimensional observations are not supported.
+ /// Currently, observations are stacked on the last dimension.
///
- public class StackingSensor : ISensor
+ public class StackingSensor : ISparseChannelSensor
{
///
/// The wrapped sensor.
@@ -26,15 +32,22 @@ public class StackingSensor : ISensor
string m_Name;
int[] m_Shape;
+ int[] m_WrappedShape;
///
/// Buffer of previous observations
///
float[][] m_StackedObservations;
+ byte[][] m_StackedCompressedObservations;
+
int m_CurrentIndex;
ObservationWriter m_LocalWriter = new ObservationWriter();
+ byte[] m_EmptyCompressedObservation;
+ int[] m_CompressionMapping;
+ TensorShape m_tensorShape;
+
///
/// Initializes the sensor.
///
@@ -48,48 +61,78 @@ public StackingSensor(ISensor wrapped, int numStackedObservations)
m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";
- if (wrapped.GetCompressionType() != SensorCompressionType.None)
- {
- throw new UnityAgentsException("StackingSensor doesn't support compressed observations.'");
- }
-
- var shape = wrapped.GetObservationShape();
- if (shape.Length != 1)
- {
- throw new UnityAgentsException("Only 1-D observations are supported by StackingSensor");
- }
- m_Shape = new int[shape.Length];
+ m_WrappedShape = wrapped.GetObservationShape();
+ m_Shape = new int[m_WrappedShape.Length];
m_UnstackedObservationSize = wrapped.ObservationSize();
- for (int d = 0; d < shape.Length; d++)
+ for (int d = 0; d < m_WrappedShape.Length; d++)
{
- m_Shape[d] = shape[d];
+ m_Shape[d] = m_WrappedShape[d];
}
// TODO support arbitrary stacking dimension
- m_Shape[0] *= numStackedObservations;
+ m_Shape[m_Shape.Length - 1] *= numStackedObservations;
+
+ // Initialize uncompressed buffer anyway in case python trainer does not
+ // support the compression mapping and has to fall back to uncompressed obs.
m_StackedObservations = new float[numStackedObservations][];
for (var i = 0; i < numStackedObservations; i++)
{
m_StackedObservations[i] = new float[m_UnstackedObservationSize];
}
+
+ if (m_WrappedSensor.GetCompressionType() != SensorCompressionType.None)
+ {
+ m_StackedCompressedObservations = new byte[numStackedObservations][];
+ m_EmptyCompressedObservation = CreateEmptyPNG();
+ for (var i = 0; i < numStackedObservations; i++)
+ {
+ m_StackedCompressedObservations[i] = m_EmptyCompressedObservation;
+ }
+ m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped);
+ }
+
+ if (m_Shape.Length != 1)
+ {
+ m_tensorShape = new TensorShape(0, m_WrappedShape[0], m_WrappedShape[1], m_WrappedShape[2]);
+ }
}
///
public int Write(ObservationWriter writer)
{
// First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one.
- var wrappedShape = m_WrappedSensor.GetObservationShape();
- m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], wrappedShape, 0);
+ m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedShape, 0);
m_WrappedSensor.Write(m_LocalWriter);
// Now write the saved observations (oldest first)
var numWritten = 0;
- for (var i = 0; i < m_NumStackedObservations; i++)
+ if (m_WrappedShape.Length == 1)
{
- var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
- writer.AddRange(m_StackedObservations[obsIndex], numWritten);
- numWritten += m_UnstackedObservationSize;
+ for (var i = 0; i < m_NumStackedObservations; i++)
+ {
+ var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
+ writer.AddRange(m_StackedObservations[obsIndex], numWritten);
+ numWritten += m_UnstackedObservationSize;
+ }
+ }
+ else
+ {
+ for (var i = 0; i < m_NumStackedObservations; i++)
+ {
+ var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
+ for (var h = 0; h < m_WrappedShape[0]; h++)
+ {
+ for (var w = 0; w < m_WrappedShape[1]; w++)
+ {
+ for (var c = 0; c < m_WrappedShape[2]; c++)
+ {
+ writer[h, w, i * m_WrappedShape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)];
+ }
+ }
+ }
+ }
+ numWritten = m_WrappedShape[0] * m_WrappedShape[1] * m_WrappedShape[2] * m_NumStackedObservations;
}
return numWritten;
@@ -113,6 +156,13 @@ public void Reset()
{
Array.Clear(m_StackedObservations[i], 0, m_StackedObservations[i].Length);
}
+ if (m_WrappedSensor.GetCompressionType() != SensorCompressionType.None)
+ {
+ for (var i = 0; i < m_NumStackedObservations; i++)
+ {
+ m_StackedCompressedObservations[i] = m_EmptyCompressedObservation;
+ }
+ }
}
///
@@ -128,17 +178,101 @@ public string GetName()
}
///
- public virtual byte[] GetCompressedObservation()
+ public byte[] GetCompressedObservation()
+ {
+ var compressed = m_WrappedSensor.GetCompressedObservation();
+ m_StackedCompressedObservations[m_CurrentIndex] = compressed;
+
+ int bytesLength = 0;
+ foreach (byte[] compressedObs in m_StackedCompressedObservations)
+ {
+ bytesLength += compressedObs.Length;
+ }
+
+ byte[] outputBytes = new byte[bytesLength];
+ int offset = 0;
+ for (var i = 0; i < m_NumStackedObservations; i++)
+ {
+ var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
+ Buffer.BlockCopy(m_StackedCompressedObservations[obsIndex],
+ 0, outputBytes, offset, m_StackedCompressedObservations[obsIndex].Length);
+ offset += m_StackedCompressedObservations[obsIndex].Length;
+ }
+
+ return outputBytes;
+ }
+
+ public int[] GetCompressedChannelMapping()
{
- return null;
+ return m_CompressionMapping;
}
///
- public virtual SensorCompressionType GetCompressionType()
+ public SensorCompressionType GetCompressionType()
+ {
+ return m_WrappedSensor.GetCompressionType();
+ }
+
+ ///
+ /// Create Empty PNG for initializing the buffer for stacking.
+ ///
+ internal byte[] CreateEmptyPNG()
{
- return SensorCompressionType.None;
+ int height = m_WrappedSensor.GetObservationShape()[0];
+ int width = m_WrappedSensor.GetObservationShape()[1];
+ var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
+ return texture2D.EncodeToPNG();
}
- // TODO support stacked compressed observations (byte stream)
+
+ ///
+ /// Constrct stacked CompressedChannelMapping.
+ ///
+ internal int[] ConstructStackedCompressedChannelMapping(ISensor wrappedSenesor)
+ {
+ // Get CompressedChannelMapping of the wrapped sensor. If the
+ // wrapped sensor doesn't have one, use default mapping.
+ // Default mapping: {0, 0, 0} for grayscale, identity mapping {1, 2, ..., n} otherwise.
+ int[] wrappedMapping = null;
+ int wrappedNumChannel = wrappedSenesor.GetObservationShape()[2];
+ var sparseChannelSensor = m_WrappedSensor as ISparseChannelSensor;
+ if (sparseChannelSensor != null)
+ {
+ wrappedMapping = sparseChannelSensor.GetCompressedChannelMapping();
+ }
+ if (wrappedMapping == null)
+ {
+ if (wrappedNumChannel == 1)
+ {
+ wrappedMapping = new int[] { 0, 0, 0 };
+ }
+ else
+ {
+ wrappedMapping = Enumerable.Range(0, wrappedNumChannel).ToArray();
+ }
+ }
+
+ // Construct stacked mapping using the mapping of wrapped sensor.
+ // First pad the wrapped mapping to multiple of 3, then repeat
+ // and add offset to each copy to form the stacked mapping.
+ int paddedMapLength = (wrappedMapping.Length + 2) / 3 * 3;
+ var compressionMapping = new int[paddedMapLength * m_NumStackedObservations];
+ for (var i = 0; i < m_NumStackedObservations; i++)
+ {
+ var offset = wrappedNumChannel * i;
+ for (var j = 0; j < paddedMapLength; j++)
+ {
+ if (j < wrappedMapping.Length)
+ {
+ compressionMapping[j + paddedMapLength * i] = wrappedMapping[j] >= 0 ? wrappedMapping[j] + offset : -1;
+ }
+ else
+ {
+ compressionMapping[j + paddedMapLength * i] = -1;
+ }
+ }
+ }
+ return compressionMapping;
+ }
}
}
diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
index 675f473db4..88039d2b3f 100644
--- a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
+++ b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
@@ -90,6 +90,19 @@ public string GetName()
}
}
+ class DummySparseChannelSensor : DummySensor, ISparseChannelSensor
+ {
+ public int[] Mapping;
+ internal DummySparseChannelSensor()
+ {
+ }
+
+ public int[] GetCompressedChannelMapping()
+ {
+ return Mapping;
+ }
+ }
+
[Test]
public void TestGetObservationProtoCapabilities()
{
@@ -139,5 +152,23 @@ public void TestGetObservationProtoCapabilities()
}
+
+ [Test]
+ public void TestIsTrivialMapping()
+ {
+ Assert.AreEqual(GrpcExtensions.IsTrivialMapping(new DummySensor()), true);
+
+ var sparseChannelSensor = new DummySparseChannelSensor();
+ sparseChannelSensor.Mapping = null;
+ Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true);
+ sparseChannelSensor.Mapping = new int[] { 0, 0, 0 };
+ Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true);
+ sparseChannelSensor.Mapping = new int[] { 0, 1, 2, 3, 4 };
+ Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true);
+ sparseChannelSensor.Mapping = new int[] { 1, 2, 3, 4, -1, -1 };
+ Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false);
+ sparseChannelSensor.Mapping = new int[] { 0, 0, 0, 1, 1, 1 };
+ Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false);
+ }
}
}
diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs b/com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs
new file mode 100644
index 0000000000..ea035b9875
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs
@@ -0,0 +1,23 @@
+using NUnit.Framework;
+using UnityEngine;
+using Unity.MLAgents.Sensors;
+
+namespace Unity.MLAgents.Tests
+{
+ public static class SensorTestHelper
+ {
+ public static void CompareObservation(ISensor sensor, float[] expected)
+ {
+ string errorMessage;
+ bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
+ Assert.IsTrue(isOK, errorMessage);
+ }
+
+ public static void CompareObservation(ISensor sensor, float[,,] expected)
+ {
+ string errorMessage;
+ bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
+ Assert.IsTrue(isOK, errorMessage);
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs.meta b/com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs.meta
new file mode 100644
index 0000000000..487ace557e
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: e769354f8bd404ca180d7cd7302a5d61
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
index 8626fe3121..1954139691 100644
--- a/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
+++ b/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
@@ -1,4 +1,7 @@
using NUnit.Framework;
+using System;
+using System.Linq;
+using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests
@@ -15,7 +18,7 @@ public void TestCtor()
}
[Test]
- public void TestStacking()
+ public void TestVectorStacking()
{
VectorSensor wrapped = new VectorSensor(2);
ISensor sensor = new StackingSensor(wrapped, 3);
@@ -44,7 +47,7 @@ public void TestStacking()
}
[Test]
- public void TestStackingReset()
+ public void TestVectorStackingReset()
{
VectorSensor wrapped = new VectorSensor(2);
ISensor sensor = new StackingSensor(wrapped, 3);
@@ -60,5 +63,154 @@ public void TestStackingReset()
wrapped.AddObservation(new[] { 5f, 6f });
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 5f, 6f });
}
+
+ class Dummy3DSensor : ISparseChannelSensor
+ {
+ public SensorCompressionType CompressionType = SensorCompressionType.PNG;
+ public int[] Mapping;
+ public int[] Shape;
+ public float[,,] CurrentObservation;
+
+ internal Dummy3DSensor()
+ {
+ }
+
+ public int[] GetObservationShape()
+ {
+ return Shape;
+ }
+
+ public int Write(ObservationWriter writer)
+ {
+ for (var h = 0; h < Shape[0]; h++)
+ {
+ for (var w = 0; w < Shape[1]; w++)
+ {
+ for (var c = 0; c < Shape[2]; c++)
+ {
+ writer[h, w, c] = CurrentObservation[h, w, c];
+ }
+ }
+ }
+ return Shape[0] * Shape[1] * Shape[2];
+ }
+
+ public byte[] GetCompressedObservation()
+ {
+ var writer = new ObservationWriter();
+ var flattenedObservation = new float[Shape[0] * Shape[1] * Shape[2]];
+ writer.SetTarget(flattenedObservation, Shape, 0);
+ Write(writer);
+ byte[] bytes = Array.ConvertAll(flattenedObservation, (z) => (byte)z);
+ return bytes;
+ }
+
+ public void Update() { }
+
+ public void Reset() { }
+
+ public SensorCompressionType GetCompressionType()
+ {
+ return CompressionType;
+ }
+
+ public string GetName()
+ {
+ return "Dummy";
+ }
+
+ public int[] GetCompressedChannelMapping()
+ {
+ return Mapping;
+ }
+ }
+
+ [Test]
+ public void TestStackingMapping()
+ {
+ // Test grayscale stacked mapping with CameraSensor
+ var cameraSensor = new CameraSensor(new Camera(), 64, 64,
+ true, "grayscaleCamera", SensorCompressionType.PNG);
+ var stackedCameraSensor = new StackingSensor(cameraSensor, 2);
+ Assert.AreEqual(stackedCameraSensor.GetCompressedChannelMapping(), new[] { 0, 0, 0, 1, 1, 1 });
+
+ // Test RGB stacked mapping with RenderTextureSensor
+ var renderTextureSensor = new RenderTextureSensor(new RenderTexture(24, 16, 0),
+ false, "renderTexture", SensorCompressionType.PNG);
+ var stackedRenderTextureSensor = new StackingSensor(renderTextureSensor, 2);
+ Assert.AreEqual(stackedRenderTextureSensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, 4, 5 });
+
+ // Test mapping with number of layers not being multiple of 3
+ var dummySensor = new Dummy3DSensor();
+ dummySensor.Shape = new int[] { 2, 2, 4 };
+ dummySensor.Mapping = new int[] { 0, 1, 2, 3 };
+ var stackedDummySensor = new StackingSensor(dummySensor, 2);
+ Assert.AreEqual(stackedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });
+
+ // Test mapping with dummy layers that should be dropped
+ var paddedDummySensor = new Dummy3DSensor();
+ paddedDummySensor.Shape = new int[] { 2, 2, 4 };
+ paddedDummySensor.Mapping = new int[] { 0, 1, 2, 3, -1, -1 };
+ var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2);
+ Assert.AreEqual(stackedPaddedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });
+ }
+
+ [Test]
+ public void Test3DStacking()
+ {
+ var wrapped = new Dummy3DSensor();
+ wrapped.Shape = new int[] { 2, 1, 2 };
+ var sensor = new StackingSensor(wrapped, 2);
+
+ // Check the stacking is on the last dimension
+ wrapped.CurrentObservation = new[, ,] { { { 1f, 2f } }, { { 3f, 4f } } };
+ SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 0f, 0f, 1f, 2f } }, { { 0f, 0f, 3f, 4f } } });
+
+ sensor.Update();
+ wrapped.CurrentObservation = new[, ,] { { { 5f, 6f } }, { { 7f, 8f } } };
+ SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 1f, 2f, 5f, 6f } }, { { 3f, 4f, 7f, 8f } } });
+
+ sensor.Update();
+ wrapped.CurrentObservation = new[, ,] { { { 9f, 10f } }, { { 11f, 12f } } };
+ SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 5f, 6f, 9f, 10f } }, { { 7f, 8f, 11f, 12f } } });
+
+ // Check that if we don't call Update(), the same observations are produced
+ SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 5f, 6f, 9f, 10f } }, { { 7f, 8f, 11f, 12f } } });
+
+ // Test reset
+ sensor.Reset();
+ wrapped.CurrentObservation = new[, ,] { { { 13f, 14f } }, { { 15f, 16f } } };
+ SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 0f, 0f, 13f, 14f } }, { { 0f, 0f, 15f, 16f } } });
+ }
+
+ [Test]
+ public void TestStackedGetCompressedObservation()
+ {
+ var wrapped = new Dummy3DSensor();
+ wrapped.Shape = new int[] { 1, 1, 3 };
+ var sensor = new StackingSensor(wrapped, 2);
+
+ wrapped.CurrentObservation = new[, ,] { { { 1f, 2f, 3f } } };
+ var expected1 = sensor.CreateEmptyPNG();
+ expected1 = expected1.Concat(Array.ConvertAll(new[] { 1f, 2f, 3f }, (z) => (byte)z)).ToArray();
+ Assert.AreEqual(sensor.GetCompressedObservation(), expected1);
+
+ sensor.Update();
+ wrapped.CurrentObservation = new[, ,] { { { 4f, 5f, 6f } } };
+ var expected2 = Array.ConvertAll(new[] { 1f, 2f, 3f, 4f, 5f, 6f }, (z) => (byte)z);
+ Assert.AreEqual(sensor.GetCompressedObservation(), expected2);
+
+ sensor.Update();
+ wrapped.CurrentObservation = new[, ,] { { { 7f, 8f, 9f } } };
+ var expected3 = Array.ConvertAll(new[] { 4f, 5f, 6f, 7f, 8f, 9f }, (z) => (byte)z);
+ Assert.AreEqual(sensor.GetCompressedObservation(), expected3);
+
+ // Test reset
+ sensor.Reset();
+ wrapped.CurrentObservation = new[, ,] { { { 10f, 11f, 12f } } };
+ var expected4 = sensor.CreateEmptyPNG();
+ expected4 = expected4.Concat(Array.ConvertAll(new[] { 10f, 11f, 12f }, (z) => (byte)z)).ToArray();
+ Assert.AreEqual(sensor.GetCompressedObservation(), expected4);
+ }
}
}
diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
index 42cd377a86..5326bca868 100644
--- a/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
+++ b/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
@@ -4,16 +4,6 @@
namespace Unity.MLAgents.Tests
{
- public static class SensorTestHelper
- {
- public static void CompareObservation(ISensor sensor, float[] expected)
- {
- string errorMessage;
- bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
- Assert.IsTrue(isOK, errorMessage);
- }
- }
-
public class VectorSensorTests
{
[Test]
diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
index 87a0cbd641..563e483a8e 100644
--- a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
+++ b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
@@ -19,7 +19,7 @@
name='mlagents_envs/communicator_objects/capabilities.proto',
package='communicator_objects',
syntax='proto3',
- serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"[\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
+ serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"}\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)
@@ -46,6 +46,13 @@
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='compressedChannelMapping', full_name='communicator_objects.UnityRLCapabilitiesProto.compressedChannelMapping', index=2,
+ number=3, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
],
extensions=[
],
@@ -59,7 +66,7 @@
oneofs=[
],
serialized_start=79,
- serialized_end=170,
+ serialized_end=204,
)
DESCRIPTOR.message_types_by_name['UnityRLCapabilitiesProto'] = _UNITYRLCAPABILITIESPROTO
diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
index 600b817173..d2bc066912 100644
--- a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
+++ b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
@@ -27,17 +27,19 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
baseRLCapabilities = ... # type: builtin___bool
concatenatedPngObservations = ... # type: builtin___bool
+ compressedChannelMapping = ... # type: builtin___bool
def __init__(self,
*,
baseRLCapabilities : typing___Optional[builtin___bool] = None,
concatenatedPngObservations : typing___Optional[builtin___bool] = None,
+ compressedChannelMapping : typing___Optional[builtin___bool] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
- def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"concatenatedPngObservations"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations"]) -> None: ...
else:
- def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"concatenatedPngObservations",b"concatenatedPngObservations"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations"]) -> None: ...
diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
index e31a806676..43c840eaa0 100644
--- a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
+++ b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
@@ -20,7 +20,7 @@
name='mlagents_envs/communicator_objects/observation.proto',
package='communicator_objects',
syntax='proto3',
- serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\xf9\x01\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
+ serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x9d\x02\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)
_COMPRESSIONTYPEPROTO = _descriptor.EnumDescriptor(
@@ -40,8 +40,8 @@
],
containing_type=None,
options=None,
- serialized_start=330,
- serialized_end=371,
+ serialized_start=366,
+ serialized_end=407,
)
_sym_db.RegisterEnumDescriptor(_COMPRESSIONTYPEPROTO)
@@ -77,8 +77,8 @@
extension_ranges=[],
oneofs=[
],
- serialized_start=283,
- serialized_end=308,
+ serialized_start=319,
+ serialized_end=344,
)
_OBSERVATIONPROTO = _descriptor.Descriptor(
@@ -116,6 +116,13 @@
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='compressed_channel_mapping', full_name='communicator_objects.ObservationProto.compressed_channel_mapping', index=4,
+ number=5, type=5, cpp_type=1, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
],
extensions=[
],
@@ -132,7 +139,7 @@
index=0, containing_type=None, fields=[]),
],
serialized_start=79,
- serialized_end=328,
+ serialized_end=364,
)
_OBSERVATIONPROTO_FLOATDATA.containing_type = _OBSERVATIONPROTO
diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
index 79681430fb..bf38830c00 100644
--- a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
+++ b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
@@ -72,6 +72,7 @@ class ObservationProto(google___protobuf___message___Message):
shape = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
compression_type = ... # type: CompressionTypeProto
compressed_data = ... # type: builtin___bytes
+ compressed_channel_mapping = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
@property
def float_data(self) -> ObservationProto.FloatData: ...
@@ -82,6 +83,7 @@ class ObservationProto(google___protobuf___message___Message):
compression_type : typing___Optional[CompressionTypeProto] = None,
compressed_data : typing___Optional[builtin___bytes] = None,
float_data : typing___Optional[ObservationProto.FloatData] = None,
+ compressed_channel_mapping : typing___Optional[typing___Iterable[builtin___int]] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> ObservationProto: ...
@@ -89,8 +91,8 @@ class ObservationProto(google___protobuf___message___Message):
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"compressed_data",u"float_data",u"observation_data"]) -> builtin___bool: ...
- def ClearField(self, field_name: typing_extensions___Literal[u"compressed_data",u"compression_type",u"float_data",u"observation_data",u"shape"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",u"compressed_data",u"compression_type",u"float_data",u"observation_data",u"shape"]) -> None: ...
else:
def HasField(self, field_name: typing_extensions___Literal[u"compressed_data",b"compressed_data",u"float_data",b"float_data",u"observation_data",b"observation_data"]) -> builtin___bool: ...
- def ClearField(self, field_name: typing_extensions___Literal[u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"float_data",b"float_data",u"observation_data",b"observation_data",u"shape",b"shape"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",b"compressed_channel_mapping",u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"float_data",b"float_data",u"observation_data",b"observation_data",u"shape",b"shape"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions___Literal[u"observation_data",b"observation_data"]) -> typing_extensions___Literal["compressed_data","float_data"]: ...
diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py
index 9cd32e364a..2997b72680 100644
--- a/ml-agents-envs/mlagents_envs/environment.py
+++ b/ml-agents-envs/mlagents_envs/environment.py
@@ -59,7 +59,8 @@ class UnityEnvironment(BaseEnv):
# Revision history:
# * 1.0.0 - initial version
# * 1.1.0 - support concatenated PNGs for compressed observations.
- API_VERSION = "1.1.0"
+ # * 1.2.0 - support compression mapping for stacked compressed observations.
+ API_VERSION = "1.2.0"
# Default port that the editor listens on. If an environment executable
# isn't specified, this port will be used.
@@ -118,6 +119,7 @@ def _get_capabilities_proto() -> UnityRLCapabilitiesProto:
capabilities = UnityRLCapabilitiesProto()
capabilities.baseRLCapabilities = True
capabilities.concatenatedPngObservations = True
+ capabilities.compressedChannelMapping = True
return capabilities
@staticmethod
diff --git a/ml-agents-envs/mlagents_envs/rpc_utils.py b/ml-agents-envs/mlagents_envs/rpc_utils.py
index 9a903fc743..abe7d9c83d 100644
--- a/ml-agents-envs/mlagents_envs/rpc_utils.py
+++ b/ml-agents-envs/mlagents_envs/rpc_utils.py
@@ -78,7 +78,9 @@ def original_tell(self) -> int:
@timed
-def process_pixels(image_bytes: bytes, expected_channels: int) -> np.ndarray:
+def process_pixels(
+ image_bytes: bytes, expected_channels: int, mappings: Optional[List[int]] = None
+) -> np.ndarray:
"""
Converts byte array observation image into numpy array, re-sizes it,
and optionally converts it to grey scale
@@ -88,23 +90,12 @@ def process_pixels(image_bytes: bytes, expected_channels: int) -> np.ndarray:
"""
image_fp = OffsetBytesIO(image_bytes)
- if expected_channels == 1:
- # Convert to grayscale
- with hierarchical_timer("image_decompress"):
- image = Image.open(image_fp)
- # Normally Image loads lazily, load() forces it to do loading in the timer scope.
- image.load()
- s = np.array(image, dtype=np.float32) / 255.0
- s = np.mean(s, axis=2)
- s = np.reshape(s, [s.shape[0], s.shape[1], 1])
- return s
-
image_arrays = []
-
# Read the images back from the bytes (without knowing the sizes).
while True:
with hierarchical_timer("image_decompress"):
image = Image.open(image_fp)
+ # Normally Image loads lazily, load() forces it to do loading in the timer scope.
image.load()
image_arrays.append(np.array(image, dtype=np.float32) / 255.0)
@@ -116,13 +107,61 @@ def process_pixels(image_bytes: bytes, expected_channels: int) -> np.ndarray:
# Didn't find the header, so must be at the end.
break
- img = np.concatenate(image_arrays, axis=2)
- # We can drop additional channels since they may need to be added to include
- # numbers of observation channels not divisible by 3.
- actual_channels = list(img.shape)[2]
- if actual_channels > expected_channels:
- img = img[..., 0:expected_channels]
+ if mappings is not None and len(mappings) > 0:
+ return _process_images_mapping(image_arrays, mappings)
+ else:
+ return _process_images_num_channels(image_arrays, expected_channels)
+
+
+def _process_images_mapping(image_arrays, mappings):
+ """
+ Helper function for processing decompressed images with compressed channel mappings.
+ """
+ image_arrays = np.concatenate(image_arrays, axis=2).transpose((2, 0, 1))
+
+ if len(mappings) != len(image_arrays):
+ raise UnityObservationException(
+ f"Compressed observation and its mapping had different number of channels - "
+ f"observation had {len(image_arrays)} channels but its mapping had {len(mappings)} channels"
+ )
+ if len({m for m in mappings if m > -1}) != max(mappings) + 1:
+ raise UnityObservationException(
+ f"Invalid Compressed Channel Mapping: the mapping {mappings} does not have the correct format."
+ )
+ if max(mappings) >= len(image_arrays):
+ raise UnityObservationException(
+ f"Invalid Compressed Channel Mapping: the mapping has index larger than the total "
+ f"number of channels in observation - mapping index {max(mappings)} is"
+ f"invalid for input observation with {len(image_arrays)} channels."
+ )
+
+ processed_image_arrays: List[np.array] = [[] for _ in range(max(mappings) + 1)]
+ for mapping_idx, img in zip(mappings, image_arrays):
+ if mapping_idx > -1:
+ processed_image_arrays[mapping_idx].append(img)
+
+ for i, img_array in enumerate(processed_image_arrays):
+ processed_image_arrays[i] = np.mean(img_array, axis=0)
+ img = np.stack(processed_image_arrays, axis=2)
+ return img
+
+def _process_images_num_channels(image_arrays, expected_channels):
+ """
+ Helper function for processing decompressed images with number of expected channels.
+ This is for old API without mapping provided. Use the first n channel, n=expected_channels.
+ """
+ if expected_channels == 1:
+ # Convert to grayscale
+ img = np.mean(image_arrays[0], axis=2)
+ img = np.reshape(img, [img.shape[0], img.shape[1], 1])
+ else:
+ img = np.concatenate(image_arrays, axis=2)
+ # We can drop additional channels since they may need to be added to include
+ # numbers of observation channels not divisible by 3.
+ actual_channels = list(img.shape)[2]
+ if actual_channels > expected_channels:
+ img = img[..., 0:expected_channels]
return img
@@ -147,7 +186,9 @@ def observation_to_np_array(
img = np.reshape(img, obs.shape)
return img
else:
- img = process_pixels(obs.compressed_data, expected_channels)
+ img = process_pixels(
+ obs.compressed_data, expected_channels, list(obs.compressed_channel_mapping)
+ )
# Compare decompressed image size to observation shape and make sure they match
if list(obs.shape) != list(img.shape):
raise UnityObservationException(
diff --git a/ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py b/ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
index 5f1a1825fc..87db474e93 100644
--- a/ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
+++ b/ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
@@ -82,11 +82,39 @@ def generate_compressed_data(in_array: np.ndarray) -> bytes:
return bytes_out
-def generate_compressed_proto_obs(in_array: np.ndarray) -> ObservationProto:
+# test helper function for old C# API (no compressed channel mapping)
+def generate_compressed_proto_obs(
+ in_array: np.ndarray, grayscale: bool = False
+) -> ObservationProto:
obs_proto = ObservationProto()
obs_proto.compressed_data = generate_compressed_data(in_array)
obs_proto.compression_type = PNG
- obs_proto.shape.extend(in_array.shape)
+ if grayscale:
+ # grayscale flag is only used for old API without mapping
+ expected_shape = [in_array.shape[0], in_array.shape[1], 1]
+ obs_proto.shape.extend(expected_shape)
+ else:
+ obs_proto.shape.extend(in_array.shape)
+ return obs_proto
+
+
+# test helper function for new C# API (with compressed channel mapping)
+def generate_compressed_proto_obs_with_mapping(
+ in_array: np.ndarray, mapping: List[int]
+) -> ObservationProto:
+ obs_proto = ObservationProto()
+ obs_proto.compressed_data = generate_compressed_data(in_array)
+ obs_proto.compression_type = PNG
+ if mapping is not None:
+ obs_proto.compressed_channel_mapping.extend(mapping)
+ expected_shape = [
+ in_array.shape[0],
+ in_array.shape[1],
+ len({m for m in mapping if m >= 0}),
+ ]
+ obs_proto.shape.extend(expected_shape)
+ else:
+ obs_proto.shape.extend(in_array.shape)
return obs_proto
@@ -231,7 +259,11 @@ def test_process_visual_observation():
in_array_1 = np.random.rand(128, 64, 3)
proto_obs_1 = generate_compressed_proto_obs(in_array_1)
in_array_2 = np.random.rand(128, 64, 3)
- proto_obs_2 = generate_uncompressed_proto_obs(in_array_2)
+ in_array_2_mapping = [0, 1, 2]
+ proto_obs_2 = generate_compressed_proto_obs_with_mapping(
+ in_array_2, in_array_2_mapping
+ )
+
ap1 = AgentInfoProto()
ap1.observations.extend([proto_obs_1])
ap2 = AgentInfoProto()
@@ -243,6 +275,44 @@ def test_process_visual_observation():
assert np.allclose(arr[1, :, :, :], in_array_2, atol=0.01)
+def test_process_visual_observation_grayscale():
+ in_array_1 = np.random.rand(128, 64, 3)
+ proto_obs_1 = generate_compressed_proto_obs(in_array_1, grayscale=True)
+ expected_out_array_1 = np.mean(in_array_1, axis=2, keepdims=True)
+ in_array_2 = np.random.rand(128, 64, 3)
+ in_array_2_mapping = [0, 0, 0]
+ proto_obs_2 = generate_compressed_proto_obs_with_mapping(
+ in_array_2, in_array_2_mapping
+ )
+ expected_out_array_2 = np.mean(in_array_2, axis=2, keepdims=True)
+
+ ap1 = AgentInfoProto()
+ ap1.observations.extend([proto_obs_1])
+ ap2 = AgentInfoProto()
+ ap2.observations.extend([proto_obs_2])
+ ap_list = [ap1, ap2]
+ arr = _process_visual_observation(0, (128, 64, 1), ap_list)
+ assert list(arr.shape) == [2, 128, 64, 1]
+ assert np.allclose(arr[0, :, :, :], expected_out_array_1, atol=0.01)
+ assert np.allclose(arr[1, :, :, :], expected_out_array_2, atol=0.01)
+
+
+def test_process_visual_observation_padded_channels():
+ in_array_1 = np.random.rand(128, 64, 12)
+ in_array_1_mapping = [0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1]
+ proto_obs_1 = generate_compressed_proto_obs_with_mapping(
+ in_array_1, in_array_1_mapping
+ )
+ expected_out_array_1 = np.take(in_array_1, [0, 1, 2, 3, 6, 7, 8, 9], axis=2)
+
+ ap1 = AgentInfoProto()
+ ap1.observations.extend([proto_obs_1])
+ ap_list = [ap1]
+ arr = _process_visual_observation(0, (128, 64, 8), ap_list)
+ assert list(arr.shape) == [1, 128, 64, 8]
+ assert np.allclose(arr[0, :, :, :], expected_out_array_1, atol=0.01)
+
+
def test_process_visual_observation_bad_shape():
in_array_1 = np.random.rand(128, 64, 3)
proto_obs_1 = generate_compressed_proto_obs(in_array_1)
diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
index 99e78a4e1f..7f03b20886 100644
--- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
+++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
@@ -13,4 +13,7 @@ message UnityRLCapabilitiesProto {
// concatenated PNG files for compressed visual observations with >3 channels.
bool concatenatedPngObservations = 2;
+
+ // compression mapping for stacking compressed observations.
+ bool compressedChannelMapping = 3;
}
diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto
index 3b57ba5bbd..6bde365afc 100644
--- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto
+++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto
@@ -19,4 +19,5 @@ message ObservationProto {
bytes compressed_data = 3;
FloatData float_data = 4;
}
+ repeated int32 compressed_channel_mapping = 5;
}