-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Moving domain randomization to C# #4065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
8d0244b
passing sampler configs to c#
andrewcoh eb0c495
ignoring commit checks
andrewcoh e9d8350
use settings.py to check PR config
andrewcoh a1c771e
Merge branch 'master' into develop-sampler-refactor
andrewcoh e856f7b
some cleanups/ interval error checking
andrewcoh daa5688
using validator to check settings
andrewcoh 9756a2c
rename min value function check
andrewcoh ec2493f
type checks for parameter randomization settings/enforces float encoding
andrewcoh aa4ebd9
tests for settings
andrewcoh 460a2ea
error properly when a keyword is not followed by a valid config in yaml
andrewcoh 54b6959
seed each sampler individually
andrewcoh b4469ca
using to_float for encoding
andrewcoh 3d26047
use run_seed if no seed specified in yaml
andrewcoh f18da6a
Merge branch 'master' into develop-sampler-refactor
andrewcoh 7f116cd
add docstring for maybe_add_samplers
andrewcoh 46f6491
added multirange uniform distr
andrewcoh 5286675
fix variable name case
andrewcoh 9dbcc4b
from set_sampler_params => set_{samplertype}_params
andrewcoh d3e0d9c
cleaned up sampler
andrewcoh 4c111e4
fix tests
andrewcoh 38f48f1
dummy doc update to trigger CI
andrewcoh 988452a
Test Circle-CI
cypres cd06ce7
clean up
andrewcoh e40951f
Merge branch 'develop-sampler-refactor' of https://github.com/Unity-T…
andrewcoh 1b9f2d5
remove square brackets
andrewcoh aff9c00
fix side channel helper functions/add offset to seed
andrewcoh 63a24cc
restructure yaml config
andrewcoh da3cb2d
fix tests and markdown
andrewcoh fae0ca3
some doc updates
andrewcoh 681c7ea
doc updates
andrewcoh c40c4d0
remove "adding your own sampler" section from docs
andrewcoh 87a3cfc
Merge branch 'master' into develop-sampler-refactor
andrewcoh 320233b
update changelog
andrewcoh 94e1d20
update upgrade_config.py
andrewcoh 9d18e7b
sampler C# tests
andrewcoh 38e1115
flatten intervals just before sending
andrewcoh de6ba28
update comment
andrewcoh bf52d2f
update settings default
andrewcoh 915d102
fix conversion script and test
andrewcoh af11e36
moved sampler type checking from env_manager to env_side_channel
andrewcoh 61e3165
update simple env manager
andrewcoh 49ee082
Merge branch 'master' into develop-sampler-refactor
andrewcoh cf838d9
fix C# unit test
andrewcoh fd5420f
fix C# tests for cloud
andrewcoh 4afc7b1
typing of intervals fix
andrewcoh 4c6cf57
made sampler static class
andrewcoh 9088b4c
fixed comment
andrewcoh c5da1c3
fix docs
andrewcoh 1b45bab
sampler settings apply themselves to env channel
andrewcoh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using Unity.MLAgents.Inference.Utils; | ||
using UnityEngine; | ||
using Random=System.Random; | ||
|
||
namespace Unity.MLAgents | ||
{ | ||
|
||
/// <summary> | ||
/// Takes a list of floats that encode a sampling distribution and returns the sampling function. | ||
/// </summary> | ||
internal static class SamplerFactory | ||
{ | ||
|
||
public static Func<float> CreateUniformSampler(float min, float max, int seed) | ||
{ | ||
Random distr = new Random(seed); | ||
return () => min + (float)distr.NextDouble() * (max - min); | ||
} | ||
|
||
public static Func<float> CreateGaussianSampler(float mean, float stddev, int seed) | ||
{ | ||
RandomNormal distr = new RandomNormal(seed, mean, stddev); | ||
return () => (float)distr.NextDouble(); | ||
} | ||
|
||
public static Func<float> CreateMultiRangeUniformSampler(IList<float> intervals, int seed) | ||
{ | ||
//RNG | ||
Random distr = new Random(seed); | ||
// Will be used to normalize intervalFuncs | ||
float sumIntervalSizes = 0; | ||
//The number of intervals | ||
int numIntervals = (int)(intervals.Count/2); | ||
// List that will store interval lengths | ||
float[] intervalSizes = new float[numIntervals]; | ||
// List that will store uniform distributions | ||
IList<Func<float>> intervalFuncs = new Func<float>[numIntervals]; | ||
// Collect all intervals and store as uniform distrus | ||
// Collect all interval sizes | ||
for(int i = 0; i < numIntervals; i++) | ||
{ | ||
var min = intervals[2 * i]; | ||
var max = intervals[2 * i + 1]; | ||
var intervalSize = max - min; | ||
sumIntervalSizes += intervalSize; | ||
intervalSizes[i] = intervalSize; | ||
intervalFuncs[i] = () => min + (float)distr.NextDouble() * intervalSize; | ||
} | ||
// Normalize interval lengths | ||
for(int i = 0; i < numIntervals; i++) | ||
{ | ||
intervalSizes[i] = intervalSizes[i] / sumIntervalSizes; | ||
} | ||
// Build cmf for intervals | ||
for(int i = 1; i < numIntervals; i++) | ||
{ | ||
intervalSizes[i] += intervalSizes[i - 1]; | ||
} | ||
Multinomial intervalDistr = new Multinomial(seed + 1); | ||
float MultiRange() | ||
{ | ||
int sampledInterval = intervalDistr.Sample(intervalSizes); | ||
return intervalFuncs[sampledInterval].Invoke(); | ||
} | ||
return MultiRange; | ||
} | ||
} | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
using System; | ||
using NUnit.Framework; | ||
using System.IO; | ||
using System.Collections.Generic; | ||
using UnityEngine; | ||
using Unity.MLAgents.SideChannels; | ||
|
||
namespace Unity.MLAgents.Tests | ||
{ | ||
public class SamplerTests | ||
{ | ||
const int k_Seed = 1337; | ||
const double k_Epsilon = 0.0001; | ||
EnvironmentParametersChannel m_Channel; | ||
|
||
public SamplerTests() | ||
{ | ||
m_Channel = SideChannelsManager.GetSideChannel<EnvironmentParametersChannel>(); | ||
// if running test on its own | ||
if (m_Channel == null) | ||
{ | ||
m_Channel = new EnvironmentParametersChannel(); | ||
SideChannelsManager.RegisterSideChannel(m_Channel); | ||
} | ||
} | ||
[Test] | ||
public void UniformSamplerTest() | ||
{ | ||
float min_value = 1.0f; | ||
float max_value = 2.0f; | ||
string parameter = "parameter1"; | ||
using (var outgoingMsg = new OutgoingMessage()) | ||
{ | ||
outgoingMsg.WriteString(parameter); | ||
// 1 indicates this meessage is a Sampler | ||
outgoingMsg.WriteInt32(1); | ||
outgoingMsg.WriteInt32(k_Seed); | ||
outgoingMsg.WriteInt32((int)SamplerType.Uniform); | ||
outgoingMsg.WriteFloat32(min_value); | ||
outgoingMsg.WriteFloat32(max_value); | ||
byte[] message = GetByteMessage(m_Channel, outgoingMsg); | ||
SideChannelsManager.ProcessSideChannelData(message); | ||
} | ||
Assert.AreEqual(1.208888f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); | ||
Assert.AreEqual(1.118017f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); | ||
} | ||
|
||
[Test] | ||
public void GaussianSamplerTest() | ||
{ | ||
float mean = 3.0f; | ||
float stddev = 0.2f; | ||
string parameter = "parameter2"; | ||
using (var outgoingMsg = new OutgoingMessage()) | ||
{ | ||
outgoingMsg.WriteString(parameter); | ||
// 1 indicates this meessage is a Sampler | ||
outgoingMsg.WriteInt32(1); | ||
outgoingMsg.WriteInt32(k_Seed); | ||
outgoingMsg.WriteInt32((int)SamplerType.Gaussian); | ||
outgoingMsg.WriteFloat32(mean); | ||
outgoingMsg.WriteFloat32(stddev); | ||
byte[] message = GetByteMessage(m_Channel, outgoingMsg); | ||
SideChannelsManager.ProcessSideChannelData(message); | ||
} | ||
Assert.AreEqual(2.936162f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); | ||
Assert.AreEqual(2.951348f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); | ||
} | ||
|
||
[Test] | ||
public void MultiRangeUniformSamplerTest() | ||
{ | ||
float[] intervals = new float[4]; | ||
intervals[0] = 1.2f; | ||
intervals[1] = 2f; | ||
intervals[2] = 3.2f; | ||
intervals[3] = 4.1f; | ||
string parameter = "parameter3"; | ||
using (var outgoingMsg = new OutgoingMessage()) | ||
{ | ||
outgoingMsg.WriteString(parameter); | ||
// 1 indicates this meessage is a Sampler | ||
outgoingMsg.WriteInt32(1); | ||
outgoingMsg.WriteInt32(k_Seed); | ||
outgoingMsg.WriteInt32((int)SamplerType.MultiRangeUniform); | ||
outgoingMsg.WriteFloatList(intervals); | ||
byte[] message = GetByteMessage(m_Channel, outgoingMsg); | ||
SideChannelsManager.ProcessSideChannelData(message); | ||
} | ||
Assert.AreEqual(3.387999f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); | ||
Assert.AreEqual(1.294413f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); | ||
} | ||
|
||
internal static byte[] GetByteMessage(SideChannel sideChannel, OutgoingMessage msg) | ||
{ | ||
byte[] message = msg.ToByteArray(); | ||
using (var memStream = new MemoryStream()) | ||
{ | ||
using (var binaryWriter = new BinaryWriter(memStream)) | ||
{ | ||
binaryWriter.Write(sideChannel.ChannelId.ToByteArray()); | ||
binaryWriter.Write(message.Length); | ||
binaryWriter.Write(message); | ||
} | ||
return memStream.ToArray(); | ||
} | ||
} | ||
} | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.