diff --git a/UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs b/UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs index 2e7d27cbf9..053072d785 100644 --- a/UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs +++ b/UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs @@ -3,6 +3,7 @@ using System.IO.Abstractions.TestingHelpers; using System.Reflection; using MLAgents.CommunicatorObjects; +using MLAgents.Sensor; namespace MLAgents.Tests { @@ -64,7 +65,7 @@ public void TestStoreInitalize() storedVectorActions = new[] { 0f, 1f }, }; - demoStore.Record(agentInfo, new System.Collections.Generic.List()); + demoStore.Record(agentInfo, new System.Collections.Generic.List()); demoStore.Close(); } diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Agent.cs b/UnitySDK/Assets/ML-Agents/Scripts/Agent.cs index 148afbb370..e576af8eec 100644 --- a/UnitySDK/Assets/ML-Agents/Scripts/Agent.cs +++ b/UnitySDK/Assets/ML-Agents/Scripts/Agent.cs @@ -235,13 +235,6 @@ public AgentInfo Info /// public VectorSensor collectObservationsSensor; - /// - /// Internal buffer used for generating float observations. - /// - float[] m_VectorSensorBuffer; - - WriteAdapter m_WriteAdapter = new WriteAdapter(); - /// MonoBehaviour function that is called when the attached GameObject /// becomes enabled or active. void OnEnable() @@ -546,8 +539,6 @@ void SendInfoToBrain() } m_Info.actionMasks = m_ActionMasker.GetMask(); - // var param = m_PolicyFactory.brainParameters; // look, no brain params! - m_Info.reward = m_Reward; m_Info.done = m_Done; m_Info.maxStepReached = m_MaxStepReached; @@ -557,19 +548,7 @@ void SendInfoToBrain() if (m_Recorder != null && m_Recorder.record && Application.isEditor) { - - if (m_VectorSensorBuffer == null) - { - // Create a buffer for writing uncompressed (i.e. float) sensor data to - m_VectorSensorBuffer = new float[sensors.GetSensorFloatObservationSize()]; - } - - // This is a bit of a hack - if we're in inference mode, observations won't be generated - // But we need these to be generated for the recorder. So generate them here. - var observations = new List(); - GenerateSensorData(sensors, m_VectorSensorBuffer, m_WriteAdapter, observations); - - m_Recorder.WriteExperience(m_Info, observations); + m_Recorder.WriteExperience(m_Info, sensors); } } @@ -592,7 +571,7 @@ void UpdateSensors() /// A float array that will be used as buffer when generating the observations. Must /// be at least the same length as the total number of uncompressed floats in the observations /// The WriteAdapter that will be used to write the ISensor data to the observations - /// A list of observations outputs. This argument will be modified by this method.// + /// A list of observations outputs. This argument will be modified by this method.// public static void GenerateSensorData(List sensors, float[] buffer, WriteAdapter adapter, List observations) { int floatsWritten = 0; diff --git a/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs b/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs index 33881f7102..3f61b6c815 100644 --- a/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs +++ b/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs @@ -70,9 +70,9 @@ public static string SanitizeName(string demoName, int maxNameLength) /// /// Forwards AgentInfo to Demonstration Store. /// - public void WriteExperience(AgentInfo info, List observations) + public void WriteExperience(AgentInfo info, List sensors) { - m_DemoStore.Record(info, observations); + m_DemoStore.Record(info, sensors); } public void Close() diff --git a/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs b/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs index 0e42ffa764..de1f8b2c04 100644 --- a/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs +++ b/UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs @@ -20,6 +20,7 @@ public class DemonstrationStore DemonstrationMetaData m_MetaData; Stream m_Writer; float m_CumulativeReward; + WriteAdapter m_WriteAdapter = new WriteAdapter(); public DemonstrationStore(IFileSystem fileSystem) { @@ -92,7 +93,7 @@ void WriteBrainParameters(string brainName, BrainParameters brainParameters) /// /// Write AgentInfo experience to file. /// - public void Record(AgentInfo info, List observations) + public void Record(AgentInfo info, List sensors) { // Increment meta-data counters. m_MetaData.numberExperiences++; @@ -102,8 +103,13 @@ public void Record(AgentInfo info, List observations) EndEpisode(); } - // Write AgentInfo to file. - var agentProto = info.ToInfoActionPairProto(observations); + // Generate observations and add AgentInfo to file. + var agentProto = info.ToInfoActionPairProto(); + foreach (var sensor in sensors) + { + agentProto.AgentInfo.Observations.Add(sensor.GetObservationProto(m_WriteAdapter)); + } + agentProto.WriteDelimitedTo(m_Writer); } diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs index d2226fc4fc..0e44f8deb6 100644 --- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs +++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs @@ -16,9 +16,9 @@ public static class GrpcExtensions /// Converts a AgentInfo to a protobuf generated AgentInfoActionPairProto /// /// The protobuf version of the AgentInfoActionPairProto. - public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai, List observations) + public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai) { - var agentInfoProto = ai.ToAgentInfoProto(observations); + var agentInfoProto = ai.ToAgentInfoProto(); var agentActionProto = new AgentActionProto { @@ -36,7 +36,7 @@ public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai, /// Converts a AgentInfo to a protobuf generated AgentInfoProto /// /// The protobuf version of the AgentInfo. - public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai, List observations) + public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai) { var agentInfoProto = new AgentInfoProto { @@ -51,14 +51,6 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai, List + /// Generate an ObservationProto for the sensor using the provided WriteAdapter. + /// This is equivalent to producing an Observation and calling Observation.ToProto(), + /// but avoid some intermediate memory allocations. + /// + /// + /// + /// + public static ObservationProto GetObservationProto(this ISensor sensor, WriteAdapter writeAdapter) + { + var shape = sensor.GetObservationShape(); + ObservationProto observationProto = null; + if (sensor.GetCompressionType() == SensorCompressionType.None) + { + var numFloats = sensor.ObservationSize(); + var floatDataProto = new ObservationProto.Types.FloatData(); + // Resize the float array + // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530 + for (var i = 0; i < numFloats; i++) + { + floatDataProto.Data.Add(0.0f); + } + + writeAdapter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0); + sensor.Write(writeAdapter); + + observationProto = new ObservationProto + { + FloatData = floatDataProto, + CompressionType = (CompressionTypeProto)SensorCompressionType.None, + }; + } + else + { + observationProto = new ObservationProto + { + CompressedData = ByteString.CopyFrom(sensor.GetCompressedObservation()), + CompressionType = (CompressionTypeProto)sensor.GetCompressionType(), + }; + } + observationProto.Shape.AddRange(shape); + return observationProto; + } } } diff --git a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs index e54a52b6a9..b2b2eb29f1 100644 --- a/UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs +++ b/UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs @@ -35,8 +35,6 @@ public struct IdCallbackPair List m_BehaviorNames = new List(); bool m_NeedCommunicateThisStep; - float[] m_VectorObservationBuffer = new float[0]; - List m_ObservationBuffer = new List(); WriteAdapter m_WriteAdapter = new WriteAdapter(); Dictionary m_SensorShapeValidators = new Dictionary(); Dictionary> m_ActionCallbacks = new Dictionary>(); @@ -239,18 +237,12 @@ public void DecideBatch() } /// - /// Sends the observations of one Agent. + /// Sends the observations of one Agent. /// /// Batch Key. /// Agent info. public void PutObservations(string brainKey, AgentInfo info, List sensors, Action action) { - int numFloatObservations = sensors.GetSensorFloatObservationSize(); - if (m_VectorObservationBuffer.Length < numFloatObservations) - { - m_VectorObservationBuffer = new float[numFloatObservations]; - } - # if DEBUG if (!m_SensorShapeValidators.ContainsKey(brainKey)) { @@ -259,16 +251,21 @@ public void PutObservations(string brainKey, AgentInfo info, List senso m_SensorShapeValidators[brainKey].ValidateSensors(sensors); #endif - using (TimerStack.Instance.Scoped("GenerateSensorData")) - { - Agent.GenerateSensorData(sensors, m_VectorObservationBuffer, m_WriteAdapter, m_ObservationBuffer); - } using (TimerStack.Instance.Scoped("AgentInfo.ToProto")) { - var agentInfoProto = info.ToAgentInfoProto(m_ObservationBuffer); + var agentInfoProto = info.ToAgentInfoProto(); + + using (TimerStack.Instance.Scoped("GenerateSensorData")) + { + foreach (var sensor in sensors) + { + var obsProto = sensor.GetObservationProto(m_WriteAdapter); + agentInfoProto.Observations.Add(obsProto); + } + } m_CurrentUnityRlOutput.AgentInfos[brainKey].Value.Add(agentInfoProto); } - m_ObservationBuffer.Clear(); + m_NeedCommunicateThisStep = true; if (!m_ActionCallbacks.ContainsKey(brainKey)) { @@ -451,7 +448,7 @@ void UpdateSentBrainParameters(UnityRLInitializationOutputProto output) #region Handling side channels /// - /// Registers a side channel to the communicator. The side channel will exchange + /// Registers a side channel to the communicator. The side channel will exchange /// messages with its Python equivalent. /// /// The side channel to be registered.