-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Tree-based featurization #3812
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tree-based featurization #3812
Changes from 3 commits
17248d3
a2f1d6c
33d0ee0
9658991
f529f1d
9c4d801
e7b84dd
5d8215a
ce4378f
618179d
49fe1d7
2197391
b00be93
bbeb17f
4906d0b
f7ab9ab
a8c0c6e
dbd5dac
1f261c5
241b3ad
6850b8e
d337aa5
4ea7bf6
7b2d654
b8a3ba8
d1d6813
cc2d531
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System.Linq; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Runtime; | ||
|
||
namespace Microsoft.ML.Trainers.FastTree | ||
{ | ||
/// <summary> | ||
/// This class encapsulates the common behavior of all tree-based featurizers such as <see cref="FastTreeBinaryFeaturizationEstimator"/>, | ||
/// <see cref="FastForestBinaryFeaturizationEstimator"/>, <see cref="FastTreeRegressionFeaturizationEstimator"/>, | ||
/// <see cref="FastForestRegressionFeaturizationEstimator"/>, and <see cref="PretrainedTreeFeaturizationEstimator"/>. | ||
/// All tree-based featurizers share the same output schema computed by <see cref="GetOutputSchema(SchemaShape)"/>. All tree-based featurizers | ||
/// requires an input feature column name and a suffix for all output columns. The <see cref="ITransformer"/> returned by <see cref="Fit(IDataView)"/> | ||
/// produces three columns: (1) the prediction values of all trees, (2) the IDs of leaves the input feature vector falling into, and (3) | ||
/// the binary vector which encodes the paths to those destination leaves. | ||
/// </summary> | ||
public abstract class FeaturizationEstimatorBase : IEstimator<TreeEnsembleFeaturizationTransformer> | ||
{ | ||
/// <summary> | ||
/// The common options of tree-based featurizations such as <see cref="FastTreeBinaryFeaturizationEstimator"/>, <see cref="FastForestBinaryFeaturizationEstimator"/>, | ||
/// <see cref="FastTreeRegressionFeaturizationEstimator"/>, <see cref="FastForestRegressionFeaturizationEstimator"/>, and <see cref="PretrainedTreeFeaturizationEstimator"/>. | ||
/// </summary> | ||
public class CommonOptions | ||
wschin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
/// <summary> | ||
/// The name of feature column in the <see cref="IDataView"/> when calling <see cref="Fit(IDataView)"/>. | ||
/// The column type must be a vector of <see cref="System.Single"/>. | ||
/// </summary> | ||
public string InputColumnName; | ||
|
||
/// <summary> | ||
/// The estimator has three output columns. Their names would be "Trees" + <see cref="OutputColumnsSuffix"/>, | ||
/// "Leaves" + <see cref="OutputColumnsSuffix"/>, and "Paths" + <see cref="OutputColumnsSuffix"/>. If <see cref="OutputColumnsSuffix"/> | ||
/// is <see langword="null"/>, the output names would be "Trees", "Leaves", and "Paths". | ||
/// </summary> | ||
public string OutputColumnsSuffix; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We went away from magic strings in the TextTransform. Previously with For the estimators API we have users directly enter the column name for the tokens. We may want to do the same for the Trees/Leaves/Paths of the TreeFeat. Perhaps: #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And don't add that column if they empty or equal to null. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As further background on the PR I was referencing... Conversation about TextTransform:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now we can do optional output columns and custom output column names. Please see tests for examples (or wait for formal API samples). #Resolved |
||
}; | ||
|
||
/// <summary> | ||
/// Feature column to apply tree-based featurization. Note that <see cref="FeatureColumnName"/> is not necessary to be the same as | ||
/// the feature column used to train the tree model. | ||
/// </summary> | ||
private protected readonly string FeatureColumnName; | ||
|
||
/// <summary> | ||
/// See <see cref="CommonOptions.OutputColumnsSuffix"/>. | ||
/// </summary> | ||
private protected readonly string OutputColumnSuffix; | ||
|
||
/// <summary> | ||
/// Environment of this instance. It controls error throwing and other enviroment settings. | ||
/// </summary> | ||
private protected readonly IHostEnvironment Env; | ||
|
||
private protected FeaturizationEstimatorBase(IHostEnvironment env, CommonOptions options) | ||
{ | ||
Env = env; | ||
FeatureColumnName = options.InputColumnName; | ||
OutputColumnSuffix = options.OutputColumnsSuffix; | ||
} | ||
|
||
/// <summary> | ||
/// All derived class should implement <see cref="PrepareModel(IDataView)"/> to tell how to get a <see cref="TreeEnsembleModelParameters"/> | ||
/// out from <paramref name="input"/> and parameters inside this or derived classes. | ||
/// </summary> | ||
/// <param name="input">Data used to train a tree model.</param> | ||
/// <returns>The trees used in <see cref="TreeEnsembleFeaturizationTransformer"/>.</returns> | ||
private protected abstract TreeEnsembleModelParameters PrepareModel(IDataView input); | ||
|
||
/// <summary> | ||
/// Produce a <see cref="TreeEnsembleModelParameters"/> which maps the column called <see cref="CommonOptions.InputColumnName"/> in <paramref name="input"/> | ||
/// to three output columns. | ||
/// </summary> | ||
public TreeEnsembleFeaturizationTransformer Fit(IDataView input) | ||
{ | ||
var model = PrepareModel(input); | ||
return new TreeEnsembleFeaturizationTransformer(Env, input.Schema, | ||
input.Schema[FeatureColumnName], model, OutputColumnSuffix); | ||
} | ||
|
||
/// <summary> | ||
/// <see cref="PretrainedTreeFeaturizationEstimator"/> adds three float-vector columns into <paramref name="inputSchema"/>. | ||
/// Given a feature vector column, the added columns are the prediction values of all trees, the leaf IDs the feature | ||
/// vector falls into, and the paths to those leaves. | ||
/// </summary> | ||
/// <param name="inputSchema">A schema which contains a feature column. Note that feature column name can be specified | ||
/// by <see cref="CommonOptions.InputColumnName"/>.</param> | ||
/// <returns>Output <see cref="SchemaShape"/> produced by <see cref="PretrainedTreeFeaturizationEstimator"/>.</returns> | ||
public SchemaShape GetOutputSchema(SchemaShape inputSchema) | ||
{ | ||
Env.CheckValue(inputSchema, nameof(inputSchema)); | ||
|
||
if (!inputSchema.TryFindColumn(FeatureColumnName, out var col)) | ||
throw Env.ExceptSchemaMismatch(nameof(inputSchema), "input", FeatureColumnName); | ||
|
||
var result = inputSchema.ToDictionary(x => x.Name); | ||
|
||
var treeColumnName = OutputColumnSuffix != null ? OutputColumnSuffix + "Trees" : "Trees"; | ||
result[treeColumnName] = new SchemaShape.Column(treeColumnName, | ||
SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); | ||
|
||
var leafColumnName = OutputColumnSuffix != null ? OutputColumnSuffix + "Leaves" : "Leaves"; | ||
result[leafColumnName] = new SchemaShape.Column(leafColumnName, | ||
SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); | ||
|
||
var pathColumnName = OutputColumnSuffix != null ? OutputColumnSuffix + "Paths" : "Paths"; | ||
result[pathColumnName] = new SchemaShape.Column(pathColumnName, | ||
SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); | ||
|
||
return new SchemaShape(result.Values); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// A <see cref="IEstimator{TTransformer}"/> which takes a trained <see cref="TreeEnsembleModelParameters"/> and calling its | ||
/// <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> produces a featurizer based on the trained model. | ||
/// </summary> | ||
public sealed class PretrainedTreeFeaturizationEstimator : FeaturizationEstimatorBase | ||
{ | ||
public sealed class Options : FeaturizationEstimatorBase.CommonOptions | ||
{ | ||
public TreeEnsembleModelParameters ModelParameters; | ||
}; | ||
|
||
private TreeEnsembleModelParameters _modelParameters; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Should this be readonly or something similar to make sure it is not altered? |
||
|
||
internal PretrainedTreeFeaturizationEstimator(IHostEnvironment env, Options options) : base(env, options) | ||
{ | ||
_modelParameters = options.ModelParameters; | ||
} | ||
|
||
/// <summary> | ||
/// Produce the <see cref="TreeEnsembleModelParameters"/> for tree-based feature engineering. This function does not | ||
/// invoke training procedure and just returns the pre-trained model passed in via <see cref="Options.ModelParameters"/>. | ||
/// </summary> | ||
private protected override TreeEnsembleModelParameters PrepareModel(IDataView input) => _modelParameters; | ||
} | ||
|
||
public sealed class FastTreeBinaryFeaturizationEstimator : FeaturizationEstimatorBase | ||
{ | ||
private readonly FastTreeBinaryTrainer.Options _trainerOptions; | ||
|
||
public sealed class Options : CommonOptions | ||
{ | ||
public FastTreeBinaryTrainer.Options TrainerOptions; | ||
} | ||
|
||
internal FastTreeBinaryFeaturizationEstimator(IHostEnvironment env, Options options) | ||
: base(env, options) | ||
{ | ||
_trainerOptions = options.TrainerOptions; | ||
} | ||
|
||
private protected override TreeEnsembleModelParameters PrepareModel(IDataView input) | ||
{ | ||
var trainer = new FastTreeBinaryTrainer(Env, _trainerOptions); | ||
var trained = trainer.Fit(input); | ||
return trained.Model.SubModel; | ||
} | ||
} | ||
|
||
public sealed class FastTreeRegressionFeaturizationEstimator : FeaturizationEstimatorBase | ||
{ | ||
private readonly FastTreeRegressionTrainer.Options _trainerOptions; | ||
|
||
public sealed class Options : CommonOptions | ||
{ | ||
public FastTreeRegressionTrainer.Options TrainerOptions; | ||
} | ||
|
||
internal FastTreeRegressionFeaturizationEstimator(IHostEnvironment env, Options options) | ||
: base(env, options) | ||
{ | ||
_trainerOptions = options.TrainerOptions; | ||
} | ||
|
||
private protected override TreeEnsembleModelParameters PrepareModel(IDataView input) | ||
{ | ||
var trainer = new FastTreeRegressionTrainer(Env, _trainerOptions); | ||
var trained = trainer.Fit(input); | ||
return trained.Model; | ||
} | ||
} | ||
|
||
public sealed class FastForestBinaryFeaturizationEstimator : FeaturizationEstimatorBase | ||
{ | ||
private readonly FastForestBinaryTrainer.Options _trainerOptions; | ||
|
||
public sealed class Options : CommonOptions | ||
{ | ||
public FastForestBinaryTrainer.Options TrainerOptions; | ||
} | ||
|
||
internal FastForestBinaryFeaturizationEstimator(IHostEnvironment env, Options options) | ||
: base(env, options) | ||
{ | ||
_trainerOptions = options.TrainerOptions; | ||
} | ||
|
||
private protected override TreeEnsembleModelParameters PrepareModel(IDataView input) | ||
{ | ||
var trainer = new FastForestBinaryTrainer(Env, _trainerOptions); | ||
var trained = trainer.Fit(input); | ||
return trained.Model; | ||
} | ||
} | ||
|
||
public sealed class FastForestRegressionFeaturizationEstimator : FeaturizationEstimatorBase | ||
{ | ||
private readonly FastForestRegressionTrainer.Options _trainerOptions; | ||
|
||
public sealed class Options : CommonOptions | ||
{ | ||
public FastForestRegressionTrainer.Options TrainerOptions; | ||
} | ||
|
||
internal FastForestRegressionFeaturizationEstimator(IHostEnvironment env, Options options) | ||
: base(env, options) | ||
{ | ||
_trainerOptions = options.TrainerOptions; | ||
} | ||
|
||
private protected override TreeEnsembleModelParameters PrepareModel(IDataView input) | ||
{ | ||
var trainer = new FastForestRegressionTrainer(Env, _trainerOptions); | ||
var trained = trainer.Fit(input); | ||
return trained.Model; | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System.Collections.Generic; | ||
using Microsoft.ML; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Data.IO; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Trainers.FastTree; | ||
|
||
[assembly: LoadableClass(typeof(TreeEnsembleFeaturizationTransformer), typeof(TreeEnsembleFeaturizationTransformer), | ||
null, typeof(SignatureLoadModel), "", TreeEnsembleFeaturizationTransformer.LoaderSignature)] | ||
|
||
namespace Microsoft.ML.Trainers.FastTree | ||
{ | ||
public sealed class TreeEnsembleFeaturizationTransformer : PredictionTransformerBase<TreeEnsembleModelParameters> | ||
{ | ||
internal const string LoaderSignature = "TreeEnseFeat"; | ||
private readonly TreeEnsembleFeaturizerBindableMapper.Arguments _scorerArgs; | ||
private readonly DataViewSchema.DetachedColumn _featureDetachedColumn; | ||
private readonly string _outputColumnSuffix; | ||
|
||
/// <summary> | ||
/// Check if <see cref="_featureDetachedColumn"/> is compatible with <paramref name="inspectedFeatureColumn"/>. | ||
/// </summary> | ||
/// <param name="inspectedFeatureColumn">A column checked against <see cref="_featureDetachedColumn"/>.</param> | ||
private void CheckFeatureColumnCompatibility(DataViewSchema.Column inspectedFeatureColumn) | ||
{ | ||
string nameErrorMessage = $"The column called {inspectedFeatureColumn.Name} does not match the expected " + | ||
$"feature column with name {_featureDetachedColumn.Name} and type {_featureDetachedColumn.Type}. " + | ||
$"Please rename your column by calling CopyColumns defined in TransformExtensionsCatalog"; | ||
// Check if column names are the same. | ||
Host.Check(_featureDetachedColumn.Name == inspectedFeatureColumn.Name, nameErrorMessage); | ||
|
||
string typeErrorMessage = $"The column called {inspectedFeatureColumn.Name} has a type {inspectedFeatureColumn.Type}, " + | ||
$"which does not match the expected feature column with name {_featureDetachedColumn.Name} and type {_featureDetachedColumn.Type}. " + | ||
$"Please make sure your feature column type is {_featureDetachedColumn.Type}."; | ||
// Check if column types are identical. | ||
Host.Check(_featureDetachedColumn.Type.Equals(inspectedFeatureColumn.Type), typeErrorMessage); | ||
} | ||
|
||
/// <summary> | ||
/// Create <see cref="RoleMappedSchema"/> from <paramref name="schema"/> by using <see cref="_featureDetachedColumn"/> as the feature role. | ||
/// </summary> | ||
/// <param name="schema">The original schema to be mapped.</param> | ||
private RoleMappedSchema MakeFeatureRoleMappedSchema(DataViewSchema schema) | ||
{ | ||
var roles = new List<KeyValuePair<RoleMappedSchema.ColumnRole, string>>(); | ||
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, _featureDetachedColumn.Name)); | ||
return new RoleMappedSchema(schema, roles); | ||
} | ||
|
||
internal TreeEnsembleFeaturizationTransformer(IHostEnvironment env, DataViewSchema inputSchema, | ||
DataViewSchema.Column featureColumn, TreeEnsembleModelParameters modelParameters, string outputColumnNameSuffix=null) : | ||
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TreeEnsembleFeaturizationTransformer)), modelParameters, inputSchema) | ||
{ | ||
// Store featureColumn as a detached column because a fitted transformer can be applied to different IDataViews and different | ||
// IDataView may have different schemas. | ||
_featureDetachedColumn = new DataViewSchema.DetachedColumn(featureColumn); | ||
// Check if featureColumn matches a column in inputSchema. The answer is yes if they have the same name and type. | ||
// The indexed column, inputSchema[featureColumn.Index], should match the detached column, _featureDetachedColumn. | ||
CheckFeatureColumnCompatibility(inputSchema[featureColumn.Index]); | ||
// Store outputColumnNameSuffix so that this transformer can be saved into a file later. | ||
_outputColumnSuffix = outputColumnNameSuffix; | ||
// Create an argument, _scorerArgs, to pass the suffix of output column names to the underlying scorer. | ||
_scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments { Suffix = _outputColumnSuffix }; | ||
// Create a bindable mapper. It provides the core computation and can be attached to any IDataView and produce | ||
// a transformed IDataView. | ||
BindableMapper = new TreeEnsembleFeaturizerBindableMapper(env, _scorerArgs, modelParameters); | ||
// Create a scorer. | ||
var roleMappedSchema = MakeFeatureRoleMappedSchema(inputSchema); | ||
Scorer = new GenericScorer(Host, _scorerArgs, new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, roleMappedSchema), roleMappedSchema); | ||
} | ||
|
||
private TreeEnsembleFeaturizationTransformer(IHostEnvironment host, ModelLoadContext ctx) | ||
: base(Contracts.CheckRef(host, nameof(host)).Register(nameof(TreeEnsembleFeaturizationTransformer)), ctx) | ||
{ | ||
// *** Binary format *** | ||
// <base info> | ||
// string: feature column's name. | ||
// string: output columns' suffix. | ||
|
||
string featureColumnName = ctx.LoadString(); | ||
_featureDetachedColumn = new DataViewSchema.DetachedColumn(TrainSchema[featureColumnName]); | ||
_outputColumnSuffix = ctx.LoadStringOrNull(); | ||
|
||
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); | ||
|
||
var args = new GenericScorer.Arguments { Suffix = "" }; | ||
var schema = MakeFeatureRoleMappedSchema(TrainSchema); | ||
Scorer = new GenericScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); | ||
} | ||
|
||
public override DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => Transform(new EmptyDataView(Host, inputSchema)).Schema; | ||
|
||
private protected override void SaveModel(ModelSaveContext ctx) | ||
{ | ||
Host.CheckValue(ctx, nameof(ctx)); | ||
ctx.CheckAtModel(); | ||
ctx.SetVersionInfo(GetVersionInfo()); | ||
|
||
// *** Binary format *** | ||
// model: prediction model. | ||
// stream: empty data view that contains train schema. | ||
// ids of strings: feature columns. | ||
// float: scorer threshold | ||
// id of string: scorer threshold column | ||
|
||
ctx.SaveModel(Model, DirModel); | ||
ctx.SaveBinaryStream(DirTransSchema, writer => | ||
{ | ||
using (var ch = Host.Start("Saving train schema")) | ||
{ | ||
var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true }); | ||
DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream); | ||
} | ||
}); | ||
|
||
ctx.SaveString(_featureDetachedColumn.Name); | ||
ctx.SaveStringOrNull(_outputColumnSuffix); | ||
} | ||
|
||
private static VersionInfo GetVersionInfo() | ||
{ | ||
return new VersionInfo( | ||
modelSignature: "TREEFEAT", // "TREE" ensemble "FEAT"urizer. | ||
verWrittenCur: 0x00010001, // Initial | ||
verReadableCur: 0x00010001, | ||
verWeCanReadBack: 0x00010001, | ||
loaderSignature: LoaderSignature, | ||
loaderAssemblyName: typeof(TreeEnsembleFeaturizationTransformer).Assembly.FullName); | ||
} | ||
|
||
private static TreeEnsembleFeaturizationTransformer Create(IHostEnvironment env, ModelLoadContext ctx) | ||
=> new TreeEnsembleFeaturizationTransformer(env, ctx); | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.