Skip to content

Commit 6650fdb

Browse files
author
Chris Elion
committed
Merge remote-tracking branch 'origin/master' into develop-BehaviorParams-public
2 parents bf95e43 + 04d2cc9 commit 6650fdb

File tree

39 files changed

+526
-373
lines changed

39 files changed

+526
-373
lines changed

Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public class Ball3DAgent : Agent
1010
Rigidbody m_BallRb;
1111
FloatPropertiesChannel m_ResetParams;
1212

13-
public override void InitializeAgent()
13+
public override void Initialize()
1414
{
1515
m_BallRb = ball.GetComponent<Rigidbody>();
1616
m_ResetParams = Academy.Instance.FloatProperties;
@@ -25,7 +25,7 @@ public override void CollectObservations(VectorSensor sensor)
2525
sensor.AddObservation(m_BallRb.velocity);
2626
}
2727

28-
public override void AgentAction(float[] vectorAction)
28+
public override void OnActionReceived(float[] vectorAction)
2929
{
3030
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
3131
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
@@ -46,15 +46,15 @@ public override void AgentAction(float[] vectorAction)
4646
Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
4747
{
4848
SetReward(-1f);
49-
Done();
49+
EndEpisode();
5050
}
5151
else
5252
{
5353
SetReward(0.1f);
5454
}
5555
}
5656

57-
public override void AgentReset()
57+
public override void OnEpisodeBegin()
5858
{
5959
gameObject.transform.rotation = new Quaternion(0f, 0f, 0f, 0f);
6060
gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));

Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public class Ball3DHardAgent : Agent
1010
Rigidbody m_BallRb;
1111
FloatPropertiesChannel m_ResetParams;
1212

13-
public override void InitializeAgent()
13+
public override void Initialize()
1414
{
1515
m_BallRb = ball.GetComponent<Rigidbody>();
1616
m_ResetParams = Academy.Instance.FloatProperties;
@@ -24,7 +24,7 @@ public override void CollectObservations(VectorSensor sensor)
2424
sensor.AddObservation((ball.transform.position - gameObject.transform.position));
2525
}
2626

27-
public override void AgentAction(float[] vectorAction)
27+
public override void OnActionReceived(float[] vectorAction)
2828
{
2929
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
3030
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
@@ -45,15 +45,15 @@ public override void AgentAction(float[] vectorAction)
4545
Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
4646
{
4747
SetReward(-1f);
48-
Done();
48+
EndEpisode();
4949
}
5050
else
5151
{
5252
SetReward(0.1f);
5353
}
5454
}
5555

56-
public override void AgentReset()
56+
public override void OnEpisodeBegin()
5757
{
5858
gameObject.transform.rotation = new Quaternion(0f, 0f, 0f, 0f);
5959
gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));

Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicController.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ public void ApplyAction(float[] vectorAction)
5959
if (m_Position == k_SmallGoalPosition)
6060
{
6161
m_Agent.AddReward(0.1f);
62-
m_Agent.Done();
62+
m_Agent.EndEpisode();
6363
ResetAgent();
6464
}
6565

6666
if (m_Position == k_LargeGoalPosition)
6767
{
6868
m_Agent.AddReward(1f);
69-
m_Agent.Done();
69+
m_Agent.EndEpisode();
7070
ResetAgent();
7171
}
7272
}

Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public class BouncerAgent : Agent
1717

1818
FloatPropertiesChannel m_ResetParams;
1919

20-
public override void InitializeAgent()
20+
public override void Initialize()
2121
{
2222
m_Rb = gameObject.GetComponent<Rigidbody>();
2323
m_LookDir = Vector3.zero;
@@ -33,7 +33,7 @@ public override void CollectObservations(VectorSensor sensor)
3333
sensor.AddObservation(target.transform.localPosition);
3434
}
3535

