Skip to content

Replace SubComponent with IComponentFactory in ML.Ensemble #681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 81 additions & 44 deletions src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,41 +37,6 @@ public interface IComponentFactory<in TArg1, out TComponent> : IComponentFactory
TComponent CreateComponent(IHostEnvironment env, TArg1 argument1);
}

/// <summary>
/// A class for creating a component when we take one extra parameter
/// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which
/// creates the component.
/// </summary>
public class SimpleComponentFactory<TArg1, TComponent> : IComponentFactory<TArg1, TComponent>
{
private Func<IHostEnvironment, TArg1, TComponent> _factory;

public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TComponent> factory)
{
_factory = factory;
}

public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1)
{
return _factory(env, argument1);
}
}

public class SimpleComponentFactory<TComponent> : IComponentFactory<TComponent>
{
private Func<IHostEnvironment, TComponent> _factory;

public SimpleComponentFactory(Func<IHostEnvironment, TComponent> factory)
{
_factory = factory;
}

public TComponent CreateComponent(IHostEnvironment env)
{
return _factory(env);
}
}

/// <summary>
/// An interface for creating a component when we take two extra parameters (and an <see cref="IHostEnvironment"/>).
/// </summary>
Expand All @@ -89,22 +54,94 @@ public interface IComponentFactory<in TArg1, in TArg2, in TArg3, out TComponent>
}

/// <summary>
/// A class for creating a component when we take three extra parameters
/// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which
/// creates the component.
/// A utility class for creating <see cref="IComponentFactory"/> instances.
/// </summary>
public class SimpleComponentFactory<TArg1, TArg2, TArg3, TComponent> : IComponentFactory<TArg1, TArg2, TArg3, TComponent>
public static class ComponentFactoryUtils
{
private Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> _factory;
/// <summary>
/// Creates a component factory with no extra parameters (other than an <see cref="IHostEnvironment"/>)
/// that simply wraps a delegate which creates the component.
/// </summary>
public static IComponentFactory<TComponent> CreateFromFunction<TComponent>(Func<IHostEnvironment, TComponent> factory)
{
return new SimpleComponentFactory<TComponent>(factory);
}

/// <summary>
/// Creates a component factory when we take one extra parameter (and an
/// <see cref="IHostEnvironment"/>) that simply wraps a delegate which creates the component.
/// </summary>
public static IComponentFactory<TArg1, TComponent> CreateFromFunction<TArg1, TComponent>(Func<IHostEnvironment, TArg1, TComponent> factory)
{
return new SimpleComponentFactory<TArg1, TComponent>(factory);
}

/// <summary>
/// Creates a component factory when we take three extra parameters (and an
/// <see cref="IHostEnvironment"/>) that simply wraps a delegate which creates the component.
/// </summary>
public static IComponentFactory<TArg1, TArg2, TArg3, TComponent> CreateFromFunction<TArg1, TArg2, TArg3, TComponent>(Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> factory)
{
return new SimpleComponentFactory<TArg1, TArg2, TArg3, TComponent>(factory);
}

public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> factory)
/// <summary>
/// A class for creating a component with no extra parameters (other than an <see cref="IHostEnvironment"/>)
/// that simply wraps a delegate which creates the component.
/// </summary>
private sealed class SimpleComponentFactory<TComponent> : IComponentFactory<TComponent>
{
_factory = factory;
private readonly Func<IHostEnvironment, TComponent> _factory;

public SimpleComponentFactory(Func<IHostEnvironment, TComponent> factory)
{
_factory = factory;
}

public TComponent CreateComponent(IHostEnvironment env)
{
return _factory(env);
}
}

public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3)
/// <summary>
/// A class for creating a component when we take one extra parameter
/// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which
/// creates the component.
/// </summary>
private sealed class SimpleComponentFactory<TArg1, TComponent> : IComponentFactory<TArg1, TComponent>
{
return _factory(env, argument1, argument2, argument3);
private readonly Func<IHostEnvironment, TArg1, TComponent> _factory;

public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TComponent> factory)
{
_factory = factory;
}

public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1)
{
return _factory(env, argument1);
}
}

/// <summary>
/// A class for creating a component when we take three extra parameters
/// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which
/// creates the component.
/// </summary>
private sealed class SimpleComponentFactory<TArg1, TArg2, TArg3, TComponent> : IComponentFactory<TArg1, TArg2, TArg3, TComponent>
{
private readonly Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> _factory;

public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TArg2, TArg3, TComponent> factory)
{
_factory = factory;
}

