Skip to content

Replace all ML.Transforms SubComponent usages with IComponentFactory. #700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ public static IComponentFactory<TArg1, TComponent> CreateFromFunction<TArg1, TCo
return new SimpleComponentFactory<TArg1, TComponent>(factory);
}

/// <summary>
/// Creates a component factory when we take two extra parameters (and an
/// <see cref="IHostEnvironment"/>) that simply wraps a delegate which creates the component.
/// </summary>
public static IComponentFactory<TArg1, TArg2, TComponent> CreateFromFunction<TArg1, TArg2, TComponent>(Func<IHostEnvironment, TArg1, TArg2, TComponent> factory)
{
return new SimpleComponentFactory<TArg1, TArg2, TComponent>(factory);
}

/// <summary>
/// Creates a component factory when we take three extra parameters (and an
/// <see cref="IHostEnvironment"/>) that simply wraps a delegate which creates the component.
Expand Down Expand Up @@ -124,6 +133,26 @@ public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1)
}
}

/// <summary>
/// A class for creating a component when we take one extra parameter
/// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which
/// creates the component.
/// </summary>
private sealed class SimpleComponentFactory<TArg1, TArg2, TComponent> : IComponentFactory<TArg1, TArg2, TComponent>
{
private readonly Func<IHostEnvironment, TArg1, TArg2, TComponent> _factory;

public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TArg2, TComponent> factory)
{
_factory = factory;
}

public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2)
{
return _factory(env, argument1, argument2);
}
}

/// <summary>
/// A class for creating a component when we take three extra parameters
/// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ private FoldResult RunFold(int fold)
}

// Train.
var predictor = TrainUtils.Train(host, ch, trainData, trainer, _trainer.Kind, validData,
var predictor = TrainUtils.Train(host, ch, trainData, trainer, validData,
_calibrator, _maxCalibrationExamples, _cacheData, _inputPredictor);

// Score.
Expand Down
13 changes: 6 additions & 7 deletions src/Microsoft.ML.Data/Commands/TrainCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ private void RunCore(IChannel ch, string cmd)
}
}

var predictor = TrainUtils.Train(Host, ch, data, trainer, _info.LoadNames[0], validData,
var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor);

using (var file = Host.CreateOutputFile(Args.OutputModelFile))
Expand Down Expand Up @@ -228,28 +228,27 @@ public static string MatchNameOrDefaultOrNull(IExceptionContext ectx, ISchema sc
#pragma warning restore MSML_ContractsNameUsesNameof
}

public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name,
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer,
ICalibratorTrainerFactory calibrator, int maxCalibrationExamples)
{
var caliTrainer = calibrator?.CreateComponent(env);
return TrainCore(env, ch, data, trainer, name, null, caliTrainer, maxCalibrationExamples, false);
return TrainCore(env, ch, data, trainer, null, caliTrainer, maxCalibrationExamples, false);
}

public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData,
IComponentFactory<ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null)
{
ICalibratorTrainer caliTrainer = calibrator?.CreateComponent(env);
return TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor);
return TrainCore(env, ch, data, trainer, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor);
}

private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData,
ICalibratorTrainer calibrator, int maxCalibrationExamples, bool? cacheData, IPredictor inputPredictor = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
ch.CheckValue(data, nameof(data));
ch.CheckValue(trainer, nameof(trainer));
ch.CheckNonEmpty(name, nameof(name));
ch.CheckValueOrNull(validData);
ch.CheckValueOrNull(inputPredictor);

Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ private void RunCore(IChannel ch, string cmd)
}
}

var predictor = TrainUtils.Train(Host, ch, data, trainer, _info.LoadNames[0], validData,
var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor);

IDataLoader testPipe;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/EntryPoints/InputBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ public static TOut Train<TArg, TOut>(IHost host, TArg input,
cachedRoleMappedData = new RoleMappedData(cacheView, roleMappedData.Schema.GetColumnRoleNames());
}

var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, "Train", calibrator, maxCalibrationExamples);
var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, calibrator, maxCalibrationExamples);
var output = new TOut() { PredictorModel = new PredictorModel(host, roleMappedData, input.TrainingData, predictor) };

