Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,41 @@ Rigidbody:
m_Interpolate: 0
m_Constraints: 0
m_CollisionDetection: 0
--- !u!114 &8042564747579887
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 7516457449653310668}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 3c8f113a8b8d94967b1b1782c549be81, type: 3}
m_Name:
m_EditorClassIdentifier:
tagToDetect: agent
spawnRadius: 40
respawnIfTouched: 1
respawnIfFallsOffPlatform: 1
fallDistance: 5
onTriggerEnterEvent:
m_PersistentCalls:
m_Calls: []
onTriggerStayEvent:
m_PersistentCalls:
m_Calls: []
onTriggerExitEvent:
m_PersistentCalls:
m_Calls: []
onCollisionEnterEvent:
m_PersistentCalls:
m_Calls: []
onCollisionStayEvent:
m_PersistentCalls:
m_Calls: []
onCollisionExitEvent:
m_PersistentCalls:
m_Calls: []
--- !u!1001 &906401165941233076
PrefabInstance:
m_ObjectHideFlags: 0
Expand Down Expand Up @@ -93,6 +128,11 @@ PrefabInstance:
propertyPath: ground
value:
objectReference: {fileID: 7519759559437056804}
- target: {fileID: 6060305997946326746, guid: 3ebcde4cf2d5c4c029e2a5ce3d853aba,
type: 3}
propertyPath: m_TagString
value: agent
objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: 3ebcde4cf2d5c4c029e2a5ce3d853aba, type: 3}
--- !u!1001 &7202236613889278392
Expand Down Expand Up @@ -202,6 +242,12 @@ Transform:
type: 3}
m_PrefabInstance: {fileID: 7202236613889278392}
m_PrefabAsset: {fileID: 0}
--- !u!1 &7516457449653310668 stripped
GameObject:
m_CorrespondingSourceObject: {fileID: 845742365997159796, guid: d6fc96a99a9754f07b48abf1e0d55a5c,
type: 3}
m_PrefabInstance: {fileID: 7202236613889278392}
m_PrefabAsset: {fileID: 0}
--- !u!4 &7513373574146463010 stripped
Transform:
m_CorrespondingSourceObject: {fileID: 844321025358320794, guid: d6fc96a99a9754f07b48abf1e0d55a5c,
Expand Down
76 changes: 18 additions & 58 deletions Project/Assets/ML-Agents/Examples/Worm/Scripts/WormAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,21 @@
[RequireComponent(typeof(JointDriveController))] // Required to set joint forces
public class WormAgent : Agent
{
[Header("Target To Walk Towards")]
[Space(10)]
[Header("Target To Walk Towards")] [Space(10)]
public Transform target;

public Transform ground;
public bool detectTargets;
public bool targetIsStatic;
public bool respawnTargetWhenTouched;
public float targetSpawnRadius;

[Header("Body Parts")] [Space(10)]
public Transform bodySegment0;
[Header("Body Parts")] [Space(10)] public Transform bodySegment0;
public Transform bodySegment1;
public Transform bodySegment2;
public Transform bodySegment3;

[Header("Joint Settings")] [Space(10)]
JointDriveController m_JdController;
[Header("Joint Settings")] [Space(10)] JointDriveController m_JdController;
Vector3 m_DirToTarget;
float m_MovingTowardsDot;
float m_FacingDot;

[Header("Reward Functions To Use")]
[Space(10)]
[Header("Reward Functions To Use")] [Space(10)]
public bool rewardMovingTowardsTarget; // Agent should move towards target

public bool rewardFacingTarget; // Agent should face the target
Expand All @@ -50,21 +41,14 @@ public override void Initialize()
m_JdController.SetupBodyPart(bodySegment1);
m_JdController.SetupBodyPart(bodySegment2);
m_JdController.SetupBodyPart(bodySegment3);

//We only want the head to detect the target
//So we need to remove TargetContact from everything else
//This is a temp fix till we can redesign
DestroyImmediate(bodySegment1.GetComponent<TargetContact>());
DestroyImmediate(bodySegment2.GetComponent<TargetContact>());
DestroyImmediate(bodySegment3.GetComponent<TargetContact>());
}


//Get Joint Rotation Relative to the Connected Rigidbody
//We want to collect this info because it is the actual rotation, not the "target rotation"
public Quaternion GetJointRotation(ConfigurableJoint joint)
{
return(Quaternion.FromToRotation(joint.axis, joint.connectedBody.transform.rotation.eulerAngles));
return (Quaternion.FromToRotation(joint.axis, joint.connectedBody.transform.rotation.eulerAngles));
}

/// <summary>
Expand All @@ -78,7 +62,8 @@ public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
var velocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.velocity);
sensor.AddObservation(velocityRelativeToLookRotationToTarget);

var angularVelocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.angularVelocity);
var angularVelocityRelativeToLookRotationToTarget =
m_TargetDirMatrix.inverse.MultiplyVector(rb.angularVelocity);
sensor.AddObservation(angularVelocityRelativeToLookRotationToTarget);

if (bp.rb.transform != bodySegment0)
Expand All @@ -103,18 +88,19 @@ public override void CollectObservations(VectorSensor sensor)
float maxDist = 10;
if (Physics.Raycast(bodySegment0.position, Vector3.down, out hit, maxDist))
{
sensor.AddObservation(hit.distance/maxDist);
sensor.AddObservation(hit.distance / maxDist);
}
else
sensor.AddObservation(1);

foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
foreach (var bodyPart in m_JdController.bodyPartsList)
{
CollectObservationBodyPart(bodyPart, sensor);
}

//Rotation delta between the matrix and the head
Quaternion headRotationDeltaFromMatrixRot = Quaternion.Inverse(m_TargetDirMatrix.rotation) * bodySegment0.rotation;
Quaternion headRotationDeltaFromMatrixRot =
Quaternion.Inverse(m_TargetDirMatrix.rotation) * bodySegment0.rotation;
sensor.AddObservation(headRotationDeltaFromMatrixRot);
}

Expand All @@ -124,20 +110,6 @@ public override void CollectObservations(VectorSensor sensor)
public void TouchedTarget()
{
AddReward(1f);
if (respawnTargetWhenTouched)
{
GetRandomTargetPos();
}
}

/// <summary>
/// Moves target to a random position within specified radius.
/// </summary>
public void GetRandomTargetPos()
{
var newTargetPos = Random.insideUnitSphere * targetSpawnRadius;
newTargetPos.y = 5;
target.position = newTargetPos + ground.position;
}

public override void OnActionReceived(float[] vectorAction)
Expand All @@ -156,25 +128,15 @@ public override void OnActionReceived(float[] vectorAction)
bpDict[bodySegment2].SetJointStrength(vectorAction[++i]);
bpDict[bodySegment3].SetJointStrength(vectorAction[++i]);

if (bodySegment0.position.y < ground.position.y -2)
// Detect if worm fell off/through platform
if (bodySegment0.position.y < ground.position.y - 2)
{
EndEpisode();
}
}

void FixedUpdate()
{
if (detectTargets)
{
foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
{
if (bodyPart.targetContact && bodyPart.targetContact.touchingTarget)
{
TouchedTarget();
}
}
}

// Set reward for this step according to mixture of the following elements.
if (rewardMovingTowardsTarget)
{
Expand All @@ -197,7 +159,8 @@ void FixedUpdate()
/// </summary>
void RewardFunctionMovingTowards()
{
m_MovingTowardsDot = Vector3.Dot(m_JdController.bodyPartsDict[bodySegment0].rb.velocity, m_DirToTarget.normalized);
m_MovingTowardsDot =
Vector3.Dot(m_JdController.bodyPartsDict[bodySegment0].rb.velocity, m_DirToTarget.normalized);
AddReward(0.01f * m_MovingTowardsDot);
}

Expand All @@ -211,7 +174,7 @@ void RewardFunctionFacingTarget()
}

/// <summary>
/// Existential penalty for time-contrained tasks.
/// Existential penalty for time-constrained tasks.
/// </summary>
void RewardFunctionTimePenalty()
{
Expand All @@ -227,15 +190,12 @@ public override void OnEpisodeBegin()
{
bodyPart.Reset(bodyPart);
}

if (m_DirToTarget != Vector3.zero)
{
transform.rotation = Quaternion.LookRotation(m_DirToTarget);
}
transform.Rotate(Vector3.up, Random.Range(0.0f, 360.0f));

if (!targetIsStatic)
{
GetRandomTargetPos();
}
transform.Rotate(Vector3.up, Random.Range(0.0f, 360.0f));
}
}