public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3)
{
return _factory(env, argument1, argument2, argument3);
}
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ private void RunCore(IChannel ch, string cmd)
new[]
{
new KeyValuePair<string, IComponentFactory<IDataView, IDataTransform>>(
"", new SimpleComponentFactory<IDataView, IDataTransform>(
"", ComponentFactoryUtils.CreateFromFunction<IDataView, IDataTransform>(
(env, input) =>
{
var args = new GenerateNumberTransform.Arguments();
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Commands/ScoreCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ public static TScorerFactory GetScorerComponent(
};
}

return new SimpleComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>(factoryFunc);
return ComponentFactoryUtils.CreateFromFunction(factoryFunc);
}

/// <summary>
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
<ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
<ProjectReference Include="..\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
<ProjectReference Include="..\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />
Copy link
Contributor

@TomFinley TomFinley Aug 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[](start = 4, length = 84)

We need this due to the usage of FastTree as the default traniner for some ensemble classes. This could be fine, but I wonder if using something from StandardLearners might make a bit more sense, if we envision that someday FastTree will be factored into a separate nuget package. (Maybe we don't want that though, in which case this dependency is probably fine.) #Resolved

Copy link
Member Author

@eerhardt eerhardt Aug 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One advantage to this change is that it brings to light these dependencies. Before we had a weakly-typed dependency using strings and DI. So if we ever did refactor FastTree to be separate, we could have easily broken this dependency.

I've logged #682 for this issue. #Resolved

Copy link
Contributor

@TomFinley TomFinley Aug 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @eerhardt ... yes I expect we'll see some more of this "hidden dependency" becoming exposed as we move away from direct usage of subcomponents and dependency injection. I recall a similar issue in #446 with the normalizers being defined in Microsoft.ML.Transforms and used in Microsoft.ML.Data, despite the former project depending on the latter. 😄 #Resolved

</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
{
public abstract class BaseScalarStacking<TSigBase> : BaseStacking<Single, TSigBase>
public abstract class BaseScalarStacking : BaseStacking<Single>
{
internal BaseScalarStacking(IHostEnvironment env, string name, ArgumentsBase args)
: base(env, name, args)
Expand Down
16 changes: 7 additions & 9 deletions src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Threading.Tasks;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
Expand All @@ -15,7 +16,7 @@
namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
{
using ColumnRole = RoleMappedSchema.ColumnRole;
public abstract class BaseStacking<TOutput, TSigBase> : IStackingTrainer<TOutput>
public abstract class BaseStacking<TOutput> : IStackingTrainer<TOutput>
{
public abstract class ArgumentsBase
{
Expand All @@ -24,13 +25,10 @@ public abstract class ArgumentsBase
[TGUI(Label = "Validation Dataset Proportion")]
public Single ValidationDatasetProportion = 0.3f;

[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
[TGUI(Label = "Base predictor")]
public SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSigBase> BasePredictorType;
internal abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> GetPredictorFactory();
}

protected readonly SubComponent<ITrainer<IPredictorProducing<TOutput>>, TSigBase> BasePredictorType;
protected readonly IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorType;
protected readonly IHost Host;
protected IPredictorProducing<TOutput> Meta;

Expand All @@ -45,10 +43,10 @@ internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args)
Host.CheckUserArg(0 <= args.ValidationDatasetProportion && args.ValidationDatasetProportion < 1,
nameof(args.ValidationDatasetProportion),
"The validation proportion for stacking should be greater than or equal to 0 and less than 1");
Host.CheckUserArg(args.BasePredictorType.IsGood(), nameof(args.BasePredictorType));

ValidationDatasetProportion = args.ValidationDatasetProportion;
BasePredictorType = args.BasePredictorType;
BasePredictorType = args.GetPredictorFactory();
Host.CheckValue(BasePredictorType, nameof(BasePredictorType));
}

internal BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx)
Expand Down Expand Up @@ -187,7 +185,7 @@ public void Train(List<FeatureSubsetModel<IPredictorProducing<TOutput>>> models,
var view = bldr.GetDataView();
var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);

var trainer = BasePredictorType.CreateInstance(host);
var trainer = BasePredictorType.CreateComponent(host);
if (trainer.Info.NeedNormalization)
ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
Meta = trainer.Train(rmd);
Expand Down
20 changes: 17 additions & 3 deletions src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.FastTree;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.Model;

[assembly: LoadableClass(typeof(MultiStacking), typeof(MultiStacking.Arguments), typeof(SignatureCombiner),
Expand All @@ -20,7 +23,7 @@
namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
{
using TVectorPredictor = IPredictorProducing<VBuffer<Single>>;
public sealed class MultiStacking : BaseStacking<VBuffer<Single>, SignatureMultiClassClassifierTrainer>, ICanSaveModel, IMultiClassOutputCombiner
public sealed class MultiStacking : BaseStacking<VBuffer<Single>>, ICanSaveModel, IMultiClassOutputCombiner
{
public const string LoadName = "MultiStacking";
public const string LoaderSignature = "MultiStackingCombiner";
Expand All @@ -38,13 +41,24 @@ private static VersionInfo GetVersionInfo()
[TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
{
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))]
[TGUI(Label = "Base predictor")]
public IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorType;

internal override IComponentFactory<ITrainer<TVectorPredictor>> GetPredictorFactory() => BasePredictorType;

public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this);

public Arguments()
{
// REVIEW: Perhaps we can have a better non-parametetric learner.
BasePredictorType = new SubComponent<ITrainer<TVectorPredictor>, SignatureMultiClassClassifierTrainer>(
"OVA", "p=FastTreeBinaryClassification");
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
env => new Ova(env, new Ova.Arguments()
{
PredictorType = ComponentFactoryUtils.CreateFromFunction(
e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
}));
}
}

Expand Down
14 changes: 12 additions & 2 deletions src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.FastTree;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Model;

[assembly: LoadableClass(typeof(RegressionStacking), typeof(RegressionStacking.Arguments), typeof(SignatureCombiner),
Expand All @@ -19,7 +21,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
{
using TScalarPredictor = IPredictorProducing<Single>;

public sealed class RegressionStacking : BaseScalarStacking<SignatureRegressorTrainer>, IRegressionOutputCombiner, ICanSaveModel
Copy link
Contributor

@TomFinley TomFinley Aug 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SignatureRegressorTrainer [](start = 64, length = 25)

What are we going to do about this sort of restriction? I had suggested previously using something like ITrainer<IRegressionPredictor> or something like this, but I'm not sure what you thought of that idea. #Resolved

Copy link
Member Author

@eerhardt eerhardt Aug 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which restriction are you referring to? The Signature Type restriction, where we can't use typeof(TSig) in an attribute?

My hope is that some day we remove signature types all together. I haven't thought of/discovered a good replacement for them yet. Maybe an approach similar to ComponentKind which is just a string "Kind".

However, that still wouldn't help in this case, because we need a different signature/kind/etc for each concrete class, but we want to define the argument on the base class. #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant just restricting the DI and API usage to only consider actual, in this case, regression trainers, which is one of the functions of the signatures. Currently this is done in a fairly half-hearted fashion using this IPredictorProducing<float> and IPredictorProducing<VBuffer<float>>, but if we got rid of this and replaced it with, say, IRegressionPredictor and IBinaryClassificationPredictor and IMulticlassPredictor (or something like that), then we could just have this be c omponent factory producing ITrainer<IRegressionPredictor> in this case. I think MEF should work well with that restriction as well, from a DI perspective.


In reply to: 210463209 [](ancestors = 210463209)

public sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel
{
public const string LoadName = "RegressionStacking";
public const string LoaderSignature = "RegressionStacking";
Expand All @@ -37,9 +39,17 @@ private static VersionInfo GetVersionInfo()
[TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerFactory
{
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))]
[TGUI(Label = "Base predictor")]
public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType;

internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType;

public Arguments()
{
BasePredictorType = new SubComponent<ITrainer<TScalarPredictor>, SignatureRegressorTrainer>("FastTreeRegression");
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
env => new FastTreeRegressionTrainer(env, new FastTreeRegressionTrainer.Arguments()));
}

public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this);
Expand Down
15 changes: 12 additions & 3 deletions src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
using System;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.FastTree;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Model;

[assembly: LoadableClass(typeof(Stacking), typeof(Stacking.Arguments), typeof(SignatureCombiner), Stacking.UserName, Stacking.LoadName)]
Expand All @@ -16,7 +17,7 @@
namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
{
using TScalarPredictor = IPredictorProducing<Single>;
public sealed class Stacking : BaseScalarStacking<SignatureBinaryClassifierTrainer>, IBinaryOutputCombiner, ICanSaveModel
public sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel
{
public const string UserName = "Stacking";
public const string LoadName = "Stacking";
Expand All @@ -35,9 +36,17 @@ private static VersionInfo GetVersionInfo()
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFactory
{
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
[TGUI(Label = "Base predictor")]
public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType;

internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType;

public Arguments()
{
BasePredictorType = new SubComponent<ITrainer<TScalarPredictor>, SignatureBinaryClassifierTrainer>("FastTreeBinaryClassification");
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
env => new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()));
}

public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this);
Expand Down
Loading