ch.Done();
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrai
string feat;
string group;
var data = CreateDataFromArgs(ch, input, args, out feat, out group);
var predictor = TrainUtils.Train(host, ch, data, trainer, args.Trainer.Kind, null,
var predictor = TrainUtils.Train(host, ch, data, trainer, null,
args.Calibrator, args.MaxCalibrationExamples, null);

ch.Done();
Expand Down
3 changes: 1 addition & 2 deletions src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,7 @@ public static IEnumerable<KeyValuePair<ColumnRole, string>> LoadRoleMappingsOrNu
{
// REVIEW: Should really validate the schema here, and consider
// ignoring this stream if it isn't as expected.
var loaderSub = new SubComponent<IDataLoader, SignatureDataLoader>("Text");
var loader = loaderSub.CreateInstance(env,
var loader = new TextLoader(env, new TextLoader.Arguments(),
new RepositoryStreamWrapper(rep, DirTrainingInfo, RoleMappingFile));

using (var cursor = loader.GetRowCursor(c => true))
Expand Down
15 changes: 9 additions & 6 deletions src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Internal.Utilities;

[assembly: LoadableClass(LearnerFeatureSelectionTransform.Summary, typeof(IDataTransform), typeof(LearnerFeatureSelectionTransform), typeof(LearnerFeatureSelectionTransform.Arguments), typeof(SignatureDataTransform),
"Learner Feature Selection Transform", "LearnerFeatureSelectionTransform", "LearnerFeatureSelection")]
Expand All @@ -32,9 +33,11 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of slots to preserve", ShortName = "topk", SortOrder = 1)]
public int? NumSlotsToKeep;

[Argument(ArgumentType.Multiple, HelpText = "Filter", ShortName = "f", SortOrder = 1)]
public SubComponent<ITrainer<IPredictorWithFeatureWeights<Single>>, SignatureFeatureScorerTrainer> Filter =
new SubComponent<ITrainer<IPredictorWithFeatureWeights<Single>>, SignatureFeatureScorerTrainer>("SDCA");
[Argument(ArgumentType.Multiple, HelpText = "Filter", ShortName = "f", SortOrder = 1, SignatureType = typeof(SignatureFeatureScorerTrainer))]
public IComponentFactory<ITrainer<IPredictorWithFeatureWeights<Single>>> Filter =
ComponentFactoryUtils.CreateFromFunction(env =>
// ML.Transforms doesn't have a direct reference to ML.StandardLearners, so use ComponentCatalog to create the Filter
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// ML.Transforms doesn't have a direct reference to ML.StandardLearners, [](start = 20, length = 72)

since it's a learner feature selection, maybe make sense to move it to learners project? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've opened #726 for this discussion. We can decide what to do in that issue. For now, keeping it to use DI (like it was before) is a decent approach.


In reply to: 211692955 [](ancestors = 211692955)

ComponentCatalog.CreateInstance<ITrainer<IPredictorWithFeatureWeights<Single>>>(env, typeof(SignatureFeatureScorerTrainer), "SDCA", options: null));

[Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for features", ShortName = "feat,col", SortOrder = 3, Purpose = SpecialPurpose.ColumnName)]
public string FeatureColumn = DefaultColumnNames.Features;
Expand Down Expand Up @@ -283,7 +286,7 @@ private static void TrainCore(IHost host, IDataView input, Arguments args, ref V
using (var ch = host.Start("Train"))
{
ch.Trace("Constructing trainer");
ITrainer trainer = args.Filter.CreateInstance(host);
ITrainer trainer = args.Filter.CreateComponent(host);

IDataView view = input;

Expand All @@ -301,7 +304,7 @@ private static void TrainCore(IHost host, IDataView input, Arguments args, ref V
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, args.CustomColumn);
var data = new RoleMappedData(view, label, feature, group, weight, name, customCols);

var predictor = TrainUtils.Train(host, ch, data, trainer, args.Filter.Kind, null,
var predictor = TrainUtils.Train(host, ch, data, trainer, null,
null, 0, args.CacheData);

var rfs = predictor as IPredictorWithFeatureWeights<Single>;
Expand Down
33 changes: 18 additions & 15 deletions src/Microsoft.ML.Transforms/RffTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.CpuMath;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
Expand Down Expand Up @@ -39,11 +40,12 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of random Fourier features to create", ShortName = "dim")]
public int NewDim = Defaults.NewDim;

[Argument(ArgumentType.Multiple, HelpText = "which kernel to use?", ShortName = "kernel")]
public SubComponent<IFourierDistributionSampler, SignatureFourierDistributionSampler> MatrixGenerator =
new SubComponent<IFourierDistributionSampler, SignatureFourierDistributionSampler>(GaussianFourierSampler.LoadName);
[Argument(ArgumentType.Multiple, HelpText = "Which kernel to use?", ShortName = "kernel", SignatureType = typeof(SignatureFourierDistributionSampler))]
public IComponentFactory<Float, IFourierDistributionSampler> MatrixGenerator =
ComponentFactoryUtils.CreateFromFunction<Float, IFourierDistributionSampler>(
(env, avgDist) => new GaussianFourierSampler(env, new GaussianFourierSampler.Arguments(), avgDist));

[Argument(ArgumentType.AtMostOnce, HelpText = "create two features for every random Fourier frequency? (one for cos and one for sin)")]
[Argument(ArgumentType.AtMostOnce, HelpText = "Create two features for every random Fourier frequency? (one for cos and one for sin)")]
public bool UseSin = Defaults.UseSin;

[Argument(ArgumentType.LastOccurenceWins,
Expand All @@ -57,8 +59,8 @@ public sealed class Column : OneToOneColumn
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of random Fourier features to create", ShortName = "dim")]
public int? NewDim;

[Argument(ArgumentType.Multiple, HelpText = "which kernel to use?", ShortName = "kernel")]
public SubComponent<IFourierDistributionSampler, SignatureFourierDistributionSampler> MatrixGenerator;
[Argument(ArgumentType.Multiple, HelpText = "which kernel to use?", ShortName = "kernel", SignatureType = typeof(SignatureFourierDistributionSampler))]
public IComponentFactory<Float, IFourierDistributionSampler> MatrixGenerator;

[Argument(ArgumentType.AtMostOnce, HelpText = "create two features for every random Fourier frequency? (one for cos and one for sin)")]
public bool? UseSin;
Expand All @@ -81,7 +83,7 @@ public static Column Parse(string str)
public bool TryUnparse(StringBuilder sb)
{
Contracts.AssertValue(sb);
if (NewDim != null || MatrixGenerator.IsGood() || UseSin != null || Seed != null)
if (NewDim != null || MatrixGenerator != null || UseSin != null || Seed != null)
return false;
return TryUnparseCore(sb);
}
Expand Down Expand Up @@ -115,10 +117,10 @@ public TransformInfo(IHost host, Column item, Arguments args, int d, Float avgDi
_rand = seed.HasValue ? RandomUtils.Create(seed) : RandomUtils.Create(host.Rand);
_state = _rand.GetState();

var sub = item.MatrixGenerator;
if (!sub.IsGood())
sub = args.MatrixGenerator;
_matrixGenerator = sub.CreateInstance(host, avgDist);
var generator = item.MatrixGenerator;
if (generator == null)
generator = args.MatrixGenerator;
_matrixGenerator = generator.CreateComponent(host, avgDist);

int roundedUpD = RoundUp(NewDim, CfltAlign);
int roundedUpNumFeatures = RoundUp(SrcDim, CfltAlign);
Expand Down Expand Up @@ -417,12 +419,13 @@ private static Float[] Train(IHost host, ColInfo[] infos, Arguments args, IDataV
else
{
Float[] distances;

var sub = args.Column[iinfo].MatrixGenerator;
if (!sub.IsGood())
if (sub == null)
sub = args.MatrixGenerator;
var info = ComponentCatalog.GetLoadableClassInfo(sub);
bool gaussian = info != null && info.Type == typeof(GaussianFourierSampler);
// create a dummy generator in order to get its type.
// REVIEW this should be refactored. See https://github.com/dotnet/machinelearning/issues/699
var matrixGenerator = sub.CreateComponent(host, 1);
bool gaussian = matrixGenerator is GaussianFourierSampler;

// If the number of pairs is at most the maximum reservoir size / 2, go over all the pairs.
if (resLength < reservoirSize)
Expand Down
Loading