Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b1d4c89
add visual stacking
Sep 10, 2020
d8612a0
stack observations on the last dimemsion
Sep 11, 2020
8c5f4ea
add stacking to renderTextureComponent
Sep 11, 2020
eadb6d0
fix shape
Sep 11, 2020
d8a36e4
add comments
Sep 11, 2020
54a5280
add comments
Sep 11, 2020
c7eb180
small fixes
Sep 11, 2020
4f122ed
add compression mapping
Sep 18, 2020
240d280
Merge branch 'master' into develop-vis-stack
Sep 21, 2020
d24ae6d
change mapping to 0-index
Sep 21, 2020
38bfb8b
Add docstring for mapping and change the name in proto
Sep 21, 2020
8b5f6db
add tests
Sep 22, 2020
2743658
update communicator version
Sep 22, 2020
e9f4d14
update communicator version
Sep 23, 2020
3b3e08d
fix test
Sep 23, 2020
bb011fc
revert unneeded changes
Sep 23, 2020
f023297
Merge branch 'master' into develop-vis-stack
Sep 24, 2020
68597ab
add to change log
Sep 24, 2020
ab4fbeb
name changes and bug fix
Sep 26, 2020
7a65314
change ICompressibleSensor to ISparseChannelSensor
Sep 26, 2020
071b7a4
fix typo
Sep 26, 2020
b5bb5dd
Update com.unity.ml-agents/CHANGELOG.md
Sep 29, 2020
75dc21b
refactor
Sep 29, 2020
70cf227
add checks for mapping
Sep 29, 2020
1202030
comment
Sep 29, 2020
bd19165
fix test
Sep 29, 2020
d240077
add c# tests
Sep 29, 2020
a515dc9
Update com.unity.ml-agents/CHANGELOG.md
Sep 30, 2020
ac2a8ca
Update com.unity.ml-agents/CHANGELOG.md
Oct 1, 2020
64787b0
add tests
Oct 2, 2020
92b34b0
fix Write for uncompressed obs to stack in the last dimension
Oct 2, 2020
1af0c84
add tests
Oct 3, 2020
62ee54d
add files for tests
Oct 3, 2020
b38154c
fix bug
Oct 6, 2020
1a77ceb
fix typo
Oct 6, 2020
89ca381
add more test
Oct 6, 2020
8ef815e
minor comment formatting
Oct 6, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,19 @@ and this project adheres to
- Added the Random Network Distillation (RND) intrinsic reward signal to the Pytorch
trainers. To use RND, add a `rnd` section to the `reward_signals` section of your
yaml configuration file. [More information here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-Configuration-File.md#rnd-intrinsic-reward)

### Minor Changes
#### com.unity.ml-agents (C#)
- Stacking for compressed observations is now supported. An addtional setting
option `Observation Stacks` is added in editor to sensor components that support
compressed observations. A new class `ISparseChannelSensor` with an
additional method `GetCompressedChannelMapping()`is added to generate a mapping
of the channels in compressed data to the actual channel after decompression,
for the python side to decompress correctly. (#4476)
#### ml-agents / ml-agents-envs / gym-unity (Python)

- The Communication API was changed to 1.2.0 to indicate support for stacked
compressed observation. A new entry `compressed_channel_mapping` is added to the
proto to handle decompression correctly. Newer versions of the package that wish to
make use of this will also need a compatible version of the Python trainers. (#4476)
### Bug Fixes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
Expand Down
1 change: 1 addition & 0 deletions com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public override void OnInspectorGUI()
EditorGUILayout.PropertyField(so.FindProperty("m_Width"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public override void OnInspectorGUI()
EditorGUILayout.PropertyField(so.FindProperty("m_RenderTexture"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
}
EditorGUI.EndDisabledGroup();

Expand Down
6 changes: 5 additions & 1 deletion com.unity.ml-agents/Runtime/Academy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,13 @@ public class Academy : IDisposable
/// <term>1.1.0</term>
/// <description>Support concatenated PNGs for compressed observations.</description>
/// </item>
/// <item>
/// <term>1.2.0</term>
/// <description>Support compression mapping for stacked compressed observations.</description>
/// </item>
/// </list>
/// </remarks>
const string k_ApiVersion = "1.1.0";
const string k_ApiVersion = "1.2.0";

/// <summary>
/// Unity package version of com.unity.ml-agents.
Expand Down
62 changes: 57 additions & 5 deletions com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ public static List<float[]> ToAgentActionList(this UnityRLInputProto.Types.ListA
/// <summary>
/// Static flag to make sure that we only fire the warning once.
/// </summary>
private static bool s_HaveWarnedAboutTrainerCapabilities = false;
private static bool s_HaveWarnedTrainerCapabilitiesMultiPng = false;
private static bool s_HaveWarnedTrainerCapabilitiesMapping = false;

/// <summary>
/// Generate an ObservationProto for the sensor using the provided ObservationWriter.
Expand All @@ -243,10 +244,27 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
if (!trainerCanHandle)
{
if (!s_HaveWarnedAboutTrainerCapabilities)
if (!s_HaveWarnedTrainerCapabilitiesMultiPng)
{
Debug.LogWarning($"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}.");
s_HaveWarnedAboutTrainerCapabilities = true;
s_HaveWarnedTrainerCapabilitiesMultiPng = true;
}
compressionType = SensorCompressionType.None;
}
}
// Check capabilities if we need mapping for compressed observations
if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3)
{
var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping;
var isTrivialMapping = IsTrivialMapping(sensor);
if (!trainerCanHandleMapping && !isTrivialMapping)
{
if (!s_HaveWarnedTrainerCapabilitiesMapping)
{
Debug.LogWarning($"The sensor {sensor.GetName()} is using non-trivial mapping and " +
"the attached trainer doesn't support compression mapping. " +
"Switching to uncompressed observations.");
s_HaveWarnedTrainerCapabilitiesMapping = true;
}
compressionType = SensorCompressionType.None;
}
Expand Down Expand Up @@ -283,12 +301,16 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
"return SensorCompressionType.None from GetCompressionType()."
);
}

observationProto = new ObservationProto
{
CompressedData = ByteString.CopyFrom(compressedObs),
CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
};
var compressibleSensor = sensor as ISparseChannelSensor;
if (compressibleSensor != null)
{
observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
}
}
observationProto.Shape.AddRange(shape);
return observationProto;
Expand All @@ -300,7 +322,8 @@ public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto
return new UnityRLCapabilities
{
BaseRLCapabilities = proto.BaseRLCapabilities,
ConcatenatedPngObservations = proto.ConcatenatedPngObservations
ConcatenatedPngObservations = proto.ConcatenatedPngObservations,
CompressedChannelMapping = proto.CompressedChannelMapping,
};
}

Expand All @@ -310,7 +333,36 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps)
{
BaseRLCapabilities = rlCaps.BaseRLCapabilities,
ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations,
CompressedChannelMapping = rlCaps.CompressedChannelMapping,
};
}

