Skip to content

Moving domain randomization to C# #4065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 49 commits into from
Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
8d0244b
passing sampler configs to c#
andrewcoh Jun 1, 2020
eb0c495
ignoring commit checks
andrewcoh Jun 3, 2020
e9d8350
use settings.py to check PR config
andrewcoh Jun 3, 2020
a1c771e
Merge branch 'master' into develop-sampler-refactor
andrewcoh Jun 3, 2020
e856f7b
some cleanups/ interval error checking
andrewcoh Jun 3, 2020
daa5688
using validator to check settings
andrewcoh Jun 4, 2020
9756a2c
rename min value function check
andrewcoh Jun 4, 2020
ec2493f
type checks for parameter randomization settings/enforces float encoding
andrewcoh Jun 4, 2020
aa4ebd9
tests for settings
andrewcoh Jun 4, 2020
460a2ea
error properly when a keyword is not followed by a valid config in yaml
andrewcoh Jun 5, 2020
54b6959
seed each sampler individually
andrewcoh Jun 5, 2020
b4469ca
using to_float for encoding
andrewcoh Jun 5, 2020
3d26047
use run_seed if no seed specified in yaml
andrewcoh Jun 5, 2020
f18da6a
Merge branch 'master' into develop-sampler-refactor
andrewcoh Jun 5, 2020
7f116cd
add docstring for maybe_add_samplers
andrewcoh Jun 5, 2020
46f6491
added multirange uniform distr
andrewcoh Jun 7, 2020
5286675
fix variable name case
andrewcoh Jun 7, 2020
9dbcc4b
from set_sampler_params => set_{samplertype}_params
andrewcoh Jun 7, 2020
d3e0d9c
cleaned up sampler
andrewcoh Jun 7, 2020
4c111e4
fix tests
andrewcoh Jun 7, 2020
38f48f1
dummy doc update to trigger CI
andrewcoh Jun 7, 2020
988452a
Test Circle-CI
cypres Jun 8, 2020
cd06ce7
clean up
andrewcoh Jun 8, 2020
e40951f
Merge branch 'develop-sampler-refactor' of https://github.com/Unity-T…
andrewcoh Jun 8, 2020
1b9f2d5
remove square brackets
andrewcoh Jun 8, 2020
aff9c00
fix side channel helper functions/add offset to seed
andrewcoh Jun 8, 2020
63a24cc
restructure yaml config
andrewcoh Jun 9, 2020
da3cb2d
fix tests and markdown
andrewcoh Jun 9, 2020
fae0ca3
some doc updates
andrewcoh Jun 9, 2020
681c7ea
doc updates
andrewcoh Jun 10, 2020
c40c4d0
remove "adding your own sampler" section from docs
andrewcoh Jun 10, 2020
87a3cfc
Merge branch 'master' into develop-sampler-refactor
andrewcoh Jun 11, 2020
320233b
update changelog
andrewcoh Jun 11, 2020
94e1d20
update upgrade_config.py
andrewcoh Jun 11, 2020
9d18e7b
sampler C# tests
andrewcoh Jun 12, 2020
38e1115
flatten intervals just before sending
andrewcoh Jun 12, 2020
de6ba28
update comment
andrewcoh Jun 12, 2020
bf52d2f
update settings default
andrewcoh Jun 12, 2020
915d102
fix conversion script and test
andrewcoh Jun 12, 2020
af11e36
moved sampler type checking from env_manager to env_side_channel
andrewcoh Jun 12, 2020
61e3165
update simple env manager
andrewcoh Jun 12, 2020
49ee082
Merge branch 'master' into develop-sampler-refactor
andrewcoh Jun 12, 2020
cf838d9
fix C# unit test
andrewcoh Jun 12, 2020
fd5420f
fix C# tests for cloud
andrewcoh Jun 12, 2020
4afc7b1
typing of intervals fix
andrewcoh Jun 12, 2020
4c6cf57
made sampler static class
andrewcoh Jun 12, 2020
9088b4c
fixed comment
andrewcoh Jun 12, 2020
c5da1c3
fix docs
andrewcoh Jun 12, 2020
1b45bab
sampler settings apply themselves to env channel
andrewcoh Jun 12, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to
### Major Changes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- The Parameter Randomization feature has been refactored to enable sampling of new parameters per episode to improve robustness. The
`resampling-interval` parameter has been removed and the config structure updated. More information [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-ML-Agents.md). (#4065)

### Minor Changes
#### com.unity.ml-agents (C#)
Expand Down
70 changes: 70 additions & 0 deletions com.unity.ml-agents/Runtime/Sampler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
using System;
using System.Collections.Generic;
using Unity.MLAgents.Inference.Utils;
using UnityEngine;
using Random=System.Random;

namespace Unity.MLAgents
{

/// <summary>
/// Takes a list of floats that encode a sampling distribution and returns the sampling function.
/// </summary>
internal static class SamplerFactory
{

public static Func<float> CreateUniformSampler(float min, float max, int seed)
{
Random distr = new Random(seed);
return () => min + (float)distr.NextDouble() * (max - min);
}

public static Func<float> CreateGaussianSampler(float mean, float stddev, int seed)
{
RandomNormal distr = new RandomNormal(seed, mean, stddev);
return () => (float)distr.NextDouble();
}

public static Func<float> CreateMultiRangeUniformSampler(IList<float> intervals, int seed)
{
//RNG
Random distr = new Random(seed);
// Will be used to normalize intervalFuncs
float sumIntervalSizes = 0;
//The number of intervals
int numIntervals = (int)(intervals.Count/2);
// List that will store interval lengths
float[] intervalSizes = new float[numIntervals];
// List that will store uniform distributions
IList<Func<float>> intervalFuncs = new Func<float>[numIntervals];
// Collect all intervals and store as uniform distrus
// Collect all interval sizes
for(int i = 0; i < numIntervals; i++)
{
var min = intervals[2 * i];
var max = intervals[2 * i + 1];
var intervalSize = max - min;
sumIntervalSizes += intervalSize;
intervalSizes[i] = intervalSize;
intervalFuncs[i] = () => min + (float)distr.NextDouble() * intervalSize;
}
// Normalize interval lengths
for(int i = 0; i < numIntervals; i++)
{
intervalSizes[i] = intervalSizes[i] / sumIntervalSizes;
}
// Build cmf for intervals
for(int i = 1; i < numIntervals; i++)
{
intervalSizes[i] += intervalSizes[i - 1];
}
Multinomial intervalDistr = new Multinomial(seed + 1);
float MultiRange()
{
int sampledInterval = intervalDistr.Sample(intervalSizes);
return intervalFuncs[sampledInterval].Invoke();
}
return MultiRange;
}
}
}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Runtime/Sampler.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,30 @@ namespace Unity.MLAgents.SideChannels
/// </summary>
internal enum EnvironmentDataTypes
{
Float = 0
Float = 0,
Sampler = 1
}

/// <summary>
/// The types of distributions from which to sample reset parameters.
/// </summary>
internal enum SamplerType
{
/// <summary>
/// Samples a reset parameter from a uniform distribution.
/// </summary>
Uniform = 0,

/// <summary>
/// Samples a reset parameter from a Gaussian distribution.
/// </summary>
Gaussian = 1,

/// <summary>
/// Samples a reset parameter from a MultiRangeUniform distribution.
/// </summary>
MultiRangeUniform = 2

}

/// <summary>
Expand All @@ -18,7 +41,7 @@ internal enum EnvironmentDataTypes
/// </summary>
internal class EnvironmentParametersChannel : SideChannel
{
Dictionary<string, float> m_Parameters = new Dictionary<string, float>();
Dictionary<string, Func<float>> m_Parameters = new Dictionary<string, Func<float>>();
Dictionary<string, Action<float>> m_RegisteredActions =
new Dictionary<string, Action<float>>();

Expand All @@ -42,12 +65,40 @@ protected override void OnMessageReceived(IncomingMessage msg)
{
var value = msg.ReadFloat32();

m_Parameters[key] = value;
m_Parameters[key] = () => value;

Action<float> action;
m_RegisteredActions.TryGetValue(key, out action);
action?.Invoke(value);
}
else if ((int)EnvironmentDataTypes.Sampler == type)
{
int seed = msg.ReadInt32();
int samplerType = msg.ReadInt32();
Func<float> sampler = () => 0.0f;
if ((int)SamplerType.Uniform == samplerType)
{
float min = msg.ReadFloat32();
float max = msg.ReadFloat32();
sampler = SamplerFactory.CreateUniformSampler(min, max, seed);
}
else if ((int)SamplerType.Gaussian == samplerType)
{
float mean = msg.ReadFloat32();
float stddev = msg.ReadFloat32();

sampler = SamplerFactory.CreateGaussianSampler(mean, stddev, seed);
}
else if ((int)SamplerType.MultiRangeUniform == samplerType)
{
IList<float> intervals = msg.ReadFloatList();
sampler = SamplerFactory.CreateMultiRangeUniformSampler(intervals, seed);
}
else{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
}
m_Parameters[key] = sampler;
}
else
{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
Expand All @@ -63,9 +114,9 @@ protected override void OnMessageReceived(IncomingMessage msg)
/// <returns></returns>
public float GetWithDefault(string key, float defaultValue)
{
float valueOut;
Func<float> valueOut;
bool hasKey = m_Parameters.TryGetValue(key, out valueOut);
return hasKey ? valueOut : defaultValue;
return hasKey ? valueOut.Invoke() : defaultValue;
}

/// <summary>
Expand Down
109 changes: 109 additions & 0 deletions com.unity.ml-agents/Tests/Editor/SamplerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using System;
using NUnit.Framework;
using System.IO;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.SideChannels;

namespace Unity.MLAgents.Tests
{
public class SamplerTests
{
const int k_Seed = 1337;
const double k_Epsilon = 0.0001;
EnvironmentParametersChannel m_Channel;

public SamplerTests()
{
m_Channel = SideChannelsManager.GetSideChannel<EnvironmentParametersChannel>();
// if running test on its own
if (m_Channel == null)
{
m_Channel = new EnvironmentParametersChannel();
SideChannelsManager.RegisterSideChannel(m_Channel);
}
}
[Test]
public void UniformSamplerTest()
{
float min_value = 1.0f;
float max_value = 2.0f;
string parameter = "parameter1";
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteString(parameter);
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1);
outgoingMsg.WriteInt32(k_Seed);
outgoingMsg.WriteInt32((int)SamplerType.Uniform);
outgoingMsg.WriteFloat32(min_value);
outgoingMsg.WriteFloat32(max_value);
byte[] message = GetByteMessage(m_Channel, outgoingMsg);
SideChannelsManager.ProcessSideChannelData(message);
}
Assert.AreEqual(1.208888f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
Assert.AreEqual(1.118017f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
}

[Test]
public void GaussianSamplerTest()
{
float mean = 3.0f;
float stddev = 0.2f;
string parameter = "parameter2";
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteString(parameter);
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1);
outgoingMsg.WriteInt32(k_Seed);
outgoingMsg.WriteInt32((int)SamplerType.Gaussian);
outgoingMsg.WriteFloat32(mean);
outgoingMsg.WriteFloat32(stddev);
byte[] message = GetByteMessage(m_Channel, outgoingMsg);
SideChannelsManager.ProcessSideChannelData(message);
}
Assert.AreEqual(2.936162f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
Assert.AreEqual(2.951348f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
}

[Test]
public void MultiRangeUniformSamplerTest()
{
float[] intervals = new float[4];
intervals[0] = 1.2f;
intervals[1] = 2f;
intervals[2] = 3.2f;
intervals[3] = 4.1f;
string parameter = "parameter3";
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteString(parameter);
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1);
outgoingMsg.WriteInt32(k_Seed);
outgoingMsg.WriteInt32((int)SamplerType.MultiRangeUniform);
outgoingMsg.WriteFloatList(intervals);
byte[] message = GetByteMessage(m_Channel, outgoingMsg);
SideChannelsManager.ProcessSideChannelData(message);
}
Assert.AreEqual(3.387999f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
Assert.AreEqual(1.294413f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
}

internal static byte[] GetByteMessage(SideChannel sideChannel, OutgoingMessage msg)
{
byte[] message = msg.ToByteArray();
using (var memStream = new MemoryStream())
{
using (var binaryWriter = new BinaryWriter(memStream))
{
binaryWriter.Write(sideChannel.ChannelId.ToByteArray());
binaryWriter.Write(message.Length);
binaryWriter.Write(message);
}
return memStream.ToArray();
}
}
}
}
11 changes: 11 additions & 0 deletions com.unity.ml-agents/Tests/Editor/SamplerTests.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 8 additions & 11 deletions config/ppo/3DBall_randomize.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,13 @@ behaviors:
threaded: true

parameter_randomization:
resampling-interval: 5000
mass:
sampler-type: uniform
min_value: 0.5
max_value: 10
gravity:
sampler-type: uniform
min_value: 7
max_value: 12
sampler_type: uniform
sampler_parameters:
min_value: 0.5
max_value: 10
scale:
sampler-type: uniform
min_value: 0.75
max_value: 3
sampler_type: uniform
sampler_parameters:
min_value: 0.75
max_value: 3
Loading