diff --git a/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs b/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs index 2e2d97dc2a..4fdcd74bfe 100644 --- a/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs +++ b/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs @@ -15,12 +15,12 @@ public override void InitializeAgent() SetResetParameters(); } - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { - AddVectorObs(gameObject.transform.rotation.z); - AddVectorObs(gameObject.transform.rotation.x); - AddVectorObs(ball.transform.position - gameObject.transform.position); - AddVectorObs(m_BallRb.velocity); + sensor.AddObservation(gameObject.transform.rotation.z); + sensor.AddObservation(gameObject.transform.rotation.x); + sensor.AddObservation(ball.transform.position - gameObject.transform.position); + sensor.AddObservation(m_BallRb.velocity); } public override void AgentAction(float[] vectorAction) diff --git a/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs b/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs index ef4cca677f..bc1777c4fb 100644 --- a/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs +++ b/Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs @@ -15,11 +15,11 @@ public override void InitializeAgent() SetResetParameters(); } - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { - AddVectorObs(gameObject.transform.rotation.z); - AddVectorObs(gameObject.transform.rotation.x); - AddVectorObs((ball.transform.position - gameObject.transform.position)); + sensor.AddObservation(gameObject.transform.rotation.z); + sensor.AddObservation(gameObject.transform.rotation.x); + sensor.AddObservation((ball.transform.position - gameObject.transform.position)); } public override void AgentAction(float[] vectorAction) diff --git a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs index 280731eeba..e4face4e73 100644 --- a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs @@ -18,9 +18,9 @@ public override void InitializeAgent() { } - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { - AddVectorObs(m_Position, 20); + sensor.AddOneHotObservation(m_Position, 20); } public override void AgentAction(float[] vectorAction) diff --git a/Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs b/Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs index 841d8810e3..bd8f19e4ff 100644 --- a/Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs @@ -25,10 +25,10 @@ public override void InitializeAgent() SetResetParameters(); } - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { - AddVectorObs(gameObject.transform.localPosition); - AddVectorObs(target.transform.localPosition); + sensor.AddObservation(gameObject.transform.localPosition); + sensor.AddObservation(target.transform.localPosition); } public override void AgentAction(float[] vectorAction) diff --git a/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs b/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs index e4b48186a8..8a3788eca5 100644 --- a/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs @@ -72,29 +72,29 @@ public override void InitializeAgent() /// /// Add relevant information on each body part to observations. /// - public void CollectObservationBodyPart(BodyPart bp) + public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor) { var rb = bp.rb; - AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground + sensor.AddObservation(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground var velocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.velocity); - AddVectorObs(velocityRelativeToLookRotationToTarget); + sensor.AddObservation(velocityRelativeToLookRotationToTarget); var angularVelocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.angularVelocity); - AddVectorObs(angularVelocityRelativeToLookRotationToTarget); + sensor.AddObservation(angularVelocityRelativeToLookRotationToTarget); if (bp.rb.transform != body) { var localPosRelToBody = body.InverseTransformPoint(rb.position); - AddVectorObs(localPosRelToBody); - AddVectorObs(bp.currentXNormalizedRot); // Current x rot - AddVectorObs(bp.currentYNormalizedRot); // Current y rot - AddVectorObs(bp.currentZNormalizedRot); // Current z rot - AddVectorObs(bp.currentStrength / m_JdController.maxJointForceLimit); + sensor.AddObservation(localPosRelToBody); + sensor.AddObservation(bp.currentXNormalizedRot); // Current x rot + sensor.AddObservation(bp.currentYNormalizedRot); // Current y rot + sensor.AddObservation(bp.currentZNormalizedRot); // Current z rot + sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit); } } - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { m_JdController.GetCurrentJointForces(); @@ -106,21 +106,21 @@ public override void CollectObservations() RaycastHit hit; if (Physics.Raycast(body.position, Vector3.down, out hit, 10.0f)) { - AddVectorObs(hit.distance); + sensor.AddObservation(hit.distance); } else - AddVectorObs(10.0f); + sensor.AddObservation(10.0f); // Forward & up to help with orientation var bodyForwardRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(body.forward); - AddVectorObs(bodyForwardRelativeToLookRotationToTarget); + sensor.AddObservation(bodyForwardRelativeToLookRotationToTarget); var bodyUpRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(body.up); - AddVectorObs(bodyUpRelativeToLookRotationToTarget); + sensor.AddObservation(bodyUpRelativeToLookRotationToTarget); foreach (var bodyPart in m_JdController.bodyPartsDict.Values) { - CollectObservationBodyPart(bodyPart); + CollectObservationBodyPart(bodyPart, sensor); } } diff --git a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs index cad3cb9757..48de78becc 100644 --- a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs +++ b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs @@ -38,15 +38,15 @@ public override void InitializeAgent() SetResetParameters(); } - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { if (useVectorObs) { var localVelocity = transform.InverseTransformDirection(m_AgentRb.velocity); - AddVectorObs(localVelocity.x); - AddVectorObs(localVelocity.z); - AddVectorObs(System.Convert.ToInt32(m_Frozen)); - AddVectorObs(System.Convert.ToInt32(m_Shoot)); + sensor.AddObservation(localVelocity.x); + sensor.AddObservation(localVelocity.z); + sensor.AddObservation(System.Convert.ToInt32(m_Frozen)); + sensor.AddObservation(System.Convert.ToInt32(m_Shoot)); } } diff --git a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs index e470029c5c..7fe523e3b9 100644 --- a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs +++ b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs @@ -31,7 +31,7 @@ public override void InitializeAgent() { } - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { // There are no numeric observations to collect as this environment uses visual // observations. diff --git a/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs b/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs index a25add4ecb..6340bc0293 100644 --- a/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs @@ -26,11 +26,11 @@ public override void InitializeAgent() m_GroundMaterial = m_GroundRenderer.material; } - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { if (useVectorObs) { - AddVectorObs(GetStepCount() / (float)maxStep); + sensor.AddObservation(GetStepCount() / (float)maxStep); } } diff --git a/Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs b/Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs index ef55288911..60b1260b44 100644 --- a/Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs @@ -21,12 +21,12 @@ public override void InitializeAgent() m_SwitchLogic = areaSwitch.GetComponent(); } - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { if (useVectorObs) { - AddVectorObs(m_SwitchLogic.GetState()); - AddVectorObs(transform.InverseTransformDirection(m_AgentRb.velocity)); + sensor.AddObservation(m_SwitchLogic.GetState()); + sensor.AddObservation(transform.InverseTransformDirection(m_AgentRb.velocity)); } } diff --git a/Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs b/Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs index c8837b1e96..25c11ff546 100644 --- a/Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs @@ -35,22 +35,22 @@ public override void InitializeAgent() /// We collect the normalized rotations, angularal velocities, and velocities of both /// limbs of the reacher as well as the relative position of the target and hand. /// - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { - AddVectorObs(pendulumA.transform.localPosition); - AddVectorObs(pendulumA.transform.rotation); - AddVectorObs(m_RbA.angularVelocity); - AddVectorObs(m_RbA.velocity); + sensor.AddObservation(pendulumA.transform.localPosition); + sensor.AddObservation(pendulumA.transform.rotation); + sensor.AddObservation(m_RbA.angularVelocity); + sensor.AddObservation(m_RbA.velocity); - AddVectorObs(pendulumB.transform.localPosition); - AddVectorObs(pendulumB.transform.rotation); - AddVectorObs(m_RbB.angularVelocity); - AddVectorObs(m_RbB.velocity); + sensor.AddObservation(pendulumB.transform.localPosition); + sensor.AddObservation(pendulumB.transform.rotation); + sensor.AddObservation(m_RbB.angularVelocity); + sensor.AddObservation(m_RbB.velocity); - AddVectorObs(goal.transform.localPosition); - AddVectorObs(hand.transform.localPosition); + sensor.AddObservation(goal.transform.localPosition); + sensor.AddObservation(hand.transform.localPosition); - AddVectorObs(m_GoalSpeed); + sensor.AddObservation(m_GoalSpeed); } /// diff --git a/Project/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs b/Project/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs index 4744777114..0c6b1a99d9 100644 --- a/Project/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs @@ -3,7 +3,7 @@ public class TemplateAgent : Agent { - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { } diff --git a/Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs b/Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs index 2ba6f1bfbc..fb6ec9920d 100644 --- a/Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs @@ -43,19 +43,19 @@ public override void InitializeAgent() SetResetParameters(); } - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { - AddVectorObs(m_InvertMult * (transform.position.x - myArea.transform.position.x)); - AddVectorObs(transform.position.y - myArea.transform.position.y); - AddVectorObs(m_InvertMult * m_AgentRb.velocity.x); - AddVectorObs(m_AgentRb.velocity.y); + sensor.AddObservation(m_InvertMult * (transform.position.x - myArea.transform.position.x)); + sensor.AddObservation(transform.position.y - myArea.transform.position.y); + sensor.AddObservation(m_InvertMult * m_AgentRb.velocity.x); + sensor.AddObservation(m_AgentRb.velocity.y); - AddVectorObs(m_InvertMult * (ball.transform.position.x - myArea.transform.position.x)); - AddVectorObs(ball.transform.position.y - myArea.transform.position.y); - AddVectorObs(m_InvertMult * m_BallRb.velocity.x); - AddVectorObs(m_BallRb.velocity.y); + sensor.AddObservation(m_InvertMult * (ball.transform.position.x - myArea.transform.position.x)); + sensor.AddObservation(ball.transform.position.y - myArea.transform.position.y); + sensor.AddObservation(m_InvertMult * m_BallRb.velocity.x); + sensor.AddObservation(m_BallRb.velocity.y); - AddVectorObs(m_InvertMult * gameObject.transform.rotation.z); + sensor.AddObservation(m_InvertMult * gameObject.transform.rotation.z); } public override void AgentAction(float[] vectorAction) diff --git a/Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs b/Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs index 6dd6d9eef3..44df4f8260 100644 --- a/Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs +++ b/Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs @@ -66,40 +66,40 @@ public override void InitializeAgent() /// /// Add relevant information on each body part to observations. /// - public void CollectObservationBodyPart(BodyPart bp) + public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor) { var rb = bp.rb; - AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Is this bp touching the ground - AddVectorObs(rb.velocity); - AddVectorObs(rb.angularVelocity); + sensor.AddObservation(bp.groundContact.touchingGround ? 1 : 0); // Is this bp touching the ground + sensor.AddObservation(rb.velocity); + sensor.AddObservation(rb.angularVelocity); var localPosRelToHips = hips.InverseTransformPoint(rb.position); - AddVectorObs(localPosRelToHips); + sensor.AddObservation(localPosRelToHips); if (bp.rb.transform != hips && bp.rb.transform != handL && bp.rb.transform != handR && bp.rb.transform != footL && bp.rb.transform != footR && bp.rb.transform != head) { - AddVectorObs(bp.currentXNormalizedRot); - AddVectorObs(bp.currentYNormalizedRot); - AddVectorObs(bp.currentZNormalizedRot); - AddVectorObs(bp.currentStrength / m_JdController.maxJointForceLimit); + sensor.AddObservation(bp.currentXNormalizedRot); + sensor.AddObservation(bp.currentYNormalizedRot); + sensor.AddObservation(bp.currentZNormalizedRot); + sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit); } } /// /// Loop over body parts to add them to observation. /// - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { m_JdController.GetCurrentJointForces(); - AddVectorObs(m_DirToTarget.normalized); - AddVectorObs(m_JdController.bodyPartsDict[hips].rb.position); - AddVectorObs(hips.forward); - AddVectorObs(hips.up); + sensor.AddObservation(m_DirToTarget.normalized); + sensor.AddObservation(m_JdController.bodyPartsDict[hips].rb.position); + sensor.AddObservation(hips.forward); + sensor.AddObservation(hips.up); foreach (var bodyPart in m_JdController.bodyPartsDict.Values) { - CollectObservationBodyPart(bodyPart); + CollectObservationBodyPart(bodyPart, sensor); } } diff --git a/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs b/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs index 7ad751ca8c..194b55e9af 100644 --- a/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs +++ b/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs @@ -132,12 +132,12 @@ void MoveTowards( } } - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { var agentPos = m_AgentRb.position - ground.transform.position; - AddVectorObs(agentPos / 20f); - AddVectorObs(DoGroundCheck(true) ? 1 : 0); + sensor.AddObservation(agentPos / 20f); + sensor.AddObservation(DoGroundCheck(true) ? 1 : 0); } /// diff --git a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs index 36d605522c..1f4fc7df00 100644 --- a/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs +++ b/com.unity.ml-agents/Editor/BehaviorParametersEditor.cs @@ -1,7 +1,6 @@ using UnityEngine; using UnityEditor; using Barracuda; -using MLAgents.Sensor; namespace MLAgents { diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 45433a07ce..65e477b999 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using UnityEngine; using Barracuda; -using MLAgents.Sensor; using UnityEngine.Serialization; namespace MLAgents @@ -482,7 +481,7 @@ void SendInfoToBrain() UpdateSensors(); using (TimerStack.Instance.Scoped("CollectObservations")) { - CollectObservations(); + CollectObservations(collectObservationsSensor); } m_Info.actionMasks = m_ActionMasker.GetMask(); @@ -508,30 +507,31 @@ void UpdateSensors() } /// - /// Collects the (vector, visual) observations of the agent. + /// Collects the vector observations of the agent. /// The agent observation describes the current environment from the /// perspective of the agent. /// /// - /// Simply, an agents observation is any environment information that helps - /// the Agent acheive its goal. For example, for a fighting Agent, its + /// An agents observation is any environment information that helps + /// the Agent achieve its goal. For example, for a fighting Agent, its /// observation could include distances to friends or enemies, or the /// current level of ammunition at its disposal. /// Recall that an Agent may attach vector or visual observations. - /// Vector observations are added by calling the provided helper methods: - /// - - /// - - /// - - /// - + /// Vector observations are added by calling the provided helper methods + /// on the VectorSensor input: + /// - + /// - + /// - + /// - /// - /// AddVectorObs(float[]) /// /// - /// AddVectorObs(List{float}) /// - /// - - /// - - /// - + /// - + /// - + /// - /// Depending on your environment, any combination of these helpers can /// be used. They just need to be used in the exact same order each time /// this method is called and the resulting size of the vector observation @@ -539,7 +539,7 @@ void UpdateSensors() /// Visual observations are implicitly added from the cameras attached to /// the Agent. /// - public virtual void CollectObservations() + public virtual void CollectObservations(VectorSensor sensor) { } @@ -593,81 +593,6 @@ protected void SetActionMask(int branch, IEnumerable actionIndices) m_ActionMasker.SetActionMask(branch, actionIndices); } - /// - /// Adds a float observation to the vector observations of the agent. - /// Increases the size of the agents vector observation by 1. - /// - /// Observation. - protected void AddVectorObs(float observation) - { - collectObservationsSensor.AddObservation(observation); - } - - /// - /// Adds an integer observation to the vector observations of the agent. - /// Increases the size of the agents vector observation by 1. - /// - /// Observation. - protected void AddVectorObs(int observation) - { - collectObservationsSensor.AddObservation(observation); - } - - /// - /// Adds an Vector3 observation to the vector observations of the agent. - /// Increases the size of the agents vector observation by 3. - /// - /// Observation. - protected void AddVectorObs(Vector3 observation) - { - collectObservationsSensor.AddObservation(observation); - } - - /// - /// Adds an Vector2 observation to the vector observations of the agent. - /// Increases the size of the agents vector observation by 2. - /// - /// Observation. - protected void AddVectorObs(Vector2 observation) - { - collectObservationsSensor.AddObservation(observation); - } - - /// - /// Adds a collection of float observations to the vector observations of the agent. - /// Increases the size of the agents vector observation by size of the collection. - /// - /// Observation. - protected void AddVectorObs(IEnumerable observation) - { - collectObservationsSensor.AddObservation(observation); - } - - /// - /// Adds a quaternion observation to the vector observations of the agent. - /// Increases the size of the agents vector observation by 4. - /// - /// Observation. - protected void AddVectorObs(Quaternion observation) - { - collectObservationsSensor.AddObservation(observation); - } - - /// - /// Adds a boolean observation to the vector observation of the agent. - /// Increases the size of the agent's vector observation by 1. - /// - /// - protected void AddVectorObs(bool observation) - { - collectObservationsSensor.AddObservation(observation); - } - - protected void AddVectorObs(int observation, int range) - { - collectObservationsSensor.AddOneHotObservation(observation, range); - } - /// /// Specifies the agent behavior at every step based on the provided /// action. diff --git a/com.unity.ml-agents/Runtime/DecisionRequester.cs b/com.unity.ml-agents/Runtime/DecisionRequester.cs index 87ac77c1cc..062bacfb50 100644 --- a/com.unity.ml-agents/Runtime/DecisionRequester.cs +++ b/com.unity.ml-agents/Runtime/DecisionRequester.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using UnityEngine; using Barracuda; -using MLAgents.Sensor; using UnityEngine.Serialization; namespace MLAgents diff --git a/com.unity.ml-agents/Runtime/DemonstrationRecorder.cs b/com.unity.ml-agents/Runtime/DemonstrationRecorder.cs index 62c28be592..621155caff 100644 --- a/com.unity.ml-agents/Runtime/DemonstrationRecorder.cs +++ b/com.unity.ml-agents/Runtime/DemonstrationRecorder.cs @@ -2,7 +2,6 @@ using System.Text.RegularExpressions; using UnityEngine; using System.Collections.Generic; -using MLAgents.Sensor; namespace MLAgents { diff --git a/com.unity.ml-agents/Runtime/DemonstrationStore.cs b/com.unity.ml-agents/Runtime/DemonstrationStore.cs index de1f8b2c04..86624b1fc6 100644 --- a/com.unity.ml-agents/Runtime/DemonstrationStore.cs +++ b/com.unity.ml-agents/Runtime/DemonstrationStore.cs @@ -2,7 +2,6 @@ using System.IO.Abstractions; using Google.Protobuf; using System.Collections.Generic; -using MLAgents.Sensor; namespace MLAgents { diff --git a/com.unity.ml-agents/Runtime/Grpc/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Grpc/GrpcExtensions.cs index 900ca774ba..04f377e5e5 100644 --- a/com.unity.ml-agents/Runtime/Grpc/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Grpc/GrpcExtensions.cs @@ -4,7 +4,6 @@ using Google.Protobuf; using Google.Protobuf.Collections; using MLAgents.CommunicatorObjects; -using MLAgents.Sensor; using UnityEngine; using System.Runtime.CompilerServices; diff --git a/com.unity.ml-agents/Runtime/Grpc/RpcCommunicator.cs b/com.unity.ml-agents/Runtime/Grpc/RpcCommunicator.cs index 5cd18d4f90..21c289292b 100644 --- a/com.unity.ml-agents/Runtime/Grpc/RpcCommunicator.cs +++ b/com.unity.ml-agents/Runtime/Grpc/RpcCommunicator.cs @@ -11,7 +11,6 @@ using MLAgents.CommunicatorObjects; using System.IO; using Google.Protobuf; -using MLAgents.Sensor; namespace MLAgents { diff --git a/com.unity.ml-agents/Runtime/ICommunicator.cs b/com.unity.ml-agents/Runtime/ICommunicator.cs index 952f3997bd..0ff8a4bfba 100644 --- a/com.unity.ml-agents/Runtime/ICommunicator.cs +++ b/com.unity.ml-agents/Runtime/ICommunicator.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using UnityEngine; using MLAgents.CommunicatorObjects; -using MLAgents.Sensor; namespace MLAgents { diff --git a/com.unity.ml-agents/Runtime/InferenceBrain/BarracudaModelParamLoader.cs b/com.unity.ml-agents/Runtime/InferenceBrain/BarracudaModelParamLoader.cs index 1d793ed9ec..8133ee15b0 100644 --- a/com.unity.ml-agents/Runtime/InferenceBrain/BarracudaModelParamLoader.cs +++ b/com.unity.ml-agents/Runtime/InferenceBrain/BarracudaModelParamLoader.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Linq; using Barracuda; -using MLAgents.Sensor; using UnityEngine; namespace MLAgents.InferenceBrain diff --git a/com.unity.ml-agents/Runtime/InferenceBrain/GeneratorImpl.cs b/com.unity.ml-agents/Runtime/InferenceBrain/GeneratorImpl.cs index 07077912ae..287a898a16 100644 --- a/com.unity.ml-agents/Runtime/InferenceBrain/GeneratorImpl.cs +++ b/com.unity.ml-agents/Runtime/InferenceBrain/GeneratorImpl.cs @@ -2,7 +2,6 @@ using System; using Barracuda; using MLAgents.InferenceBrain.Utils; -using MLAgents.Sensor; using UnityEngine; namespace MLAgents.InferenceBrain diff --git a/com.unity.ml-agents/Runtime/InferenceBrain/ModelRunner.cs b/com.unity.ml-agents/Runtime/InferenceBrain/ModelRunner.cs index fd172f1137..299accfd5a 100644 --- a/com.unity.ml-agents/Runtime/InferenceBrain/ModelRunner.cs +++ b/com.unity.ml-agents/Runtime/InferenceBrain/ModelRunner.cs @@ -2,7 +2,6 @@ using Barracuda; using UnityEngine.Profiling; using System; -using MLAgents.Sensor; namespace MLAgents.InferenceBrain { diff --git a/com.unity.ml-agents/Runtime/InferenceBrain/TensorGenerator.cs b/com.unity.ml-agents/Runtime/InferenceBrain/TensorGenerator.cs index 899dd9ce91..98ff7ede39 100644 --- a/com.unity.ml-agents/Runtime/InferenceBrain/TensorGenerator.cs +++ b/com.unity.ml-agents/Runtime/InferenceBrain/TensorGenerator.cs @@ -1,6 +1,5 @@ using System.Collections.Generic; using Barracuda; -using MLAgents.Sensor; namespace MLAgents.InferenceBrain { diff --git a/com.unity.ml-agents/Runtime/Policy/BarracudaPolicy.cs b/com.unity.ml-agents/Runtime/Policy/BarracudaPolicy.cs index 75e256f4fb..e77d40e816 100644 --- a/com.unity.ml-agents/Runtime/Policy/BarracudaPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policy/BarracudaPolicy.cs @@ -3,7 +3,6 @@ using System.Collections.Generic; using MLAgents.InferenceBrain; using System; -using MLAgents.Sensor; namespace MLAgents { diff --git a/com.unity.ml-agents/Runtime/Policy/HeuristicPolicy.cs b/com.unity.ml-agents/Runtime/Policy/HeuristicPolicy.cs index b2bfcddf6c..77acd43845 100644 --- a/com.unity.ml-agents/Runtime/Policy/HeuristicPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policy/HeuristicPolicy.cs @@ -1,4 +1,3 @@ -using MLAgents.Sensor; using System.Collections.Generic; using System; diff --git a/com.unity.ml-agents/Runtime/Policy/IPolicy.cs b/com.unity.ml-agents/Runtime/Policy/IPolicy.cs index 14e9154aab..bc3fdc62c2 100644 --- a/com.unity.ml-agents/Runtime/Policy/IPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policy/IPolicy.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using MLAgents.Sensor; namespace MLAgents { diff --git a/com.unity.ml-agents/Runtime/Policy/RemotePolicy.cs b/com.unity.ml-agents/Runtime/Policy/RemotePolicy.cs index ff56d65471..71784e5ca5 100644 --- a/com.unity.ml-agents/Runtime/Policy/RemotePolicy.cs +++ b/com.unity.ml-agents/Runtime/Policy/RemotePolicy.cs @@ -1,6 +1,5 @@ using UnityEngine; using System.Collections.Generic; -using MLAgents.Sensor; using System; namespace MLAgents diff --git a/com.unity.ml-agents/Runtime/Sensor/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensor/CameraSensor.cs index b51e935f30..22c7b470eb 100644 --- a/com.unity.ml-agents/Runtime/Sensor/CameraSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensor/CameraSensor.cs @@ -1,7 +1,7 @@ using System; using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { public class CameraSensor : ISensor { diff --git a/com.unity.ml-agents/Runtime/Sensor/CameraSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensor/CameraSensorComponent.cs index 4c754d5ec0..39a325e194 100644 --- a/com.unity.ml-agents/Runtime/Sensor/CameraSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensor/CameraSensorComponent.cs @@ -1,7 +1,7 @@ using System; using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { [AddComponentMenu("ML Agents/Camera Sensor", (int)MenuGroup.Sensors)] public class CameraSensorComponent : SensorComponent diff --git a/com.unity.ml-agents/Runtime/Sensor/ISensor.cs b/com.unity.ml-agents/Runtime/Sensor/ISensor.cs index 4978e03629..1f49777eee 100644 --- a/com.unity.ml-agents/Runtime/Sensor/ISensor.cs +++ b/com.unity.ml-agents/Runtime/Sensor/ISensor.cs @@ -1,4 +1,4 @@ -namespace MLAgents.Sensor +namespace MLAgents { public enum SensorCompressionType { diff --git a/com.unity.ml-agents/Runtime/Sensor/Observation.cs b/com.unity.ml-agents/Runtime/Sensor/Observation.cs index fb3f9a7904..f3374d4732 100644 --- a/com.unity.ml-agents/Runtime/Sensor/Observation.cs +++ b/com.unity.ml-agents/Runtime/Sensor/Observation.cs @@ -1,7 +1,7 @@ using System; using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { internal struct Observation { diff --git a/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs b/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs index 4afc3c4145..201d6ed36d 100644 --- a/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { public class RayPerceptionSensor : ISensor { diff --git a/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent2D.cs b/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent2D.cs index a183179e78..5385f146ff 100644 --- a/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent2D.cs +++ b/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent2D.cs @@ -1,6 +1,6 @@ using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { [AddComponentMenu("ML Agents/Ray Perception Sensor 2D", (int)MenuGroup.Sensors)] public class RayPerceptionSensorComponent2D : RayPerceptionSensorComponentBase diff --git a/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent3D.cs b/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent3D.cs index eca147b89e..c3467b3d43 100644 --- a/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent3D.cs +++ b/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent3D.cs @@ -1,7 +1,7 @@ using System; using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { [AddComponentMenu("ML Agents/Ray Perception Sensor 3D", (int)MenuGroup.Sensors)] public class RayPerceptionSensorComponent3D : RayPerceptionSensorComponentBase diff --git a/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponentBase.cs b/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponentBase.cs index e873526ac2..8b502946bc 100644 --- a/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponentBase.cs +++ b/com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponentBase.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { public abstract class RayPerceptionSensorComponentBase : SensorComponent { diff --git a/com.unity.ml-agents/Runtime/Sensor/RenderTextureSensor.cs b/com.unity.ml-agents/Runtime/Sensor/RenderTextureSensor.cs index 4a2fa0f689..a8298b25bf 100644 --- a/com.unity.ml-agents/Runtime/Sensor/RenderTextureSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensor/RenderTextureSensor.cs @@ -1,7 +1,7 @@ using System; using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { public class RenderTextureSensor : ISensor { diff --git a/com.unity.ml-agents/Runtime/Sensor/RenderTextureSensorComponent.cs b/com.unity.ml-agents/Runtime/Sensor/RenderTextureSensorComponent.cs index d986e221a8..1e1ea148dd 100644 --- a/com.unity.ml-agents/Runtime/Sensor/RenderTextureSensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensor/RenderTextureSensorComponent.cs @@ -1,7 +1,7 @@ using System; using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { [AddComponentMenu("ML Agents/Render Texture Sensor", (int)MenuGroup.Sensors)] public class RenderTextureSensorComponent : SensorComponent diff --git a/com.unity.ml-agents/Runtime/Sensor/SensorBase.cs b/com.unity.ml-agents/Runtime/Sensor/SensorBase.cs index 9f4356d209..369140e5d6 100644 --- a/com.unity.ml-agents/Runtime/Sensor/SensorBase.cs +++ b/com.unity.ml-agents/Runtime/Sensor/SensorBase.cs @@ -1,6 +1,6 @@ using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { public abstract class SensorBase : ISensor { diff --git a/com.unity.ml-agents/Runtime/Sensor/SensorComponent.cs b/com.unity.ml-agents/Runtime/Sensor/SensorComponent.cs index 63c1328377..cdc51fd5f4 100644 --- a/com.unity.ml-agents/Runtime/Sensor/SensorComponent.cs +++ b/com.unity.ml-agents/Runtime/Sensor/SensorComponent.cs @@ -1,7 +1,7 @@ using System; using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { /// /// Editor components for creating Sensors. Generally an ISensor implementation should have a corresponding diff --git a/com.unity.ml-agents/Runtime/Sensor/SensorShapeValidator.cs b/com.unity.ml-agents/Runtime/Sensor/SensorShapeValidator.cs index 0b92725c6a..3624a5ac67 100644 --- a/com.unity.ml-agents/Runtime/Sensor/SensorShapeValidator.cs +++ b/com.unity.ml-agents/Runtime/Sensor/SensorShapeValidator.cs @@ -1,7 +1,7 @@ using System.Collections.Generic; using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { internal class SensorShapeValidator { diff --git a/com.unity.ml-agents/Runtime/Sensor/StackingSensor.cs b/com.unity.ml-agents/Runtime/Sensor/StackingSensor.cs index bc825ec751..d230aa3e57 100644 --- a/com.unity.ml-agents/Runtime/Sensor/StackingSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensor/StackingSensor.cs @@ -1,4 +1,4 @@ -namespace MLAgents.Sensor +namespace MLAgents { /// /// Sensor that wraps around another Sensor to provide temporal stacking. diff --git a/com.unity.ml-agents/Runtime/Sensor/VectorSensor.cs b/com.unity.ml-agents/Runtime/Sensor/VectorSensor.cs index 0c1972a6c8..682da90efb 100644 --- a/com.unity.ml-agents/Runtime/Sensor/VectorSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensor/VectorSensor.cs @@ -1,7 +1,7 @@ using System.Collections.Generic; using UnityEngine; -namespace MLAgents.Sensor +namespace MLAgents { public class VectorSensor : ISensor { diff --git a/com.unity.ml-agents/Runtime/Sensor/WriteAdapter.cs b/com.unity.ml-agents/Runtime/Sensor/WriteAdapter.cs index 498e6dcd6a..3bc0df0aaa 100644 --- a/com.unity.ml-agents/Runtime/Sensor/WriteAdapter.cs +++ b/com.unity.ml-agents/Runtime/Sensor/WriteAdapter.cs @@ -3,7 +3,7 @@ using Barracuda; using MLAgents.InferenceBrain; -namespace MLAgents.Sensor +namespace MLAgents { /// /// Allows sensors to write to both TensorProxy and float arrays/lists. diff --git a/com.unity.ml-agents/Runtime/Utilities.cs b/com.unity.ml-agents/Runtime/Utilities.cs index b84008a9db..b828fe644c 100644 --- a/com.unity.ml-agents/Runtime/Utilities.cs +++ b/com.unity.ml-agents/Runtime/Utilities.cs @@ -1,6 +1,5 @@ using UnityEngine; using System.Collections.Generic; -using MLAgents.Sensor; namespace MLAgents { diff --git a/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs b/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs index df0bde9ac1..865f62788a 100644 --- a/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs +++ b/com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs @@ -3,7 +3,6 @@ using System.IO.Abstractions.TestingHelpers; using System.Reflection; using MLAgents.CommunicatorObjects; -using MLAgents.Sensor; namespace MLAgents.Tests { @@ -71,12 +70,12 @@ public void TestStoreInitalize() public class ObservationAgent : TestAgent { - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { collectObservationsCalls += 1; - AddVectorObs(1f); - AddVectorObs(2f); - AddVectorObs(3f); + sensor.AddObservation(1f); + sensor.AddObservation(2f); + sensor.AddObservation(3f); } } diff --git a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs index 234e46bbc5..a380422f2f 100644 --- a/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs +++ b/com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs @@ -1,7 +1,6 @@ using UnityEngine; using NUnit.Framework; using System.Reflection; -using MLAgents.Sensor; using System.Collections.Generic; namespace MLAgents.Tests @@ -51,10 +50,10 @@ public override void InitializeAgent() sensors.Add(sensor1); } - public override void CollectObservations() + public override void CollectObservations(VectorSensor sensor) { collectObservationsCalls += 1; - AddVectorObs(0f); + sensor.AddObservation(0f); } public override void AgentAction(float[] vectorAction) diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs index 6465e31dd9..4c5595284b 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs @@ -1,6 +1,5 @@ using NUnit.Framework; using UnityEngine; -using MLAgents.Sensor; namespace MLAgents.Tests { diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs index 14ad8b01cd..22ce15b07e 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using NUnit.Framework; using UnityEngine; -using MLAgents.Sensor; namespace MLAgents.Tests { diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs index 81c80193bc..77dc3d0243 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs @@ -1,6 +1,5 @@ using NUnit.Framework; using UnityEngine; -using MLAgents.Sensor; namespace MLAgents.Tests { diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs index f2439c5382..64140e241f 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs @@ -1,6 +1,5 @@ using NUnit.Framework; using UnityEngine; -using MLAgents.Sensor; namespace MLAgents.Tests { diff --git a/com.unity.ml-agents/Tests/Editor/Sensor/WriterAdapterTests.cs b/com.unity.ml-agents/Tests/Editor/Sensor/WriterAdapterTests.cs index 67a3650d73..4b6a39ba5e 100644 --- a/com.unity.ml-agents/Tests/Editor/Sensor/WriterAdapterTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Sensor/WriterAdapterTests.cs @@ -1,7 +1,5 @@ using NUnit.Framework; using UnityEngine; -using MLAgents.Sensor; - using Barracuda; using MLAgents.InferenceBrain; diff --git a/docs/Getting-Started-with-Balance-Ball.md b/docs/Getting-Started-with-Balance-Ball.md index b904a44035..dbffa34a51 100644 --- a/docs/Getting-Started-with-Balance-Ball.md +++ b/docs/Getting-Started-with-Balance-Ball.md @@ -71,11 +71,11 @@ The Ball3DAgent subclass defines the following methods: agent cube and ball. The function randomizes the reset values so that the training generalizes to more than a specific starting position and agent cube attitude. -* agent.CollectObservations() — Called every simulation step. Responsible for +* agent.CollectObservations(VectorSensor sensor) — Called every simulation step. Responsible for collecting the Agent's observations of the environment. Since the Behavior Parameters of the Agent are set with vector observation - space with a state size of 8, the `CollectObservations()` must call - `AddVectorObs` such that vector size adds up to 8. + space with a state size of 8, the `CollectObservations(VectorSensor sensor)` must call + `VectorSensor.AddObservation()` such that vector size adds up to 8. * agent.AgentAction() — Called every simulation step. Receives the action chosen by the Agent. The vector action spaces result in a small change in the agent cube's rotation at each step. The `AgentAction()` function @@ -102,7 +102,7 @@ This means that the feature vector containing the Agent's observations contains eight elements: the `x` and `z` components of the agent cube's rotation and the `x`, `y`, and `z` components of the ball's relative position and velocity. (The observation values are -defined in the Agent's `CollectObservations()` function.) +defined in the Agent's `CollectObservations(VectorSensor sensor)` method.) #### Behavior Parameters : Vector Action Space diff --git a/docs/Learning-Environment-Best-Practices.md b/docs/Learning-Environment-Best-Practices.md index 95e52b33ab..69b0665581 100644 --- a/docs/Learning-Environment-Best-Practices.md +++ b/docs/Learning-Environment-Best-Practices.md @@ -42,8 +42,8 @@ * Besides encoding non-numeric values, all inputs should be normalized to be in the range 0 to +1 (or -1 to 1). For example, the `x` position information of an agent where the maximum possible value is `maxValue` should be recorded as - `AddVectorObs(transform.position.x / maxValue);` rather than - `AddVectorObs(transform.position.x);`. See the equation below for one approach + `VectorSensor.AddObservation(transform.position.x / maxValue);` rather than + `VectorSensor.AddObservation(transform.position.x);`. See the equation below for one approach of normalization. * Positional information of relevant GameObjects should be encoded in relative coordinates wherever possible. This is often relative to the agent position. diff --git a/docs/Learning-Environment-Create-New.md b/docs/Learning-Environment-Create-New.md index e977924c6a..a32f9569fb 100644 --- a/docs/Learning-Environment-Create-New.md +++ b/docs/Learning-Environment-Create-New.md @@ -182,7 +182,7 @@ public class RollerAgent : Agent } ``` -Next, let's implement the `Agent.CollectObservations()` method. +Next, let's implement the `Agent.CollectObservations(VectorSensor sensor)` method. ### Observing the Environment @@ -198,13 +198,13 @@ In our case, the information our Agent collects includes: * Position of the target. ```csharp -AddVectorObs(Target.position); +sensor.AddObservation(Target.position); ``` * Position of the Agent itself. ```csharp -AddVectorObs(this.transform.position); +sensor.AddObservation(this.transform.position); ``` * The velocity of the Agent. This helps the Agent learn to control its speed so @@ -212,23 +212,23 @@ AddVectorObs(this.transform.position); ```csharp // Agent velocity -AddVectorObs(rBody.velocity.x); -AddVectorObs(rBody.velocity.z); +sensor.AddObservation(rBody.velocity.x); +sensor.AddObservation(rBody.velocity.z); ``` In total, the state observation contains 8 values and we need to use the continuous state space when we get around to setting the Brain properties: ```csharp -public override void CollectObservations() +public override void CollectObservations(VectorSensor sensor) { // Target and Agent positions - AddVectorObs(Target.position); - AddVectorObs(this.transform.position); + sensor.AddObservation(Target.position); + sensor.AddObservation(this.transform.position); // Agent velocity - AddVectorObs(rBody.velocity.x); - AddVectorObs(rBody.velocity.z); + sensor.AddObservation(rBody.velocity.x); + sensor.AddObservation(rBody.velocity.z); } ``` diff --git a/docs/Learning-Environment-Design-Agents.md b/docs/Learning-Environment-Design-Agents.md index 1aa563bce3..dcc92e44bb 100644 --- a/docs/Learning-Environment-Design-Agents.md +++ b/docs/Learning-Environment-Design-Agents.md @@ -35,6 +35,7 @@ should make its decisions every step of the simulation. On the other hand, an agent that only needs to make decisions when certain game or simulation events occur, should call `Agent.RequestDecision()` manually. + ## Observations To make decisions, an agent must observe its environment in order to infer the @@ -45,18 +46,18 @@ state of the world. A state observation can take the following forms: * **Visual Observations** — one or more camera images and/or render textures. When you use vector observations for an Agent, implement the -`Agent.CollectObservations()` method to create the feature vector. When you use +`Agent.CollectObservations(VectorSensor sensor)` method to create the feature vector. When you use **Visual Observations**, you only need to identify which Unity Camera objects or RenderTextures will provide images and the base Agent class handles the rest. -You do not need to implement the `CollectObservations()` method when your Agent +You do not need to implement the `CollectObservations(VectorSensor sensor)` method when your Agent uses visual observations (unless it also uses vector observations). ### Vector Observation Space: Feature Vectors For agents using a continuous state space, you create a feature vector to represent the agent's observation at each step of the simulation. The Policy -class calls the `CollectObservations()` method of each Agent. Your -implementation of this function must call `AddVectorObs` to add vector +class calls the `CollectObservations(VectorSensor sensor)` method of each Agent. Your +implementation of this function must call `VectorSensor.AddObservation` to add vector observations. The observation must include all the information an agents needs to accomplish @@ -78,16 +79,16 @@ noticeably worse. public GameObject ball; private List state = new List(); -public override void CollectObservations() +public override void CollectObservations(VectorSensor sensor) { - AddVectorObs(gameObject.transform.rotation.z); - AddVectorObs(gameObject.transform.rotation.x); - AddVectorObs((ball.transform.position.x - gameObject.transform.position.x)); - AddVectorObs((ball.transform.position.y - gameObject.transform.position.y)); - AddVectorObs((ball.transform.position.z - gameObject.transform.position.z)); - AddVectorObs(ball.transform.GetComponent().velocity.x); - AddVectorObs(ball.transform.GetComponent().velocity.y); - AddVectorObs(ball.transform.GetComponent().velocity.z); + sensor.AddObservation(gameObject.transform.rotation.z); + sensor.AddObservation(gameObject.transform.rotation.x); + sensor.AddObservation((ball.transform.position.x - gameObject.transform.position.x)); + sensor.AddObservation((ball.transform.position.y - gameObject.transform.position.y)); + sensor.AddObservation((ball.transform.position.z - gameObject.transform.position.z)); + sensor.AddObservation(ball.transform.GetComponent().velocity.x); + sensor.AddObservation(ball.transform.GetComponent().velocity.y); + sensor.AddObservation(ball.transform.GetComponent().velocity.z); } ``` @@ -106,7 +107,7 @@ properties to use a continuous vector observation: The observation feature vector is a list of floating point numbers, which means you must convert any other data types to a float or a list of floats. -The `AddVectorObs` method provides a number of overloads for adding common types +The `VectorSensor.AddObservation` method provides a number of overloads for adding common types of data to your observation vector. You can add Integers and booleans directly to the observation vector, as well as some common Unity data types such as `Vector2`, `Vector3`, and `Quaternion`. @@ -121,27 +122,27 @@ the feature vector. The following code example illustrates how to add. ```csharp enum CarriedItems { Sword, Shield, Bow, LastItem } private List state = new List(); -public override void CollectObservations() +public override void CollectObservations(VectorSensor sensor) { for (int ci = 0; ci < (int)CarriedItems.LastItem; ci++) { - AddVectorObs((int)currentItem == ci ? 1.0f : 0.0f); + sensor.AddObservation((int)currentItem == ci ? 1.0f : 0.0f); } } ``` -`AddVectorObs` also provides a two-argument version as a shortcut for _one-hot_ +`VectorSensor.AddObservation` also provides a two-argument version as a shortcut for _one-hot_ style observations. The following example is identical to the previous one. ```csharp enum CarriedItems { Sword, Shield, Bow, LastItem } const int NUM_ITEM_TYPES = (int)CarriedItems.LastItem; -public override void CollectObservations() +public override void CollectObservations(VectorSensor sensor) { // The first argument is the selection index; the second is the // number of possibilities - AddVectorObs((int)currentItem, NUM_ITEM_TYPES); + sensor.AddOneHotObservation((int)currentItem, NUM_ITEM_TYPES); } ``` diff --git a/docs/Learning-Environment-Design.md b/docs/Learning-Environment-Design.md index 81a1d0358d..3d7077901e 100644 --- a/docs/Learning-Environment-Design.md +++ b/docs/Learning-Environment-Design.md @@ -41,7 +41,7 @@ The ML-Agents Academy class orchestrates the agent simulation loop as follows: 1. Calls your Academy's `OnEnvironmentReset` delegate. 2. Calls the `AgentReset()` function for each Agent in the scene. -3. Calls the `CollectObservations()` function for each Agent in the scene. +3. Calls the `CollectObservations(VectorSensor sensor)` function for each Agent in the scene. 4. Uses each Agent's Policy to decide on the Agent's next action. 5. Calls the `AgentAction()` function for each Agent in the scene, passing in the action chosen by the Agent's Policy. (This function is not called if the @@ -50,7 +50,7 @@ The ML-Agents Academy class orchestrates the agent simulation loop as follows: Step` count or has otherwise marked itself as `done`. To create a training environment, extend the Agent class to -implement the above methods. The `Agent.CollectObservations()` and +implement the above methods. The `Agent.CollectObservations(VectorSensor sensor)` and `Agent.AgentAction()` functions are required; the other methods are optional — whether you need to implement them or not depends on your specific scenario. @@ -107,9 +107,9 @@ in a football game or a car object in a vehicle simulation. Every Agent must have appropriate `Behavior Parameters`. To create an Agent, extend the Agent class and implement the essential -`CollectObservations()` and `AgentAction()` methods: +`CollectObservations(VectorSensor sensor)` and `AgentAction()` methods: -* `CollectObservations()` — Collects the Agent's observation of its environment. +* `CollectObservations(VectorSensor sensor)` — Collects the Agent's observation of its environment. * `AgentAction()` — Carries out the action chosen by the Agent's Policy and assigns a reward to the current state. diff --git a/docs/Migrating.md b/docs/Migrating.md index 0fcba40451..765ae068a8 100644 --- a/docs/Migrating.md +++ b/docs/Migrating.md @@ -10,9 +10,14 @@ The versions can be found in ## Migrating from 0.14 to latest ### Important changes +* The `Agent.CollectObservations()` virtual method now takes as input a `VectorSensor` sensor as argument. The `Agent.AddVectorObs()` methods were removed. * The `Monitor` class has been moved to the Examples Project. (It was prone to errors during testing) +* The `MLAgents.Sensor` namespace has been removed. All sensors now belong to the `MLAgents` namespace. + ### Steps to Migrate +* Replace your Agent's implementation of `CollectObservations()` with `CollectObservations(VectorSensor sensor)`. In addition, replace all calls to `AddVectorObs()` with `sensor.AddObservation()` or `sensor.AddOneHotObservation()` on the `VectorSensor` passed as argument. + ## Migrating from 0.13 to 0.14