36-
public override void AgentAction(float[] vectorAction)
36+
public override void OnActionReceived(float[] vectorAction)
3737
{
3838
for (var i = 0; i < vectorAction.Length; i++)
3939
{
@@ -52,7 +52,7 @@ public override void AgentAction(float[] vectorAction)
5252
m_LookDir = new Vector3(x, y, z);
5353
}
5454

55-
public override void AgentReset()
55+
public override void OnEpisodeBegin()
5656
{
5757
gameObject.transform.localPosition = new Vector3(
5858
(1 - 2 * Random.value) * 5, 2, (1 - 2 * Random.value) * 5);
@@ -85,20 +85,20 @@ void FixedUpdate()
8585
if (gameObject.transform.position.y < -1)
8686
{
8787
AddReward(-1);
88-
Done();
88+
EndEpisode();
8989
return;
9090
}
9191

9292
if (gameObject.transform.localPosition.x < -19 || gameObject.transform.localPosition.x > 19
9393
|| gameObject.transform.localPosition.z < -19 || gameObject.transform.localPosition.z > 19)
9494
{
9595
AddReward(-1);
96-
Done();
96+
EndEpisode();
9797
return;
9898
}
9999
if (m_JumpLeft == 0)
100100
{
101-
Done();
101+
EndEpisode();
102102
}
103103
}
104104

Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public class CrawlerAgent : Agent
5252
Quaternion m_LookRotation;
5353
Matrix4x4 m_TargetDirMatrix;
5454

55-
public override void InitializeAgent()
55+
public override void Initialize()
5656
{
5757
m_JdController = GetComponent<JointDriveController>();
5858
m_DirToTarget = target.position - body.position;
@@ -147,7 +147,7 @@ public void GetRandomTargetPos()
147147
target.position = newTargetPos + ground.position;
148148
}
149149

150-
public override void AgentAction(float[] vectorAction)
150+
public override void OnActionReceived(float[] vectorAction)
151151
{
152152
// The dictionary with all the body parts in it are in the jdController
153153
var bpDict = m_JdController.bodyPartsDict;
@@ -251,7 +251,7 @@ void RewardFunctionTimePenalty()
251251
/// <summary>
252252
/// Loop over body parts and reset them to initial conditions.
253253
/// </summary>
254-
public override void AgentReset()
254+
public override void OnEpisodeBegin()
255255
{
256256
if (m_DirToTarget != Vector3.zero)
257257
{

Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@ public class FoodCollectorAgent : Agent
2929
public bool useVectorObs;
3030

3131

32-
public override void InitializeAgent()
32+
public override void Initialize()
3333
{
34-
base.InitializeAgent();
3534
m_AgentRb = GetComponent<Rigidbody>();
3635
m_MyArea = area.GetComponent<FoodCollectorArea>();
3736
m_FoodCollecterSettings = FindObjectOfType<FoodCollectorSettings>();
@@ -202,7 +201,7 @@ void Unsatiate()
202201
gameObject.GetComponentInChildren<Renderer>().material = normalMaterial;
203202
}
204203

205-
public override void AgentAction(float[] vectorAction)
204+
public override void OnActionReceived(float[] vectorAction)
206205
{
207206
MoveAgent(vectorAction);
208207
}
@@ -230,7 +229,7 @@ public override float[] Heuristic()
230229
return action;
231230
}
232231

233-
public override void AgentReset()
232+
public override void OnEpisodeBegin()
234233
{
235234
Unfreeze();
236235
Unpoison();
3.25 MB
Binary file not shown.

Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@ public class GridAgent : Agent
2828
const int k_Left = 3;
2929
const int k_Right = 4;
3030

31-
public override void InitializeAgent()
32-
{
33-
}
34-
3531
public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
3632
{
3733
// Mask the necessary actions if selected by the user.
@@ -65,7 +61,7 @@ public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMaske
6561
}
6662

6763
// to be implemented by the developer
68-
public override void AgentAction(float[] vectorAction)
64+
public override void OnActionReceived(float[] vectorAction)
6965
{
7066
AddReward(-0.01f);
7167
var action = Mathf.FloorToInt(vectorAction[0]);
@@ -101,12 +97,12 @@ public override void AgentAction(float[] vectorAction)
10197
if (hit.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1)
10298
{
10399
SetReward(1f);
104-
Done();
100+
EndEpisode();
105101
}
106102
else if (hit.Where(col => col.gameObject.CompareTag("pit")).ToArray().Length == 1)
107103
{
108104
SetReward(-1f);
109-
Done();
105+
EndEpisode();
110106
}
111107
}
112108
}
@@ -133,7 +129,7 @@ public override float[] Heuristic()
133129
}
134130

135131
// to be implemented by the developer
136-
public override void AgentReset()
132+
public override void OnEpisodeBegin()
137133
{
138134
area.AreaReset();
139135
}

Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ public class HallwayAgent : Agent
1818
HallwaySettings m_HallwaySettings;
1919
int m_Selection;
2020

21-
public override void InitializeAgent()
21+
public override void Initialize()
2222
{
23-
base.InitializeAgent();
2423
m_HallwaySettings = FindObjectOfType<HallwaySettings>();
2524
m_AgentRb = GetComponent<Rigidbody>();
2625
m_GroundRenderer = ground.GetComponent<Renderer>();
@@ -67,7 +66,7 @@ public void MoveAgent(float[] act)
6766
m_AgentRb.AddForce(dirToGo * m_HallwaySettings.agentRunSpeed, ForceMode.VelocityChange);
6867
}
6968

70-
public override void AgentAction(float[] vectorAction)
69+
public override void OnActionReceived(float[] vectorAction)
7170
{
7271
AddReward(-1f / maxStep);
7372
MoveAgent(vectorAction);
@@ -88,7 +87,7 @@ void OnCollisionEnter(Collision col)
8887
SetReward(-0.1f);
8988
StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.failMaterial, 0.5f));
9089
}
91-
Done();
90+
EndEpisode();
9291
}
9392
}
9493

