Skip to content

Tree estimators #855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Sep 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Api/TypedCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ public ICursor GetRootCursor()
/// </summary>
public static class CursoringUtils
{
private const string NeedEnvObsoleteMessage = "This method is obsolete. Please use the overload that takes an additional 'env' argument. An environment can be created via new TlcEnvironment().";
private const string NeedEnvObsoleteMessage = "This method is obsolete. Please use the overload that takes an additional 'env' argument. An environment can be created via new LocalEnvironment().";

/// <summary>
/// Generate a strongly-typed cursorable wrapper of the <see cref="IDataView"/>.
Expand Down
57 changes: 57 additions & 0 deletions src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
[assembly: LoadableClass(typeof(RegressionPredictionTransformer<IPredictorProducing<float>>), typeof(RegressionPredictionTransformer), null, typeof(SignatureLoadModel),
"", RegressionPredictionTransformer.LoaderSignature)]

[assembly: LoadableClass(typeof(RankingPredictionTransformer<IPredictorProducing<float>>), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel),
"", RankingPredictionTransformer.LoaderSignature)]

namespace Microsoft.ML.Runtime.Data
{
public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel>, ICanSaveModel
Expand Down Expand Up @@ -301,6 +304,52 @@ private static VersionInfo GetVersionInfo()
}
}

public sealed class RankingPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RankingPredictionTransformer [](start = 24, length = 28)

Is the reason why we have two types that are identical in practically everything but name, so we can identify ranking estimators vs. regression estimators in a statically typed way?

Copy link
Contributor

@Zruty0 Zruty0 Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this transformer should also expose the group ID column name, at least that would be my belief


In reply to: 218214277 [](ancestors = 218214277)

Copy link
Contributor

@TomFinley TomFinley Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually thought about this, like labels group ids are only needed for training, right? So for prediction I don't think they should be.


In reply to: 218216192 [](ancestors = 218216192,218214277)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So keep it, or make the Regression one Generic and use it for both?


In reply to: 218216839 [](ancestors = 218216839,218216192,218214277)

where TModel : class, IPredictorProducing<float>
{
private readonly GenericScorer _scorer;

public RankingPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
{
var schema = new RoleMappedSchema(inputSchema, null, featureColumn);
_scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema);
}

internal RankingPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer<TModel>)), ctx)
{
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn);
_scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
}

public override IDataView Transform(IDataView input)
{
Host.CheckValue(input, nameof(input));
return _scorer.ApplyToData(Host, input);
}

protected override void SaveCore(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.SetVersionInfo(GetVersionInfo());

// *** Binary format ***
// <base info>
base.SaveCore(ctx);
}

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "MC RANK",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: RankingPredictionTransformer.LoaderSignature);
}
}

internal static class BinaryPredictionTransformer
{
public const string LoaderSignature = "BinaryPredXfer";
Expand All @@ -324,4 +373,12 @@ internal static class RegressionPredictionTransformer
public static RegressionPredictionTransformer<IPredictorProducing<float>> Create(IHostEnvironment env, ModelLoadContext ctx)
=> new RegressionPredictionTransformer<IPredictorProducing<float>>(env, ctx);
}

internal static class RankingPredictionTransformer
{
public const string LoaderSignature = "RankingPredXfer";

public static RankingPredictionTransformer<IPredictorProducing<float>> Create(IHostEnvironment env, ModelLoadContext ctx)
=> new RankingPredictionTransformer<IPredictorProducing<float>>(env, ctx);
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public Arguments()
env => new Ova(env, new Ova.Arguments()
{
PredictorType = ComponentFactoryUtils.CreateFromFunction(
e => new AveragedPerceptronTrainer(e, new AveragedPerceptronTrainer.Arguments()))
e => new FastTreeBinaryClassificationTrainer(e, DefaultColumnNames.Label, DefaultColumnNames.Features))
Copy link
Contributor

@TomFinley TomFinley Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FastTreeBinaryClassificationTrainer [](start = 37, length = 35)

I'd really rather we didn't. This seems to fit into the same bucket as the discussion on #682. That ensembling should have a dependency on FastTree merely because we have a default does not make sense to me. If someone wants to use stacking, that's great, but they need to specify the learners. #Pending

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But maybe we can hold off for right now.


In reply to: 218215145 [](ancestors = 218215145)

Copy link
Member Author

@sfilipi sfilipi Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's do that separately, when we shape the ensembles to take in the arguments in the constructor.


In reply to: 218215323 [](ancestors = 218215323,218215145)

}));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF
public Arguments()
{
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
env => new FastTreeRegressionTrainer(env, new FastTreeRegressionTrainer.Arguments()));
env => new FastTreeRegressionTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features));
}

