Skip to content

Creation of components through MLContext and cleanup (GcnNorm, LpNorm, RandomFourier, CustomStopWords, VectorWhiten, PCA) #2366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(BootstrapSamplingTransformer.Summary, typeof(BootstrapSamplingTransformer), typeof(BootstrapSamplingTransformer.Arguments), typeof(SignatureDataTransform),
[assembly: LoadableClass(BootstrapSamplingTransformer.Summary, typeof(BootstrapSamplingTransformer), typeof(BootstrapSamplingTransformer.Options), typeof(SignatureDataTransform),
BootstrapSamplingTransformer.UserName, "BootstrapSampleTransform", "BootstrapSample")]

[assembly: LoadableClass(BootstrapSamplingTransformer.Summary, typeof(BootstrapSamplingTransformer), null, typeof(SignatureLoadDataTransform),
Expand All @@ -36,7 +36,7 @@ internal static class Defaults
public const int PoolSize = 1000;
}

public sealed class Arguments : TransformInputBase
public sealed class Options : TransformInputBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether this is the out-of-bag sample, that is, all those rows that are not selected by the transform.",
ShortName = "comp")]
Expand Down Expand Up @@ -76,16 +76,16 @@ private static VersionInfo GetVersionInfo()
private readonly bool _shuffleInput;
private readonly int _poolSize;

public BootstrapSamplingTransformer(IHostEnvironment env, Arguments args, IDataView input)
public BootstrapSamplingTransformer(IHostEnvironment env, Options options, IDataView input)
: base(env, RegistrationName, input)
{
Host.CheckValue(args, nameof(args));
Host.CheckUserArg(args.PoolSize >= 0, nameof(args.PoolSize), "Cannot be negative");
Host.CheckValue(options, nameof(options));
Host.CheckUserArg(options.PoolSize >= 0, nameof(options.PoolSize), "Cannot be negative");

_complement = args.Complement;
_state = new TauswortheHybrid.State(args.Seed ?? (uint)Host.Rand.Next());
_shuffleInput = args.ShuffleInput;
_poolSize = args.PoolSize;
_complement = options.Complement;
_state = new TauswortheHybrid.State(options.Seed ?? (uint)Host.Rand.Next());
_shuffleInput = options.ShuffleInput;
_poolSize = options.PoolSize;
}

/// <summary>
Expand All @@ -103,7 +103,7 @@ public BootstrapSamplingTransformer(IHostEnvironment env,
uint? seed = null,
bool shuffleInput = Defaults.ShuffleInput,
int poolSize = Defaults.PoolSize)
: this(env, new Arguments() { Complement = complement, Seed = seed, ShuffleInput = shuffleInput, PoolSize = poolSize }, input)
: this(env, new Options() { Complement = complement, Seed = seed, ShuffleInput = shuffleInput, PoolSize = poolSize }, input)
{
}

