diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 8a34b187ea..a318fabce3 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -13,11 +13,19 @@ and this project adheres to - Added the Random Network Distillation (RND) intrinsic reward signal to the Pytorch trainers. To use RND, add a `rnd` section to the `reward_signals` section of your yaml configuration file. [More information here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-Configuration-File.md#rnd-intrinsic-reward) - ### Minor Changes #### com.unity.ml-agents (C#) + - Stacking for compressed observations is now supported. An addtional setting + option `Observation Stacks` is added in editor to sensor components that support + compressed observations. A new class `ISparseChannelSensor` with an + additional method `GetCompressedChannelMapping()`is added to generate a mapping + of the channels in compressed data to the actual channel after decompression, + for the python side to decompress correctly. (#4476) #### ml-agents / ml-agents-envs / gym-unity (Python) - + - The Communication API was changed to 1.2.0 to indicate support for stacked + compressed observation. A new entry `compressed_channel_mapping` is added to the + proto to handle decompression correctly. Newer versions of the package that wish to + make use of this will also need a compatible version of the Python trainers. (#4476) ### Bug Fixes #### com.unity.ml-agents (C#) #### ml-agents / ml-agents-envs / gym-unity (Python) diff --git a/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs b/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs index 99904e4363..baf9873324 100644 --- a/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs +++ b/com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs @@ -25,6 +25,7 @@ public override void OnInspectorGUI() EditorGUILayout.PropertyField(so.FindProperty("m_Width"), true); EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true); EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true); } EditorGUI.EndDisabledGroup(); EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true); diff --git a/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs b/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs index 55c621a9c4..10a7769d8d 100644 --- a/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs +++ b/com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs @@ -20,6 +20,7 @@ public override void OnInspectorGUI() EditorGUILayout.PropertyField(so.FindProperty("m_RenderTexture"), true); EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true); EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true); + EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true); } EditorGUI.EndDisabledGroup(); diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs index fac6934c22..383526b2e8 100644 --- a/com.unity.ml-agents/Runtime/Academy.cs +++ b/com.unity.ml-agents/Runtime/Academy.cs @@ -74,9 +74,13 @@ public class Academy : IDisposable /// 1.1.0 /// Support concatenated PNGs for compressed observations. /// + /// + /// 1.2.0 + /// Support compression mapping for stacked compressed observations. + /// /// /// - const string k_ApiVersion = "1.1.0"; + const string k_ApiVersion = "1.2.0"; /// /// Unity package version of com.unity.ml-agents. diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index 131685f8d9..5f79b5f068 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -222,7 +222,8 @@ public static List ToAgentActionList(this UnityRLInputProto.Types.ListA /// /// Static flag to make sure that we only fire the warning once. /// - private static bool s_HaveWarnedAboutTrainerCapabilities = false; + private static bool s_HaveWarnedTrainerCapabilitiesMultiPng = false; + private static bool s_HaveWarnedTrainerCapabilitiesMapping = false; /// /// Generate an ObservationProto for the sensor using the provided ObservationWriter. @@ -243,10 +244,27 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations; if (!trainerCanHandle) { - if (!s_HaveWarnedAboutTrainerCapabilities) + if (!s_HaveWarnedTrainerCapabilitiesMultiPng) { Debug.LogWarning($"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}."); - s_HaveWarnedAboutTrainerCapabilities = true; + s_HaveWarnedTrainerCapabilitiesMultiPng = true; + } + compressionType = SensorCompressionType.None; + } + } + // Check capabilities if we need mapping for compressed observations + if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3) + { + var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping; + var isTrivialMapping = IsTrivialMapping(sensor); + if (!trainerCanHandleMapping && !isTrivialMapping) + { + if (!s_HaveWarnedTrainerCapabilitiesMapping) + { + Debug.LogWarning($"The sensor {sensor.GetName()} is using non-trivial mapping and " + + "the attached trainer doesn't support compression mapping. " + + "Switching to uncompressed observations."); + s_HaveWarnedTrainerCapabilitiesMapping = true; } compressionType = SensorCompressionType.None; } @@ -283,12 +301,16 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat "return SensorCompressionType.None from GetCompressionType()." ); } - observationProto = new ObservationProto { CompressedData = ByteString.CopyFrom(compressedObs), CompressionType = (CompressionTypeProto)sensor.GetCompressionType(), }; + var compressibleSensor = sensor as ISparseChannelSensor; + if (compressibleSensor != null) + { + observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping()); + } } observationProto.Shape.AddRange(shape); return observationProto; @@ -300,7 +322,8 @@ public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto return new UnityRLCapabilities { BaseRLCapabilities = proto.BaseRLCapabilities, - ConcatenatedPngObservations = proto.ConcatenatedPngObservations + ConcatenatedPngObservations = proto.ConcatenatedPngObservations, + CompressedChannelMapping = proto.CompressedChannelMapping, }; } @@ -310,7 +333,36 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps) { BaseRLCapabilities = rlCaps.BaseRLCapabilities, ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations, + CompressedChannelMapping = rlCaps.CompressedChannelMapping, }; } + + internal static bool IsTrivialMapping(ISensor sensor) + { + var compressibleSensor = sensor as ISparseChannelSensor; + if (compressibleSensor is null) + { + return true; + } + var mapping = compressibleSensor.GetCompressedChannelMapping(); + if (mapping == null) + { + return true; + } + // check if mapping equals zero mapping + if (mapping.Length == 3 && mapping.All(m => m == 0)) + { + return true; + } + // check if mapping equals identity mapping + for (var i = 0; i < mapping.Length; i++) + { + if (mapping[i] != i) + { + return false; + } + } + return true; + } } } diff --git a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs index 7e188af421..30a72c19e3 100644 --- a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs +++ b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs @@ -6,15 +6,17 @@ internal class UnityRLCapabilities { public bool BaseRLCapabilities; public bool ConcatenatedPngObservations; + public bool CompressedChannelMapping; /// /// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This /// struct will be used to inform users if and when they are using C# / Trainer features that are mismatched. /// - public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true) + public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true, bool compressedChannelMapping = true) { BaseRLCapabilities = baseRlCapabilities; ConcatenatedPngObservations = concatenatedPngObservations; + CompressedChannelMapping = compressedChannelMapping; } /// diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs index c30d7ec6f9..57050424e7 100644 --- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs @@ -25,14 +25,15 @@ static CapabilitiesReflection() { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( "CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp", - "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiWwoYVW5pdHlSTENh", + "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMifQoYVW5pdHlSTENh", "cGFiaWxpdGllc1Byb3RvEhoKEmJhc2VSTENhcGFiaWxpdGllcxgBIAEoCBIj", - "Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAhCJaoCIlVuaXR5", - "Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + "Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAgSIAoYY29tcHJl", + "c3NlZENoYW5uZWxNYXBwaW5nGAMgASgIQiWqAiJVbml0eS5NTEFnZW50cy5D", + "b21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM=")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations" }, null, null, null) + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping" }, null, null, null) })); } #endregion @@ -71,6 +72,7 @@ public UnityRLCapabilitiesProto() { public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() { baseRLCapabilities_ = other.baseRLCapabilities_; concatenatedPngObservations_ = other.concatenatedPngObservations_; + compressedChannelMapping_ = other.compressedChannelMapping_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -107,6 +109,20 @@ public bool ConcatenatedPngObservations { } } + /// Field number for the "compressedChannelMapping" field. + public const int CompressedChannelMappingFieldNumber = 3; + private bool compressedChannelMapping_; + /// + /// compression mapping for stacking compressed observations. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool CompressedChannelMapping { + get { return compressedChannelMapping_; } + set { + compressedChannelMapping_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as UnityRLCapabilitiesProto); @@ -122,6 +138,7 @@ public bool Equals(UnityRLCapabilitiesProto other) { } if (BaseRLCapabilities != other.BaseRLCapabilities) return false; if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false; + if (CompressedChannelMapping != other.CompressedChannelMapping) return false; return Equals(_unknownFields, other._unknownFields); } @@ -130,6 +147,7 @@ public override int GetHashCode() { int hash = 1; if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode(); if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode(); + if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -151,6 +169,10 @@ public void WriteTo(pb::CodedOutputStream output) { output.WriteRawTag(16); output.WriteBool(ConcatenatedPngObservations); } + if (CompressedChannelMapping != false) { + output.WriteRawTag(24); + output.WriteBool(CompressedChannelMapping); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -165,6 +187,9 @@ public int CalculateSize() { if (ConcatenatedPngObservations != false) { size += 1 + 1; } + if (CompressedChannelMapping != false) { + size += 1 + 1; + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -182,6 +207,9 @@ public void MergeFrom(UnityRLCapabilitiesProto other) { if (other.ConcatenatedPngObservations != false) { ConcatenatedPngObservations = other.ConcatenatedPngObservations; } + if (other.CompressedChannelMapping != false) { + CompressedChannelMapping = other.CompressedChannelMapping; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -201,6 +229,10 @@ public void MergeFrom(pb::CodedInputStream input) { ConcatenatedPngObservations = input.ReadBool(); break; } + case 24: { + CompressedChannelMapping = input.ReadBool(); + break; + } } } } diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs index 8c38ef31a8..fee69f568c 100644 --- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs @@ -25,19 +25,20 @@ static ObservationReflection() { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( "CjRtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0", - "aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyL5AQoQT2JzZXJ2YXRp", + "aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKdAgoQT2JzZXJ2YXRp", "b25Qcm90bxINCgVzaGFwZRgBIAMoBRJEChBjb21wcmVzc2lvbl90eXBlGAIg", "ASgOMiouY29tbXVuaWNhdG9yX29iamVjdHMuQ29tcHJlc3Npb25UeXBlUHJv", "dG8SGQoPY29tcHJlc3NlZF9kYXRhGAMgASgMSAASRgoKZmxvYXRfZGF0YRgE", "IAEoCzIwLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8u", - "RmxvYXREYXRhSAAaGQoJRmxvYXREYXRhEgwKBGRhdGEYASADKAJCEgoQb2Jz", - "ZXJ2YXRpb25fZGF0YSopChRDb21wcmVzc2lvblR5cGVQcm90bxIICgROT05F", - "EAASBwoDUE5HEAFCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9i", - "amVjdHNiBnByb3RvMw==")); + "RmxvYXREYXRhSAASIgoaY29tcHJlc3NlZF9jaGFubmVsX21hcHBpbmcYBSAD", + "KAUaGQoJRmxvYXREYXRhEgwKBGRhdGEYASADKAJCEgoQb2JzZXJ2YXRpb25f", + "ZGF0YSopChRDb21wcmVzc2lvblR5cGVQcm90bxIICgROT05FEAASBwoDUE5H", + "EAFCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy", + "b3RvMw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { }, new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), }, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)}) + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)}) })); } #endregion @@ -79,6 +80,7 @@ public ObservationProto() { public ObservationProto(ObservationProto other) : this() { shape_ = other.shape_.Clone(); compressionType_ = other.compressionType_; + compressedChannelMapping_ = other.compressedChannelMapping_.Clone(); switch (other.ObservationDataCase) { case ObservationDataOneofCase.CompressedData: CompressedData = other.CompressedData; @@ -139,6 +141,16 @@ public ObservationProto Clone() { } } + /// Field number for the "compressed_channel_mapping" field. + public const int CompressedChannelMappingFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_compressedChannelMapping_codec + = pb::FieldCodec.ForInt32(42); + private readonly pbc::RepeatedField compressedChannelMapping_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField CompressedChannelMapping { + get { return compressedChannelMapping_; } + } + private object observationData_; /// Enum of possible cases for the "observation_data" oneof. public enum ObservationDataOneofCase { @@ -175,6 +187,7 @@ public bool Equals(ObservationProto other) { if (CompressionType != other.CompressionType) return false; if (CompressedData != other.CompressedData) return false; if (!object.Equals(FloatData, other.FloatData)) return false; + if(!compressedChannelMapping_.Equals(other.compressedChannelMapping_)) return false; if (ObservationDataCase != other.ObservationDataCase) return false; return Equals(_unknownFields, other._unknownFields); } @@ -186,6 +199,7 @@ public override int GetHashCode() { if (CompressionType != 0) hash ^= CompressionType.GetHashCode(); if (observationDataCase_ == ObservationDataOneofCase.CompressedData) hash ^= CompressedData.GetHashCode(); if (observationDataCase_ == ObservationDataOneofCase.FloatData) hash ^= FloatData.GetHashCode(); + hash ^= compressedChannelMapping_.GetHashCode(); hash ^= (int) observationDataCase_; if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); @@ -213,6 +227,7 @@ public void WriteTo(pb::CodedOutputStream output) { output.WriteRawTag(34); output.WriteMessage(FloatData); } + compressedChannelMapping_.WriteTo(output, _repeated_compressedChannelMapping_codec); if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -231,6 +246,7 @@ public int CalculateSize() { if (observationDataCase_ == ObservationDataOneofCase.FloatData) { size += 1 + pb::CodedOutputStream.ComputeMessageSize(FloatData); } + size += compressedChannelMapping_.CalculateSize(_repeated_compressedChannelMapping_codec); if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -246,6 +262,7 @@ public void MergeFrom(ObservationProto other) { if (other.CompressionType != 0) { CompressionType = other.CompressionType; } + compressedChannelMapping_.Add(other.compressedChannelMapping_); switch (other.ObservationDataCase) { case ObservationDataOneofCase.CompressedData: CompressedData = other.CompressedData; @@ -291,6 +308,11 @@ public void MergeFrom(pb::CodedInputStream input) { FloatData = subBuilder; break; } + case 42: + case 40: { + compressedChannelMapping_.AddEntriesFrom(input, _repeated_compressedChannelMapping_codec); + break; + } } } } diff --git a/com.unity.ml-agents/Runtime/SensorHelper.cs b/com.unity.ml-agents/Runtime/SensorHelper.cs index 471a768b27..5870d0789a 100644 --- a/com.unity.ml-agents/Runtime/SensorHelper.cs +++ b/com.unity.ml-agents/Runtime/SensorHelper.cs @@ -1,4 +1,5 @@ using UnityEngine; +using Unity.Barracuda; namespace Unity.MLAgents.Sensors { @@ -62,5 +63,58 @@ public static bool CompareObservation(ISensor sensor, float[] expected, out stri errorMessage = null; return true; } + + public static bool CompareObservation(ISensor sensor, float[,,] expected, out string errorMessage) + { + var tensorShape = new TensorShape(0, expected.GetLength(0), expected.GetLength(1), expected.GetLength(2)); + var numExpected = tensorShape.height * tensorShape.width * tensorShape.channels; + const float fill = -1337f; + var output = new float[numExpected]; + for (var i = 0; i < numExpected; i++) + { + output[i] = fill; + } + + if (numExpected > 0) + { + if (fill != output[0]) + { + errorMessage = "Error setting output buffer."; + return false; + } + } + + ObservationWriter writer = new ObservationWriter(); + writer.SetTarget(output, sensor.GetObservationShape(), 0); + + // Make sure ObservationWriter didn't touch anything + if (numExpected > 0) + { + if (fill != output[0]) + { + errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have."; + return false; + } + } + + sensor.Write(writer); + for (var h = 0; h < tensorShape.height; h++) + { + for (var w = 0; w < tensorShape.width; w++) + { + for (var c = 0; c < tensorShape.channels; c++) + { + if (expected[h, w, c] != output[tensorShape.Index(0, h, w, c)]) + { + errorMessage = $"Expected and actual differed in position [{h}, {w}, {c}]. " + + "Expected: {expected[h, w, c]} Actual: {output[tensorShape.Index(0, h, w, c)]} "; + return false; + } + } + } + } + errorMessage = null; + return true; + } } } diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs index fcfa350797..ebb9d9dc73 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs @@ -75,6 +75,11 @@ public bool Grayscale set { m_Grayscale = value; } } + [HideInInspector, SerializeField] + [Range(1, 50)] + [Tooltip("Number of camera frames that will be stacked before being fed to the neural network.")] + int m_ObservationStacks = 1; + [HideInInspector, SerializeField, FormerlySerializedAs("compression")] SensorCompressionType m_Compression = SensorCompressionType.PNG; @@ -87,6 +92,16 @@ public SensorCompressionType CompressionType set { m_Compression = value; UpdateSensor(); } } + /// + /// Whether to stack previous observations. Using 1 means no previous observations. + /// Note that changing this after the sensor is created has no effect. + /// + public int ObservationStacks + { + get { return m_ObservationStacks; } + set { m_ObservationStacks = value; } + } + /// /// Creates the /// @@ -94,6 +109,11 @@ public SensorCompressionType CompressionType public override ISensor CreateSensor() { m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression); + + if (ObservationStacks != 1) + { + return new StackingSensor(m_Sensor, ObservationStacks); + } return m_Sensor; } @@ -103,7 +123,13 @@ public override ISensor CreateSensor() /// The observation shape of the associated object. public override int[] GetObservationShape() { - return CameraSensor.GenerateShape(m_Width, m_Height, Grayscale); + var stacks = ObservationStacks > 1 ? ObservationStacks : 1; + var cameraSensorshape = CameraSensor.GenerateShape(m_Width, m_Height, Grayscale); + if (stacks > 1) + { + cameraSensorshape[cameraSensorshape.Length - 1] *= stacks; + } + return cameraSensorshape; } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs b/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs new file mode 100644 index 0000000000..06b517eb42 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs @@ -0,0 +1,20 @@ +namespace Unity.MLAgents.Sensors +{ + /// + /// Sensor interface for sparse channel sensor which requires a compressed channel mapping. + /// + public interface ISparseChannelSensor : ISensor + { + /// + /// Returns the mapping of the channels in compressed data to the actual channel after decompression. + /// The mapping is a list of interger index with the same length as + /// the number of output observation layers (channels), including padding if there's any. + /// Each index indicates the actual channel the layer will go into. + /// Layers with the same index will be averaged, and layers with negative index will be dropped. + /// For example, mapping for CameraSensor using grayscale and stacking of two: [0, 0, 0, 1, 1, 1] + /// Mapping for GridSensor of 4 channels and stacking of two: [0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1] + /// + /// Mapping of the compressed data + int[] GetCompressedChannelMapping(); + } +} diff --git a/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs.meta b/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs.meta new file mode 100644 index 0000000000..bebec4f1cf --- /dev/null +++ b/com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 63bb76c1e31c24fa5b4a384ea0edbfb0 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs index f308507e0f..d22a3ffcaf 100644 --- a/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs @@ -54,6 +54,11 @@ public bool Grayscale set { m_Grayscale = value; } } + [HideInInspector, SerializeField] + [Range(1, 50)] + [Tooltip("Number of frames that will be stacked before being fed to the neural network.")] + int m_ObservationStacks = 1; + [HideInInspector, SerializeField, FormerlySerializedAs("compression")] SensorCompressionType m_Compression = SensorCompressionType.PNG; @@ -66,10 +71,24 @@ public SensorCompressionType CompressionType set { m_Compression = value; UpdateSensor(); } } + /// + /// Whether to stack previous observations. Using 1 means no previous observations. + /// Note that changing this after the sensor is created has no effect. + /// + public int ObservationStacks + { + get { return m_ObservationStacks; } + set { m_ObservationStacks = value; } + } + /// public override ISensor CreateSensor() { m_Sensor = new RenderTextureSensor(RenderTexture, Grayscale, SensorName, m_Compression); + if (ObservationStacks != 1) + { + return new StackingSensor(m_Sensor, ObservationStacks); + } return m_Sensor; } @@ -78,8 +97,15 @@ public override int[] GetObservationShape() { var width = RenderTexture != null ? RenderTexture.width : 0; var height = RenderTexture != null ? RenderTexture.height : 0; + var observationShape = new[] { height, width, Grayscale ? 1 : 3 }; + + var stacks = ObservationStacks > 1 ? ObservationStacks : 1; + if (stacks > 1) + { + observationShape[2] *= stacks; + } - return new[] { height, width, Grayscale ? 1 : 3 }; + return observationShape; } /// diff --git a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs index 962ffe1a9e..6064a758dd 100644 --- a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs @@ -1,4 +1,11 @@ using System; +using System.Linq; +using System.Runtime.CompilerServices; +using UnityEngine; +using Unity.Barracuda; + + +[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")] namespace Unity.MLAgents.Sensors { @@ -8,10 +15,9 @@ namespace Unity.MLAgents.Sensors /// For example, 4 stacked sets of observations would be output like /// | t = now - 3 | t = now -3 | t = now - 2 | t = now | /// Internally, a circular buffer of arrays is used. The m_CurrentIndex represents the most recent observation. - /// - /// Currently, compressed and multidimensional observations are not supported. + /// Currently, observations are stacked on the last dimension. /// - public class StackingSensor : ISensor + public class StackingSensor : ISparseChannelSensor { /// /// The wrapped sensor. @@ -26,15 +32,22 @@ public class StackingSensor : ISensor string m_Name; int[] m_Shape; + int[] m_WrappedShape; /// /// Buffer of previous observations /// float[][] m_StackedObservations; + byte[][] m_StackedCompressedObservations; + int m_CurrentIndex; ObservationWriter m_LocalWriter = new ObservationWriter(); + byte[] m_EmptyCompressedObservation; + int[] m_CompressionMapping; + TensorShape m_tensorShape; + /// /// Initializes the sensor. /// @@ -48,48 +61,78 @@ public StackingSensor(ISensor wrapped, int numStackedObservations) m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}"; - if (wrapped.GetCompressionType() != SensorCompressionType.None) - { - throw new UnityAgentsException("StackingSensor doesn't support compressed observations.'"); - } - - var shape = wrapped.GetObservationShape(); - if (shape.Length != 1) - { - throw new UnityAgentsException("Only 1-D observations are supported by StackingSensor"); - } - m_Shape = new int[shape.Length]; + m_WrappedShape = wrapped.GetObservationShape(); + m_Shape = new int[m_WrappedShape.Length]; m_UnstackedObservationSize = wrapped.ObservationSize(); - for (int d = 0; d < shape.Length; d++) + for (int d = 0; d < m_WrappedShape.Length; d++) { - m_Shape[d] = shape[d]; + m_Shape[d] = m_WrappedShape[d]; } // TODO support arbitrary stacking dimension - m_Shape[0] *= numStackedObservations; + m_Shape[m_Shape.Length - 1] *= numStackedObservations; + + // Initialize uncompressed buffer anyway in case python trainer does not + // support the compression mapping and has to fall back to uncompressed obs. m_StackedObservations = new float[numStackedObservations][]; for (var i = 0; i < numStackedObservations; i++) { m_StackedObservations[i] = new float[m_UnstackedObservationSize]; } + + if (m_WrappedSensor.GetCompressionType() != SensorCompressionType.None) + { + m_StackedCompressedObservations = new byte[numStackedObservations][]; + m_EmptyCompressedObservation = CreateEmptyPNG(); + for (var i = 0; i < numStackedObservations; i++) + { + m_StackedCompressedObservations[i] = m_EmptyCompressedObservation; + } + m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped); + } + + if (m_Shape.Length != 1) + { + m_tensorShape = new TensorShape(0, m_WrappedShape[0], m_WrappedShape[1], m_WrappedShape[2]); + } } /// public int Write(ObservationWriter writer) { // First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one. - var wrappedShape = m_WrappedSensor.GetObservationShape(); - m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], wrappedShape, 0); + m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedShape, 0); m_WrappedSensor.Write(m_LocalWriter); // Now write the saved observations (oldest first) var numWritten = 0; - for (var i = 0; i < m_NumStackedObservations; i++) + if (m_WrappedShape.Length == 1) { - var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; - writer.AddRange(m_StackedObservations[obsIndex], numWritten); - numWritten += m_UnstackedObservationSize; + for (var i = 0; i < m_NumStackedObservations; i++) + { + var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; + writer.AddRange(m_StackedObservations[obsIndex], numWritten); + numWritten += m_UnstackedObservationSize; + } + } + else + { + for (var i = 0; i < m_NumStackedObservations; i++) + { + var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; + for (var h = 0; h < m_WrappedShape[0]; h++) + { + for (var w = 0; w < m_WrappedShape[1]; w++) + { + for (var c = 0; c < m_WrappedShape[2]; c++) + { + writer[h, w, i * m_WrappedShape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)]; + } + } + } + } + numWritten = m_WrappedShape[0] * m_WrappedShape[1] * m_WrappedShape[2] * m_NumStackedObservations; } return numWritten; @@ -113,6 +156,13 @@ public void Reset() { Array.Clear(m_StackedObservations[i], 0, m_StackedObservations[i].Length); } + if (m_WrappedSensor.GetCompressionType() != SensorCompressionType.None) + { + for (var i = 0; i < m_NumStackedObservations; i++) + { + m_StackedCompressedObservations[i] = m_EmptyCompressedObservation; + } + } } /// @@ -128,17 +178,101 @@ public string GetName() } /// - public virtual byte[] GetCompressedObservation() + public byte[] GetCompressedObservation() + { + var compressed = m_WrappedSensor.GetCompressedObservation(); + m_StackedCompressedObservations[m_CurrentIndex] = compressed; + + int bytesLength = 0; + foreach (byte[] compressedObs in m_StackedCompressedObservations) + { + bytesLength += compressedObs.Length; + } + + byte[] outputBytes = new byte[bytesLength]; + int offset = 0; + for (var i = 0; i < m_NumStackedObservations; i++) + { + var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; + Buffer.BlockCopy(m_StackedCompressedObservations[obsIndex], + 0, outputBytes, offset, m_StackedCompressedObservations[obsIndex].Length); + offset += m_StackedCompressedObservations[obsIndex].Length; + } + + return outputBytes; + } + + public int[] GetCompressedChannelMapping() { - return null; + return m_CompressionMapping; } /// - public virtual SensorCompressionType GetCompressionType() + public SensorCompressionType GetCompressionType() + { + return m_WrappedSensor.GetCompressionType(); + } + + /// + /// Create Empty PNG for initializing the buffer for stacking. + /// + internal byte[] CreateEmptyPNG() { - return SensorCompressionType.None; + int height = m_WrappedSensor.GetObservationShape()[0]; + int width = m_WrappedSensor.GetObservationShape()[1]; + var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false); + return texture2D.EncodeToPNG(); } - // TODO support stacked compressed observations (byte stream) + + /// + /// Constrct stacked CompressedChannelMapping. + /// + internal int[] ConstructStackedCompressedChannelMapping(ISensor wrappedSenesor) + { + // Get CompressedChannelMapping of the wrapped sensor. If the + // wrapped sensor doesn't have one, use default mapping. + // Default mapping: {0, 0, 0} for grayscale, identity mapping {1, 2, ..., n} otherwise. + int[] wrappedMapping = null; + int wrappedNumChannel = wrappedSenesor.GetObservationShape()[2]; + var sparseChannelSensor = m_WrappedSensor as ISparseChannelSensor; + if (sparseChannelSensor != null) + { + wrappedMapping = sparseChannelSensor.GetCompressedChannelMapping(); + } + if (wrappedMapping == null) + { + if (wrappedNumChannel == 1) + { + wrappedMapping = new int[] { 0, 0, 0 }; + } + else + { + wrappedMapping = Enumerable.Range(0, wrappedNumChannel).ToArray(); + } + } + + // Construct stacked mapping using the mapping of wrapped sensor. + // First pad the wrapped mapping to multiple of 3, then repeat + // and add offset to each copy to form the stacked mapping. + int paddedMapLength = (wrappedMapping.Length + 2) / 3 * 3; + var compressionMapping = new int[paddedMapLength * m_NumStackedObservations]; + for (var i = 0; i < m_NumStackedObservations; i++) + { + var offset = wrappedNumChannel * i; + for (var j = 0; j < paddedMapLength; j++) + { + if (j < wrappedMapping.Length) + { + compressionMapping[j + paddedMapLength * i] = wrappedMapping[j] >= 0 ? wrappedMapping[j] + offset : -1; + } + else + { + compressionMapping[j + paddedMapLength * i] = -1; + } + } + } + return compressionMapping; + } } } diff --git a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs index 675f473db4..88039d2b3f 100644 --- a/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs @@ -90,6 +90,19 @@ public string GetName() } } + class DummySparseChannelSensor : DummySensor, ISparseChannelSensor + { + public int[] Mapping; + internal DummySparseChannelSensor() + { + } + + public int[] GetCompressedChannelMapping() + { + return Mapping; + } + } + [Test] public void TestGetObservationProtoCapabilities() { @@ -139,5 +152,23 @@ public void TestGetObservationProtoCapabilities() } + + [Test] + public void TestIsTrivialMapping() + { + Assert.AreEqual(GrpcExtensions.IsTrivialMapping(new DummySensor()), true); + + var sparseChannelSensor = new DummySparseChannelSensor(); + sparseChannelSensor.Mapping = null; + Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true); + sparseChannelSensor.Mapping = new int[] { 0, 0, 0 }; + Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true); + sparseChannelSensor.Mapping = new int[] { 0, 1, 2, 3, 4 }; + Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true); + sparseChannelSensor.Mapping = new int[] { 1, 2, 3, 4, -1, -1 }; + Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false); + sparseChannelSensor.Mapping = new int[] { 0, 0, 0, 1, 1, 1 }; + Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false); + } } } diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs b/com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs new file mode 100644 index 0000000000..ea035b9875 --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs @@ -0,0 +1,23 @@ +using NUnit.Framework; +using UnityEngine; +using Unity.MLAgents.Sensors; + +namespace Unity.MLAgents.Tests +{ + public static class SensorTestHelper + { + public static void CompareObservation(ISensor sensor, float[] expected) + { + string errorMessage; + bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage); + Assert.IsTrue(isOK, errorMessage); + } + + public static void CompareObservation(ISensor sensor, float[,,] expected) + { + string errorMessage; + bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage); + Assert.IsTrue(isOK, errorMessage); + } + } +} diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs.meta b/com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs.meta new file mode 100644 index 0000000000..487ace557e --- /dev/null +++ b/com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e769354f8bd404ca180d7cd7302a5d61 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs index 8626fe3121..1954139691 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs @@ -1,4 +1,7 @@ using NUnit.Framework; +using System; +using System.Linq; +using UnityEngine; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Tests @@ -15,7 +18,7 @@ public void TestCtor() } [Test] - public void TestStacking() + public void TestVectorStacking() { VectorSensor wrapped = new VectorSensor(2); ISensor sensor = new StackingSensor(wrapped, 3); @@ -44,7 +47,7 @@ public void TestStacking() } [Test] - public void TestStackingReset() + public void TestVectorStackingReset() { VectorSensor wrapped = new VectorSensor(2); ISensor sensor = new StackingSensor(wrapped, 3); @@ -60,5 +63,154 @@ public void TestStackingReset() wrapped.AddObservation(new[] { 5f, 6f }); SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 5f, 6f }); } + + class Dummy3DSensor : ISparseChannelSensor + { + public SensorCompressionType CompressionType = SensorCompressionType.PNG; + public int[] Mapping; + public int[] Shape; + public float[,,] CurrentObservation; + + internal Dummy3DSensor() + { + } + + public int[] GetObservationShape() + { + return Shape; + } + + public int Write(ObservationWriter writer) + { + for (var h = 0; h < Shape[0]; h++) + { + for (var w = 0; w < Shape[1]; w++) + { + for (var c = 0; c < Shape[2]; c++) + { + writer[h, w, c] = CurrentObservation[h, w, c]; + } + } + } + return Shape[0] * Shape[1] * Shape[2]; + } + + public byte[] GetCompressedObservation() + { + var writer = new ObservationWriter(); + var flattenedObservation = new float[Shape[0] * Shape[1] * Shape[2]]; + writer.SetTarget(flattenedObservation, Shape, 0); + Write(writer); + byte[] bytes = Array.ConvertAll(flattenedObservation, (z) => (byte)z); + return bytes; + } + + public void Update() { } + + public void Reset() { } + + public SensorCompressionType GetCompressionType() + { + return CompressionType; + } + + public string GetName() + { + return "Dummy"; + } + + public int[] GetCompressedChannelMapping() + { + return Mapping; + } + } + + [Test] + public void TestStackingMapping() + { + // Test grayscale stacked mapping with CameraSensor + var cameraSensor = new CameraSensor(new Camera(), 64, 64, + true, "grayscaleCamera", SensorCompressionType.PNG); + var stackedCameraSensor = new StackingSensor(cameraSensor, 2); + Assert.AreEqual(stackedCameraSensor.GetCompressedChannelMapping(), new[] { 0, 0, 0, 1, 1, 1 }); + + // Test RGB stacked mapping with RenderTextureSensor + var renderTextureSensor = new RenderTextureSensor(new RenderTexture(24, 16, 0), + false, "renderTexture", SensorCompressionType.PNG); + var stackedRenderTextureSensor = new StackingSensor(renderTextureSensor, 2); + Assert.AreEqual(stackedRenderTextureSensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, 4, 5 }); + + // Test mapping with number of layers not being multiple of 3 + var dummySensor = new Dummy3DSensor(); + dummySensor.Shape = new int[] { 2, 2, 4 }; + dummySensor.Mapping = new int[] { 0, 1, 2, 3 }; + var stackedDummySensor = new StackingSensor(dummySensor, 2); + Assert.AreEqual(stackedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 }); + + // Test mapping with dummy layers that should be dropped + var paddedDummySensor = new Dummy3DSensor(); + paddedDummySensor.Shape = new int[] { 2, 2, 4 }; + paddedDummySensor.Mapping = new int[] { 0, 1, 2, 3, -1, -1 }; + var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2); + Assert.AreEqual(stackedPaddedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 }); + } + + [Test] + public void Test3DStacking() + { + var wrapped = new Dummy3DSensor(); + wrapped.Shape = new int[] { 2, 1, 2 }; + var sensor = new StackingSensor(wrapped, 2); + + // Check the stacking is on the last dimension + wrapped.CurrentObservation = new[, ,] { { { 1f, 2f } }, { { 3f, 4f } } }; + SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 0f, 0f, 1f, 2f } }, { { 0f, 0f, 3f, 4f } } }); + + sensor.Update(); + wrapped.CurrentObservation = new[, ,] { { { 5f, 6f } }, { { 7f, 8f } } }; + SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 1f, 2f, 5f, 6f } }, { { 3f, 4f, 7f, 8f } } }); + + sensor.Update(); + wrapped.CurrentObservation = new[, ,] { { { 9f, 10f } }, { { 11f, 12f } } }; + SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 5f, 6f, 9f, 10f } }, { { 7f, 8f, 11f, 12f } } }); + + // Check that if we don't call Update(), the same observations are produced + SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 5f, 6f, 9f, 10f } }, { { 7f, 8f, 11f, 12f } } }); + + // Test reset + sensor.Reset(); + wrapped.CurrentObservation = new[, ,] { { { 13f, 14f } }, { { 15f, 16f } } }; + SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 0f, 0f, 13f, 14f } }, { { 0f, 0f, 15f, 16f } } }); + } + + [Test] + public void TestStackedGetCompressedObservation() + { + var wrapped = new Dummy3DSensor(); + wrapped.Shape = new int[] { 1, 1, 3 }; + var sensor = new StackingSensor(wrapped, 2); + + wrapped.CurrentObservation = new[, ,] { { { 1f, 2f, 3f } } }; + var expected1 = sensor.CreateEmptyPNG(); + expected1 = expected1.Concat(Array.ConvertAll(new[] { 1f, 2f, 3f }, (z) => (byte)z)).ToArray(); + Assert.AreEqual(sensor.GetCompressedObservation(), expected1); + + sensor.Update(); + wrapped.CurrentObservation = new[, ,] { { { 4f, 5f, 6f } } }; + var expected2 = Array.ConvertAll(new[] { 1f, 2f, 3f, 4f, 5f, 6f }, (z) => (byte)z); + Assert.AreEqual(sensor.GetCompressedObservation(), expected2); + + sensor.Update(); + wrapped.CurrentObservation = new[, ,] { { { 7f, 8f, 9f } } }; + var expected3 = Array.ConvertAll(new[] { 4f, 5f, 6f, 7f, 8f, 9f }, (z) => (byte)z); + Assert.AreEqual(sensor.GetCompressedObservation(), expected3); + + // Test reset + sensor.Reset(); + wrapped.CurrentObservation = new[, ,] { { { 10f, 11f, 12f } } }; + var expected4 = sensor.CreateEmptyPNG(); + expected4 = expected4.Concat(Array.ConvertAll(new[] { 10f, 11f, 12f }, (z) => (byte)z)).ToArray(); + Assert.AreEqual(sensor.GetCompressedObservation(), expected4); + } } } diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs index 42cd377a86..5326bca868 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs @@ -4,16 +4,6 @@ namespace Unity.MLAgents.Tests { - public static class SensorTestHelper - { - public static void CompareObservation(ISensor sensor, float[] expected) - { - string errorMessage; - bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage); - Assert.IsTrue(isOK, errorMessage); - } - } - public class VectorSensorTests { [Test] diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py index 87a0cbd641..563e483a8e 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py +++ b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py @@ -19,7 +19,7 @@ name='mlagents_envs/communicator_objects/capabilities.proto', package='communicator_objects', syntax='proto3', - serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"[\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') + serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"}\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') ) @@ -46,6 +46,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='compressedChannelMapping', full_name='communicator_objects.UnityRLCapabilitiesProto.compressedChannelMapping', index=2, + number=3, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -59,7 +66,7 @@ oneofs=[ ], serialized_start=79, - serialized_end=170, + serialized_end=204, ) DESCRIPTOR.message_types_by_name['UnityRLCapabilitiesProto'] = _UNITYRLCAPABILITIESPROTO diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi index 600b817173..d2bc066912 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi +++ b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi @@ -27,17 +27,19 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message): DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... baseRLCapabilities = ... # type: builtin___bool concatenatedPngObservations = ... # type: builtin___bool + compressedChannelMapping = ... # type: builtin___bool def __init__(self, *, baseRLCapabilities : typing___Optional[builtin___bool] = None, concatenatedPngObservations : typing___Optional[builtin___bool] = None, + compressedChannelMapping : typing___Optional[builtin___bool] = None, ) -> None: ... @classmethod def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ... def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... if sys.version_info >= (3,): - def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"concatenatedPngObservations"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations"]) -> None: ... else: - def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"concatenatedPngObservations",b"concatenatedPngObservations"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations"]) -> None: ... diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py index e31a806676..43c840eaa0 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py +++ b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py @@ -20,7 +20,7 @@ name='mlagents_envs/communicator_objects/observation.proto', package='communicator_objects', syntax='proto3', - serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\xf9\x01\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') + serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x9d\x02\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') ) _COMPRESSIONTYPEPROTO = _descriptor.EnumDescriptor( @@ -40,8 +40,8 @@ ], containing_type=None, options=None, - serialized_start=330, - serialized_end=371, + serialized_start=366, + serialized_end=407, ) _sym_db.RegisterEnumDescriptor(_COMPRESSIONTYPEPROTO) @@ -77,8 +77,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=283, - serialized_end=308, + serialized_start=319, + serialized_end=344, ) _OBSERVATIONPROTO = _descriptor.Descriptor( @@ -116,6 +116,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='compressed_channel_mapping', full_name='communicator_objects.ObservationProto.compressed_channel_mapping', index=4, + number=5, type=5, cpp_type=1, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -132,7 +139,7 @@ index=0, containing_type=None, fields=[]), ], serialized_start=79, - serialized_end=328, + serialized_end=364, ) _OBSERVATIONPROTO_FLOATDATA.containing_type = _OBSERVATIONPROTO diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi index 79681430fb..bf38830c00 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi +++ b/ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi @@ -72,6 +72,7 @@ class ObservationProto(google___protobuf___message___Message): shape = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int] compression_type = ... # type: CompressionTypeProto compressed_data = ... # type: builtin___bytes + compressed_channel_mapping = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int] @property def float_data(self) -> ObservationProto.FloatData: ... @@ -82,6 +83,7 @@ class ObservationProto(google___protobuf___message___Message): compression_type : typing___Optional[CompressionTypeProto] = None, compressed_data : typing___Optional[builtin___bytes] = None, float_data : typing___Optional[ObservationProto.FloatData] = None, + compressed_channel_mapping : typing___Optional[typing___Iterable[builtin___int]] = None, ) -> None: ... @classmethod def FromString(cls, s: builtin___bytes) -> ObservationProto: ... @@ -89,8 +91,8 @@ class ObservationProto(google___protobuf___message___Message): def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... if sys.version_info >= (3,): def HasField(self, field_name: typing_extensions___Literal[u"compressed_data",u"float_data",u"observation_data"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"compressed_data",u"compression_type",u"float_data",u"observation_data",u"shape"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",u"compressed_data",u"compression_type",u"float_data",u"observation_data",u"shape"]) -> None: ... else: def HasField(self, field_name: typing_extensions___Literal[u"compressed_data",b"compressed_data",u"float_data",b"float_data",u"observation_data",b"observation_data"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"float_data",b"float_data",u"observation_data",b"observation_data",u"shape",b"shape"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",b"compressed_channel_mapping",u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"float_data",b"float_data",u"observation_data",b"observation_data",u"shape",b"shape"]) -> None: ... def WhichOneof(self, oneof_group: typing_extensions___Literal[u"observation_data",b"observation_data"]) -> typing_extensions___Literal["compressed_data","float_data"]: ... diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py index 9cd32e364a..2997b72680 100644 --- a/ml-agents-envs/mlagents_envs/environment.py +++ b/ml-agents-envs/mlagents_envs/environment.py @@ -59,7 +59,8 @@ class UnityEnvironment(BaseEnv): # Revision history: # * 1.0.0 - initial version # * 1.1.0 - support concatenated PNGs for compressed observations. - API_VERSION = "1.1.0" + # * 1.2.0 - support compression mapping for stacked compressed observations. + API_VERSION = "1.2.0" # Default port that the editor listens on. If an environment executable # isn't specified, this port will be used. @@ -118,6 +119,7 @@ def _get_capabilities_proto() -> UnityRLCapabilitiesProto: capabilities = UnityRLCapabilitiesProto() capabilities.baseRLCapabilities = True capabilities.concatenatedPngObservations = True + capabilities.compressedChannelMapping = True return capabilities @staticmethod diff --git a/ml-agents-envs/mlagents_envs/rpc_utils.py b/ml-agents-envs/mlagents_envs/rpc_utils.py index 9a903fc743..abe7d9c83d 100644 --- a/ml-agents-envs/mlagents_envs/rpc_utils.py +++ b/ml-agents-envs/mlagents_envs/rpc_utils.py @@ -78,7 +78,9 @@ def original_tell(self) -> int: @timed -def process_pixels(image_bytes: bytes, expected_channels: int) -> np.ndarray: +def process_pixels( + image_bytes: bytes, expected_channels: int, mappings: Optional[List[int]] = None +) -> np.ndarray: """ Converts byte array observation image into numpy array, re-sizes it, and optionally converts it to grey scale @@ -88,23 +90,12 @@ def process_pixels(image_bytes: bytes, expected_channels: int) -> np.ndarray: """ image_fp = OffsetBytesIO(image_bytes) - if expected_channels == 1: - # Convert to grayscale - with hierarchical_timer("image_decompress"): - image = Image.open(image_fp) - # Normally Image loads lazily, load() forces it to do loading in the timer scope. - image.load() - s = np.array(image, dtype=np.float32) / 255.0 - s = np.mean(s, axis=2) - s = np.reshape(s, [s.shape[0], s.shape[1], 1]) - return s - image_arrays = [] - # Read the images back from the bytes (without knowing the sizes). while True: with hierarchical_timer("image_decompress"): image = Image.open(image_fp) + # Normally Image loads lazily, load() forces it to do loading in the timer scope. image.load() image_arrays.append(np.array(image, dtype=np.float32) / 255.0) @@ -116,13 +107,61 @@ def process_pixels(image_bytes: bytes, expected_channels: int) -> np.ndarray: # Didn't find the header, so must be at the end. break - img = np.concatenate(image_arrays, axis=2) - # We can drop additional channels since they may need to be added to include - # numbers of observation channels not divisible by 3. - actual_channels = list(img.shape)[2] - if actual_channels > expected_channels: - img = img[..., 0:expected_channels] + if mappings is not None and len(mappings) > 0: + return _process_images_mapping(image_arrays, mappings) + else: + return _process_images_num_channels(image_arrays, expected_channels) + + +def _process_images_mapping(image_arrays, mappings): + """ + Helper function for processing decompressed images with compressed channel mappings. + """ + image_arrays = np.concatenate(image_arrays, axis=2).transpose((2, 0, 1)) + + if len(mappings) != len(image_arrays): + raise UnityObservationException( + f"Compressed observation and its mapping had different number of channels - " + f"observation had {len(image_arrays)} channels but its mapping had {len(mappings)} channels" + ) + if len({m for m in mappings if m > -1}) != max(mappings) + 1: + raise UnityObservationException( + f"Invalid Compressed Channel Mapping: the mapping {mappings} does not have the correct format." + ) + if max(mappings) >= len(image_arrays): + raise UnityObservationException( + f"Invalid Compressed Channel Mapping: the mapping has index larger than the total " + f"number of channels in observation - mapping index {max(mappings)} is" + f"invalid for input observation with {len(image_arrays)} channels." + ) + + processed_image_arrays: List[np.array] = [[] for _ in range(max(mappings) + 1)] + for mapping_idx, img in zip(mappings, image_arrays): + if mapping_idx > -1: + processed_image_arrays[mapping_idx].append(img) + + for i, img_array in enumerate(processed_image_arrays): + processed_image_arrays[i] = np.mean(img_array, axis=0) + img = np.stack(processed_image_arrays, axis=2) + return img + +def _process_images_num_channels(image_arrays, expected_channels): + """ + Helper function for processing decompressed images with number of expected channels. + This is for old API without mapping provided. Use the first n channel, n=expected_channels. + """ + if expected_channels == 1: + # Convert to grayscale + img = np.mean(image_arrays[0], axis=2) + img = np.reshape(img, [img.shape[0], img.shape[1], 1]) + else: + img = np.concatenate(image_arrays, axis=2) + # We can drop additional channels since they may need to be added to include + # numbers of observation channels not divisible by 3. + actual_channels = list(img.shape)[2] + if actual_channels > expected_channels: + img = img[..., 0:expected_channels] return img @@ -147,7 +186,9 @@ def observation_to_np_array( img = np.reshape(img, obs.shape) return img else: - img = process_pixels(obs.compressed_data, expected_channels) + img = process_pixels( + obs.compressed_data, expected_channels, list(obs.compressed_channel_mapping) + ) # Compare decompressed image size to observation shape and make sure they match if list(obs.shape) != list(img.shape): raise UnityObservationException( diff --git a/ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py b/ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py index 5f1a1825fc..87db474e93 100644 --- a/ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py +++ b/ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py @@ -82,11 +82,39 @@ def generate_compressed_data(in_array: np.ndarray) -> bytes: return bytes_out -def generate_compressed_proto_obs(in_array: np.ndarray) -> ObservationProto: +# test helper function for old C# API (no compressed channel mapping) +def generate_compressed_proto_obs( + in_array: np.ndarray, grayscale: bool = False +) -> ObservationProto: obs_proto = ObservationProto() obs_proto.compressed_data = generate_compressed_data(in_array) obs_proto.compression_type = PNG - obs_proto.shape.extend(in_array.shape) + if grayscale: + # grayscale flag is only used for old API without mapping + expected_shape = [in_array.shape[0], in_array.shape[1], 1] + obs_proto.shape.extend(expected_shape) + else: + obs_proto.shape.extend(in_array.shape) + return obs_proto + + +# test helper function for new C# API (with compressed channel mapping) +def generate_compressed_proto_obs_with_mapping( + in_array: np.ndarray, mapping: List[int] +) -> ObservationProto: + obs_proto = ObservationProto() + obs_proto.compressed_data = generate_compressed_data(in_array) + obs_proto.compression_type = PNG + if mapping is not None: + obs_proto.compressed_channel_mapping.extend(mapping) + expected_shape = [ + in_array.shape[0], + in_array.shape[1], + len({m for m in mapping if m >= 0}), + ] + obs_proto.shape.extend(expected_shape) + else: + obs_proto.shape.extend(in_array.shape) return obs_proto @@ -231,7 +259,11 @@ def test_process_visual_observation(): in_array_1 = np.random.rand(128, 64, 3) proto_obs_1 = generate_compressed_proto_obs(in_array_1) in_array_2 = np.random.rand(128, 64, 3) - proto_obs_2 = generate_uncompressed_proto_obs(in_array_2) + in_array_2_mapping = [0, 1, 2] + proto_obs_2 = generate_compressed_proto_obs_with_mapping( + in_array_2, in_array_2_mapping + ) + ap1 = AgentInfoProto() ap1.observations.extend([proto_obs_1]) ap2 = AgentInfoProto() @@ -243,6 +275,44 @@ def test_process_visual_observation(): assert np.allclose(arr[1, :, :, :], in_array_2, atol=0.01) +def test_process_visual_observation_grayscale(): + in_array_1 = np.random.rand(128, 64, 3) + proto_obs_1 = generate_compressed_proto_obs(in_array_1, grayscale=True) + expected_out_array_1 = np.mean(in_array_1, axis=2, keepdims=True) + in_array_2 = np.random.rand(128, 64, 3) + in_array_2_mapping = [0, 0, 0] + proto_obs_2 = generate_compressed_proto_obs_with_mapping( + in_array_2, in_array_2_mapping + ) + expected_out_array_2 = np.mean(in_array_2, axis=2, keepdims=True) + + ap1 = AgentInfoProto() + ap1.observations.extend([proto_obs_1]) + ap2 = AgentInfoProto() + ap2.observations.extend([proto_obs_2]) + ap_list = [ap1, ap2] + arr = _process_visual_observation(0, (128, 64, 1), ap_list) + assert list(arr.shape) == [2, 128, 64, 1] + assert np.allclose(arr[0, :, :, :], expected_out_array_1, atol=0.01) + assert np.allclose(arr[1, :, :, :], expected_out_array_2, atol=0.01) + + +def test_process_visual_observation_padded_channels(): + in_array_1 = np.random.rand(128, 64, 12) + in_array_1_mapping = [0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1] + proto_obs_1 = generate_compressed_proto_obs_with_mapping( + in_array_1, in_array_1_mapping + ) + expected_out_array_1 = np.take(in_array_1, [0, 1, 2, 3, 6, 7, 8, 9], axis=2) + + ap1 = AgentInfoProto() + ap1.observations.extend([proto_obs_1]) + ap_list = [ap1] + arr = _process_visual_observation(0, (128, 64, 8), ap_list) + assert list(arr.shape) == [1, 128, 64, 8] + assert np.allclose(arr[0, :, :, :], expected_out_array_1, atol=0.01) + + def test_process_visual_observation_bad_shape(): in_array_1 = np.random.rand(128, 64, 3) proto_obs_1 = generate_compressed_proto_obs(in_array_1) diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto index 99e78a4e1f..7f03b20886 100644 --- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto +++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto @@ -13,4 +13,7 @@ message UnityRLCapabilitiesProto { // concatenated PNG files for compressed visual observations with >3 channels. bool concatenatedPngObservations = 2; + + // compression mapping for stacking compressed observations. + bool compressedChannelMapping = 3; } diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto index 3b57ba5bbd..6bde365afc 100644 --- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto +++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto @@ -19,4 +19,5 @@ message ObservationProto { bytes compressed_data = 3; FloatData float_data = 4; } + repeated int32 compressed_channel_mapping = 5; }