public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this);
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.FastTree;
Expand Down Expand Up @@ -46,7 +47,7 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto
public Arguments()
{
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
env => new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()));
env => new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features));
}

public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this);
Expand Down
14 changes: 11 additions & 3 deletions src/Microsoft.ML.FastTree/BoostingFastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,25 @@

using System;
using System.Linq;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.FastTree.Internal;
using Microsoft.ML.Runtime.Internal.Internallearn;

namespace Microsoft.ML.Runtime.FastTree
{
public abstract class BoostingFastTreeTrainerBase<TArgs, TPredictor> : FastTreeTrainerBase<TArgs, TPredictor>
public abstract class BoostingFastTreeTrainerBase<TArgs, TTransformer, TModel> : FastTreeTrainerBase<TArgs, TTransformer, TModel>
where TTransformer : IPredictionTransformer<TModel>
where TArgs : BoostedTreeArgs, new()
where TPredictor : IPredictorProducing<Float>
where TModel : IPredictorProducing<Float>
{
public BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args) : base(env, args)
protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) : base(env, args, label)
{
}

protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
: base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings)
{
}

Expand Down
98 changes: 76 additions & 22 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.Conversion;
Expand Down Expand Up @@ -43,10 +44,11 @@ internal static class FastTreeShared
public static readonly object TrainLock = new object();
}

public abstract class FastTreeTrainerBase<TArgs, TPredictor> :
TrainerBase<TPredictor>
public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
TrainerEstimatorBase<TTransformer, TModel>
where TTransformer: IPredictionTransformer<TModel>
where TArgs : TreeArgs, new()
where TPredictor : IPredictorProducing<Float>
where TModel : IPredictorProducing<Float>
{
protected readonly TArgs Args;
protected readonly bool AllowGC;
Expand Down Expand Up @@ -87,34 +89,53 @@ public abstract class FastTreeTrainerBase<TArgs, TPredictor> :

private protected virtual bool NeedCalibration => false;

private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args)
: base(env, RegisterName)
/// <summary>
/// Constructor to use when instantiating the classing deriving from here through the API.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(featureColumn), label, MakeWeightColumn(weightColumn))
{
Args = new TArgs();

//apply the advanced args, if the user supplied any
advancedSettings?.Invoke(Args);
Args.LabelColumn = label.Name;

if (weightColumn != null)
Args.WeightColumn = weightColumn;

if (groupIdColumn != null)
Args.GroupIdColumn = groupIdColumn;

// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
// Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.
Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true);
// REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment
// with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from LocalEnvironment.
AllowGC = (env is HostEnvironmentBase<LocalEnvironment>);

Initialize(env);
}

/// <summary>
/// Legacy constructor that is used when invoking the classsing deriving from this, through maml.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.WeightColumn))
{
Host.CheckValue(args, nameof(args));
Args = args;
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
// Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.
Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true);
int numThreads = Args.NumThreads ?? Environment.ProcessorCount;
if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
{
using (var ch = Host.Start("FastTreeTrainerBase"))
{
numThreads = Host.ConcurrencyFactor;
ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor "
+ "setting of the environment. Using {0} training threads instead.", numThreads);
ch.Done();
}
}
ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer();
ParallelTraining.InitEnvironment();
// REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment
// with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from ConsoleEnvironment.
AllowGC = (env is HostEnvironmentBase<ConsoleEnvironment>);
Tests = new List<Test>();
// with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from LocalEnvironment.
AllowGC = (env is HostEnvironmentBase<LocalEnvironment>);

InitializeThreads(numThreads);
Initialize(env);
}

protected abstract void PrepareLabels(IChannel ch);
Expand All @@ -133,6 +154,39 @@ protected virtual Float GetMaxLabel()
return Float.PositiveInfinity;
}

private static SchemaShape.Column MakeWeightColumn(string weightColumn)
{
if (weightColumn == null)
return null;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}

private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
{
return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
}

private void Initialize(IHostEnvironment env)
{
int numThreads = Args.NumThreads ?? Environment.ProcessorCount;
if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
{
using (var ch = Host.Start("FastTreeTrainerBase"))
{
numThreads = Host.ConcurrencyFactor;
ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor "
+ "setting of the environment. Using {0} training threads instead.", numThreads);
ch.Done();
}
}
ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer();
ParallelTraining.InitEnvironment();

Tests = new List<Test>();

InitializeThreads(numThreads);
}

protected void ConvertData(RoleMappedData trainData)
{
trainData.Schema.Schema.TryGetColumnIndex(DefaultColumnNames.Features, out int featureIndex);
Expand Down
Loading