internal static bool IsTrivialMapping(ISensor sensor)
{
var compressibleSensor = sensor as ISparseChannelSensor;
if (compressibleSensor is null)
{
return true;
}
var mapping = compressibleSensor.GetCompressedChannelMapping();
if (mapping == null)
{
return true;
}
// check if mapping equals zero mapping
if (mapping.Length == 3 && mapping.All(m => m == 0))
{
return true;
}
// check if mapping equals identity mapping
for (var i = 0; i < mapping.Length; i++)
{
if (mapping[i] != i)
{
return false;
}
}
return true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@ internal class UnityRLCapabilities
{
public bool BaseRLCapabilities;
public bool ConcatenatedPngObservations;
public bool CompressedChannelMapping;

/// <summary>
/// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This
/// struct will be used to inform users if and when they are using C# / Trainer features that are mismatched.
/// </summary>
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true)
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true, bool compressedChannelMapping = true)
{
BaseRLCapabilities = baseRlCapabilities;
ConcatenatedPngObservations = concatenatedPngObservations;
CompressedChannelMapping = compressedChannelMapping;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ static CapabilitiesReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiWwoYVW5pdHlSTENh",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMifQoYVW5pdHlSTENh",
"cGFiaWxpdGllc1Byb3RvEhoKEmJhc2VSTENhcGFiaWxpdGllcxgBIAEoCBIj",
"Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAhCJaoCIlVuaXR5",
"Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
"Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAgSIAoYY29tcHJl",
"c3NlZENoYW5uZWxNYXBwaW5nGAMgASgIQiWqAiJVbml0eS5NTEFnZW50cy5D",
"b21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping" }, null, null, null)
}));
}
#endregion
Expand Down Expand Up @@ -71,6 +72,7 @@ public UnityRLCapabilitiesProto() {
public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() {
baseRLCapabilities_ = other.baseRLCapabilities_;
concatenatedPngObservations_ = other.concatenatedPngObservations_;
compressedChannelMapping_ = other.compressedChannelMapping_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

Expand Down Expand Up @@ -107,6 +109,20 @@ public bool ConcatenatedPngObservations {
}
}

/// <summary>Field number for the "compressedChannelMapping" field.</summary>
public const int CompressedChannelMappingFieldNumber = 3;
private bool compressedChannelMapping_;
/// <summary>
/// compression mapping for stacking compressed observations.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool CompressedChannelMapping {
get { return compressedChannelMapping_; }
set {
compressedChannelMapping_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);
Expand All @@ -122,6 +138,7 @@ public bool Equals(UnityRLCapabilitiesProto other) {
}
if (BaseRLCapabilities != other.BaseRLCapabilities) return false;
if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false;
if (CompressedChannelMapping != other.CompressedChannelMapping) return false;
return Equals(_unknownFields, other._unknownFields);
}

Expand All @@ -130,6 +147,7 @@ public override int GetHashCode() {
int hash = 1;
if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode();
if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode();
if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
Expand All @@ -151,6 +169,10 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(16);
output.WriteBool(ConcatenatedPngObservations);
}
if (CompressedChannelMapping != false) {
output.WriteRawTag(24);
output.WriteBool(CompressedChannelMapping);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand All @@ -165,6 +187,9 @@ public int CalculateSize() {
if (ConcatenatedPngObservations != false) {
size += 1 + 1;
}
if (CompressedChannelMapping != false) {
size += 1 + 1;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand All @@ -182,6 +207,9 @@ public void MergeFrom(UnityRLCapabilitiesProto other) {
if (other.ConcatenatedPngObservations != false) {
ConcatenatedPngObservations = other.ConcatenatedPngObservations;
}
if (other.CompressedChannelMapping != false) {
CompressedChannelMapping = other.CompressedChannelMapping;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

Expand All @@ -201,6 +229,10 @@ public void MergeFrom(pb::CodedInputStream input) {
ConcatenatedPngObservations = input.ReadBool();
break;
}
case 24: {
CompressedChannelMapping = input.ReadBool();
break;
}
}
}
}
Expand Down
Loading