Expand Down Expand Up @@ -242,7 +242,7 @@ protected override bool MoveNextCore()
internal static class BootstrapSample
{
[TlcModule.EntryPoint(Name = "Transforms.ApproximateBootstrapSampler", Desc = BootstrapSamplingTransformer.Summary, UserName = BootstrapSamplingTransformer.UserName, ShortName = BootstrapSamplingTransformer.RegistrationName)]
public static CommonOutputs.TransformOutput GetSample(IHostEnvironment env, BootstrapSamplingTransformer.Arguments input)
public static CommonOutputs.TransformOutput GetSample(IHostEnvironment env, BootstrapSamplingTransformer.Options input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(input, nameof(input));
Expand Down
11 changes: 6 additions & 5 deletions src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(RowShufflingTransformer.Summary, typeof(RowShufflingTransformer), typeof(RowShufflingTransformer.Arguments), typeof(SignatureDataTransform),
[assembly: LoadableClass(RowShufflingTransformer.Summary, typeof(RowShufflingTransformer), typeof(RowShufflingTransformer.Options), typeof(SignatureDataTransform),
"Shuffle Transform", "ShuffleTransform", "Shuffle", "shuf")]

[assembly: LoadableClass(RowShufflingTransformer.Summary, typeof(RowShufflingTransformer), null, typeof(SignatureLoadDataTransform),
Expand All @@ -30,7 +30,8 @@ namespace Microsoft.ML.Transforms
/// rows in the input cursor, and then, successively, the output cursor will yield one
/// of these rows and replace it with another row from the input.
/// </summary>
public sealed class RowShufflingTransformer : RowToRowTransformBase
[BestFriend]
internal sealed class RowShufflingTransformer : RowToRowTransformBase
{
private static class Defaults
{
Expand All @@ -39,7 +40,7 @@ private static class Defaults
public const bool ForceShuffle = false;
}

public sealed class Arguments
public sealed class Options
{
// REVIEW: A more intelligent heuristic, based on the expected size of the inputs, perhaps?
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The pool will have this many rows", ShortName = "rows")]
Expand Down Expand Up @@ -99,14 +100,14 @@ public RowShufflingTransformer(IHostEnvironment env,
int poolRows = Defaults.PoolRows,
bool poolOnly = Defaults.PoolOnly,
bool forceShuffle = Defaults.ForceShuffle)
: this(env, new Arguments() { PoolRows = poolRows, PoolOnly = poolOnly, ForceShuffle = forceShuffle }, input)
: this(env, new Options() { PoolRows = poolRows, PoolOnly = poolOnly, ForceShuffle = forceShuffle }, input)
{
}

/// <summary>
/// Public constructor corresponding to SignatureDataTransform.
/// </summary>
public RowShufflingTransformer(IHostEnvironment env, Arguments args, IDataView input)
public RowShufflingTransformer(IHostEnvironment env, Options args, IDataView input)
: base(env, RegistrationName, input)
{
Host.CheckValue(args, nameof(args));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public override IEnumerable<Subset> GetSubsets(Batch batch, Random rand)
for (int i = 0; i < Size; i++)
{
// REVIEW: Consider ways to reintroduce "balanced" samples.
var viewTrain = new BootstrapSamplingTransformer(Host, new BootstrapSamplingTransformer.Arguments(), Data.Data);
var viewTrain = new BootstrapSamplingTransformer(Host, new BootstrapSamplingTransformer.Options(), Data.Data);
var dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames());
yield return FeatureSelector.SelectFeatures(dataTrain, rand);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
{
Contracts.Assert(toOutput.Length == 1);

var infos = new VectorWhiteningTransformer.ColumnInfo[toOutput.Length];
var infos = new VectorWhiteningEstimator.ColumnInfo[toOutput.Length];
for (int i = 0; i < toOutput.Length; i++)
infos[i] = new VectorWhiteningTransformer.ColumnInfo(outputNames[toOutput[i]], inputNames[((OutPipelineColumn)toOutput[i]).Input], _kind, _eps, _maxRows, _pcaNum);
infos[i] = new VectorWhiteningEstimator.ColumnInfo(outputNames[toOutput[i]], inputNames[((OutPipelineColumn)toOutput[i]).Input], _kind, _eps, _maxRows, _pcaNum);

return new VectorWhiteningEstimator(env, infos);
}
Expand All @@ -63,18 +63,18 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
/// <param name="maxRows">Maximum number of rows used to train the transform.</param>
/// <param name="pcaNum">In case of PCA whitening, indicates the number of components to retain.</param>
public static Vector<float> PcaWhitening(this Vector<float> input,
float eps = VectorWhiteningTransformer.Defaults.Eps,
int maxRows = VectorWhiteningTransformer.Defaults.MaxRows,
int pcaNum = VectorWhiteningTransformer.Defaults.PcaNum)
float eps = VectorWhiteningEstimator.Defaults.Eps,
int maxRows = VectorWhiteningEstimator.Defaults.MaxRows,
int pcaNum = VectorWhiteningEstimator.Defaults.PcaNum)
=> new OutPipelineColumn(input, WhiteningKind.Pca, eps, maxRows, pcaNum);

/// <include file='../Microsoft.ML.HalLearners/doc.xml' path='doc/members/member[@name="Whitening"]/*'/>
/// <param name="input">The column to which the transform will be applied.</param>
/// <param name="eps">Whitening constant, prevents division by zero.</param>
/// <param name="maxRows">Maximum number of rows used to train the transform.</param>
public static Vector<float> ZcaWhitening(this Vector<float> input,
float eps = VectorWhiteningTransformer.Defaults.Eps,
int maxRows = VectorWhiteningTransformer.Defaults.MaxRows)
=> new OutPipelineColumn(input, WhiteningKind.Zca, eps, maxRows, VectorWhiteningTransformer.Defaults.PcaNum);
float eps = VectorWhiteningEstimator.Defaults.Eps,
int maxRows = VectorWhiteningEstimator.Defaults.MaxRows)
=> new OutPipelineColumn(input, WhiteningKind.Zca, eps, maxRows, VectorWhiteningEstimator.Defaults.PcaNum);
}
}
10 changes: 5 additions & 5 deletions src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(
/// </format>
/// </example>
public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.ProjectionTransforms catalog, string outputColumnName, string inputColumnName = null,
WhiteningKind kind = VectorWhiteningTransformer.Defaults.Kind,
float eps = VectorWhiteningTransformer.Defaults.Eps,
int maxRows = VectorWhiteningTransformer.Defaults.MaxRows,
int pcaNum = VectorWhiteningTransformer.Defaults.PcaNum)
WhiteningKind kind = VectorWhiteningEstimator.Defaults.Kind,
float eps = VectorWhiteningEstimator.Defaults.Eps,
int maxRows = VectorWhiteningEstimator.Defaults.MaxRows,
int pcaNum = VectorWhiteningEstimator.Defaults.PcaNum)
=> new VectorWhiteningEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, kind, eps, maxRows, pcaNum);

/// <summary>
Expand All @@ -124,7 +124,7 @@ public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.Proje
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="columns">Describes the parameters of the whitening process for each column pair.</param>
public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.ProjectionTransforms catalog, params VectorWhiteningTransformer.ColumnInfo[] columns)
public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.ProjectionTransforms catalog, params VectorWhiteningEstimator.ColumnInfo[] columns)
=> new VectorWhiteningEstimator(CatalogUtils.GetEnvironment(catalog), columns);

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedDa
idvToFeedTrain = idvToShuffle;
else
{
var shuffleArgs = new RowShufflingTransformer.Arguments
var shuffleArgs = new RowShufflingTransformer.Options
{
PoolOnly = false,
ForceShuffle = _options.Shuffle
Expand Down
Loading