@@ -113,7 +112,7 @@ public override float[] Heuristic()
113112
return new float[] { 0 };
114113
}
115114

116-
public override void AgentReset()
115+
public override void OnEpisodeBegin()
117116
{
118117
var agentOffset = -15f;
119118
var blockOffset = 0f;

Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,8 @@ void Awake()
5353
m_PushBlockSettings = FindObjectOfType<PushBlockSettings>();
5454
}
5555

56-
public override void InitializeAgent()
56+
public override void Initialize()
5757
{
58-
base.InitializeAgent();
5958
goalDetect = block.GetComponent<GoalDetect>();
6059
goalDetect.agent = this;
6160

@@ -105,7 +104,7 @@ public void ScoredAGoal()
105104
AddReward(5f);
106105

107106
// By marking an agent as done AgentReset() will be called automatically.
108-
Done();
107+
EndEpisode();
109108

110109
// Swap ground material for a bit to indicate we scored.
111110
StartCoroutine(GoalScoredSwapGroundMaterial(m_PushBlockSettings.goalScoredMaterial, 0.5f));
@@ -161,7 +160,7 @@ public void MoveAgent(float[] act)
161160
/// <summary>
162161
/// Called every step of the engine. Here the agent takes an action.
163162
/// </summary>
164-
public override void AgentAction(float[] vectorAction)
163+
public override void OnActionReceived(float[] vectorAction)
165164
{
166165
// Move the agent using the action.
167166
MoveAgent(vectorAction);
@@ -210,7 +209,7 @@ void ResetBlock()
210209
/// In the editor, if "Reset On Done" is checked then AgentReset() will be
211210
/// called automatically anytime we mark done = true in an agent script.
212211
/// </summary>
213-
public override void AgentReset()
212+
public override void OnEpisodeBegin()
214213
{
215214
var rotation = Random.Range(0, 4);
216215
var rotationAngle = rotation * 90f;

0 commit comments

Comments
 (0)