Skip to content

Change IModelCombiner to not be generic, and add unit tests #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 23, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/Microsoft.ML.Core/Prediction/ITrainer.cs
Original file line number Diff line number Diff line change
@@ -95,10 +95,11 @@ public static TPredictor Train<TPredictor>(this ITrainer<TPredictor> trainer, Ro
=> trainer.Train(new TrainContext(trainData));
}

// A trainer can optionally implement this to indicate it can combine multiple models into a single predictor.
public interface IModelCombiner<TModel, TPredictor>
where TPredictor : IPredictor
/// <summary>
/// An interface that combines multiple predictors into a single predictor.
/// </summary>
public interface IModelCombiner
{
TPredictor CombineModels(IEnumerable<TModel> models);
IPredictor CombineModels(IEnumerable<IPredictor> models);
}
}
21 changes: 17 additions & 4 deletions src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
Original file line number Diff line number Diff line change
@@ -19,6 +19,9 @@
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) },
EnsembleTrainer.UserNameValue, EnsembleTrainer.LoadNameValue, "pe", "ParallelEnsemble")]

[assembly: LoadableClass(typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments), typeof(SignatureModelCombiner),
"Binary Classification Ensemble Model Combiner", EnsembleTrainer.LoadNameValue, "pe", "ParallelEnsemble")]

namespace Microsoft.ML.Runtime.Ensemble
{
using TDistPredictor = IDistPredictorProducing<Single, Single>;
@@ -28,7 +31,7 @@ namespace Microsoft.ML.Runtime.Ensemble
/// </summary>
public sealed class EnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor,
IBinarySubModelSelector, IBinaryOutputCombiner>,
IModelCombiner<TScalarPredictor, TScalarPredictor>
IModelCombiner
{
public const string LoadNameValue = "WeightedEnsemble";
public const string UserNameValue = "Parallel Ensemble (bagging, stacking, etc)";
@@ -70,6 +73,12 @@ public EnsembleTrainer(IHostEnvironment env, Arguments args)
Combiner = args.OutputCombiner.CreateComponent(Host);
}

private EnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind predictionKind)
: this(env, args)
{
Host.CheckParam(predictionKind == PredictionKind.BinaryClassification, nameof(PredictionKind));
}

public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

private protected override TScalarPredictor CreatePredictor(List<FeatureSubsetModel<TScalarPredictor>> models)
@@ -79,18 +88,22 @@ private protected override TScalarPredictor CreatePredictor(List<FeatureSubsetMo
return new EnsemblePredictor(Host, PredictionKind, CreateModels<TScalarPredictor>(models), Combiner);
}

public TScalarPredictor CombineModels(IEnumerable<TScalarPredictor> models)
public IPredictor CombineModels(IEnumerable<IPredictor> models)
{
Host.CheckValue(models, nameof(models));

var combiner = _outputCombiner.CreateComponent(Host);
var p = models.First();

if (p is TDistPredictor)
{
Host.CheckParam(models.All(m => m is TDistPredictor), nameof(models));
return new EnsembleDistributionPredictor(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel<TDistPredictor>((TDistPredictor)k)).ToArray(), combiner);
}

Host.CheckParam(models.All(m => m is TScalarPredictor), nameof(models));
return new EnsemblePredictor(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel<TScalarPredictor>(k)).ToArray(), combiner);
models.Select(k => new FeatureSubsetModel<TScalarPredictor>((TScalarPredictor)k)).ToArray(), combiner);
}
}
}
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@
using Microsoft.ML.Runtime.Ensemble;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.Ensemble.Selector;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Learners;

@@ -22,6 +21,9 @@
MulticlassDataPartitionEnsembleTrainer.UserNameValue,
MulticlassDataPartitionEnsembleTrainer.LoadNameValue)]

[assembly: LoadableClass(typeof(MulticlassDataPartitionEnsembleTrainer), typeof(MulticlassDataPartitionEnsembleTrainer.Arguments),
typeof(SignatureModelCombiner), "Multiclass Classification Ensemble Model Combiner", MulticlassDataPartitionEnsembleTrainer.LoadNameValue)]

namespace Microsoft.ML.Runtime.Ensemble
{
using TVectorPredictor = IPredictorProducing<VBuffer<Single>>;
@@ -31,7 +33,7 @@ namespace Microsoft.ML.Runtime.Ensemble
public sealed class MulticlassDataPartitionEnsembleTrainer :
EnsembleTrainerBase<VBuffer<Single>, EnsembleMultiClassPredictor,
IMulticlassSubModelSelector, IMultiClassOutputCombiner>,
IModelCombiner<TVectorPredictor, TVectorPredictor>
IModelCombiner
{
public const string LoadNameValue = "WeightedEnsembleMulticlass";
public const string UserNameValue = "Multi-class Parallel Ensemble (bagging, stacking, etc)";
@@ -72,19 +74,28 @@ public MulticlassDataPartitionEnsembleTrainer(IHostEnvironment env, Arguments ar
Combiner = args.OutputCombiner.CreateComponent(Host);
}

private MulticlassDataPartitionEnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind predictionKind)
: this(env, args)
{
Host.CheckParam(predictionKind == PredictionKind.MultiClassClassification, nameof(PredictionKind));
}

public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;

private protected override EnsembleMultiClassPredictor CreatePredictor(List<FeatureSubsetModel<TVectorPredictor>> models)
{
return new EnsembleMultiClassPredictor(Host, CreateModels<TVectorPredictor>(models), Combiner as IMultiClassOutputCombiner);
}

public TVectorPredictor CombineModels(IEnumerable<TVectorPredictor> models)
public IPredictor CombineModels(IEnumerable<IPredictor> models)
{
var predictor = new EnsembleMultiClassPredictor(Host,
models.Select(k => new FeatureSubsetModel<TVectorPredictor>(k)).ToArray(),
_outputCombiner.CreateComponent(Host));
Host.CheckValue(models, nameof(models));
Host.CheckParam(models.All(m => m is TVectorPredictor), nameof(models));

var combiner = _outputCombiner.CreateComponent(Host);
var predictor = new EnsembleMultiClassPredictor(Host,
models.Select(k => new FeatureSubsetModel<TVectorPredictor>((TVectorPredictor)k)).ToArray(),
combiner);
return predictor;
}
}
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@
using Microsoft.ML.Runtime.Ensemble;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.Ensemble.Selector;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Learners;

@@ -21,12 +20,15 @@
RegressionEnsembleTrainer.UserNameValue,
RegressionEnsembleTrainer.LoadNameValue)]

[assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments), typeof(SignatureModelCombiner),
"Regression Ensemble Model Combiner", RegressionEnsembleTrainer.LoadNameValue)]

namespace Microsoft.ML.Runtime.Ensemble
{
using TScalarPredictor = IPredictorProducing<Single>;
public sealed class RegressionEnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor,
IRegressionSubModelSelector, IRegressionOutputCombiner>,
IModelCombiner<TScalarPredictor, TScalarPredictor>
IModelCombiner
{
public const string LoadNameValue = "EnsembleRegression";
public const string UserNameValue = "Regression Ensemble (bagging, stacking, etc)";
@@ -66,20 +68,29 @@ public RegressionEnsembleTrainer(IHostEnvironment env, Arguments args)
Combiner = args.OutputCombiner.CreateComponent(Host);
}

private RegressionEnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind predictionKind)
: this(env, args)
{
Host.CheckParam(predictionKind == PredictionKind.Regression, nameof(PredictionKind));
}

public override PredictionKind PredictionKind => PredictionKind.Regression;

private protected override TScalarPredictor CreatePredictor(List<FeatureSubsetModel<TScalarPredictor>> models)
{
return new EnsemblePredictor(Host, PredictionKind, CreateModels<TScalarPredictor>(models), Combiner);
}

public TScalarPredictor CombineModels(IEnumerable<TScalarPredictor> models)
public IPredictor CombineModels(IEnumerable<IPredictor> models)
{
Host.CheckValue(models, nameof(models));
Host.CheckParam(models.All(m => m is TScalarPredictor), nameof(models));

var combiner = _outputCombiner.CreateComponent(Host);
var p = models.First();

var predictor = new EnsemblePredictor(Host, p.PredictionKind,
models.Select(k => new FeatureSubsetModel<TScalarPredictor>(k)).ToArray(), combiner);
models.Select(k => new FeatureSubsetModel<TScalarPredictor>((TScalarPredictor)k)).ToArray(), combiner);

return predictor;
}
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@

namespace Microsoft.ML.Runtime.FastTree.Internal
{
public sealed class TreeEnsembleCombiner : IModelCombiner<IPredictorProducing<float>, IPredictorProducing<float>>
public sealed class TreeEnsembleCombiner : IModelCombiner
{
private readonly IHost _host;
private readonly PredictionKind _kind;
@@ -32,7 +32,7 @@ public TreeEnsembleCombiner(IHostEnvironment env, PredictionKind kind)
}
}

public IPredictorProducing<float> CombineModels(IEnumerable<IPredictorProducing<float>> models)
public IPredictor CombineModels(IEnumerable<IPredictor> models)
{
_host.CheckValue(models, nameof(models));

Loading