Skip to content

Tree-based featurization #3812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 26, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 233 additions & 0 deletions src/Microsoft.ML.FastTree/TreeEnsembleFeaturizationEstimator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;

namespace Microsoft.ML.Trainers.FastTree
{
/// <summary>
/// This class encapsulates the common behavior of all tree-based featurizers such as <see cref="FastTreeBinaryFeaturizationEstimator"/>,
/// <see cref="FastForestBinaryFeaturizationEstimator"/>, <see cref="FastTreeRegressionFeaturizationEstimator"/>,
/// <see cref="FastForestRegressionFeaturizationEstimator"/>, and <see cref="PretrainedTreeFeaturizationEstimator"/>.
/// All tree-based featurizers share the same output schema computed by <see cref="GetOutputSchema(SchemaShape)"/>. All tree-based featurizers
/// requires an input feature column name and a suffix for all output columns. The <see cref="ITransformer"/> returned by <see cref="Fit(IDataView)"/>
/// produces three columns: (1) the prediction values of all trees, (2) the IDs of leaves the input feature vector falling into, and (3)
/// the binary vector which encodes the paths to those destination leaves.
/// </summary>
public abstract class FeaturizationEstimatorBase : IEstimator<TreeEnsembleFeaturizationTransformer>
{
/// <summary>
/// The common options of tree-based featurizations such as <see cref="FastTreeBinaryFeaturizationEstimator"/>, <see cref="FastForestBinaryFeaturizationEstimator"/>,
/// <see cref="FastTreeRegressionFeaturizationEstimator"/>, <see cref="FastForestRegressionFeaturizationEstimator"/>, and <see cref="PretrainedTreeFeaturizationEstimator"/>.
/// </summary>
public class CommonOptions
{
/// <summary>
/// The name of feature column in the <see cref="IDataView"/> when calling <see cref="Fit(IDataView)"/>.
/// The column type must be a vector of <see cref="System.Single"/>.
/// </summary>
public string InputColumnName;

/// <summary>
/// The estimator has three output columns. Their names would be "Trees" + <see cref="OutputColumnsSuffix"/>,
/// "Leaves" + <see cref="OutputColumnsSuffix"/>, and "Paths" + <see cref="OutputColumnsSuffix"/>. If <see cref="OutputColumnsSuffix"/>
/// is <see langword="null"/>, the output names would be "Trees", "Leaves", and "Paths".
/// </summary>
public string OutputColumnsSuffix;
Copy link
Contributor

@justinormont justinormont Jun 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We went away from magic strings in the TextTransform. Previously with tokens=+, we produced a new column named {OutputColName}_TokenizedText.

For the estimators API we have users directly enter the column name for the tokens. We may want to do the same for the Trees/Leaves/Paths of the TreeFeat.

Perhaps:
OutputColumnTreeName, OutputColumnLeavesName, OutputColumnPathsName.

#Resolved

Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jun 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And don't add that column if they empty or equal to null.
That way you can actually configure which parts of tree structure do you want. #Resolved

Copy link
Contributor

@justinormont justinormont Jun 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As further background on the PR I was referencing...

Conversation about TextTransform:
via @rogancarr in #2957 PR

When using OutputTokens=true, FeaturizeText creates a new column called ${OutputColumnName}_TransformedText. This isn't really well documented anywhere, and it's odd behavior. I suggest that we make the tokenized text column name explicit in the API.

My suggestion would be the following:

  • Change OutputTokens = [bool] to OutputTokensColumn = [string], and a string.NullOrWhitespace(OutputTokensColumn) signifies that this column will not be created. #Resolved

Copy link
Member Author

@wschin wschin Jun 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we can do optional output columns and custom output column names. Please see tests for examples (or wait for formal API samples). #Resolved

};

/// <summary>
/// Feature column to apply tree-based featurization. Note that <see cref="FeatureColumnName"/> is not necessary to be the same as
/// the feature column used to train the tree model.
/// </summary>
private protected readonly string FeatureColumnName;

/// <summary>
/// See <see cref="CommonOptions.OutputColumnsSuffix"/>.
/// </summary>
private protected readonly string OutputColumnSuffix;

/// <summary>
/// Environment of this instance. It controls error throwing and other enviroment settings.
/// </summary>
private protected readonly IHostEnvironment Env;

private protected FeaturizationEstimatorBase(IHostEnvironment env, CommonOptions options)
{
Env = env;
FeatureColumnName = options.InputColumnName;
OutputColumnSuffix = options.OutputColumnsSuffix;
}

/// <summary>
/// All derived class should implement <see cref="PrepareModel(IDataView)"/> to tell how to get a <see cref="TreeEnsembleModelParameters"/>
/// out from <paramref name="input"/> and parameters inside this or derived classes.
/// </summary>
/// <param name="input">Data used to train a tree model.</param>
/// <returns>The trees used in <see cref="TreeEnsembleFeaturizationTransformer"/>.</returns>
private protected abstract TreeEnsembleModelParameters PrepareModel(IDataView input);

/// <summary>
/// Produce a <see cref="TreeEnsembleModelParameters"/> which maps the column called <see cref="CommonOptions.InputColumnName"/> in <paramref name="input"/>
/// to three output columns.
/// </summary>
public TreeEnsembleFeaturizationTransformer Fit(IDataView input)
{
var model = PrepareModel(input);
return new TreeEnsembleFeaturizationTransformer(Env, input.Schema,
input.Schema[FeatureColumnName], model, OutputColumnSuffix);
}

/// <summary>
/// <see cref="PretrainedTreeFeaturizationEstimator"/> adds three float-vector columns into <paramref name="inputSchema"/>.
/// Given a feature vector column, the added columns are the prediction values of all trees, the leaf IDs the feature
/// vector falls into, and the paths to those leaves.
/// </summary>
/// <param name="inputSchema">A schema which contains a feature column. Note that feature column name can be specified
/// by <see cref="CommonOptions.InputColumnName"/>.</param>
/// <returns>Output <see cref="SchemaShape"/> produced by <see cref="PretrainedTreeFeaturizationEstimator"/>.</returns>
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Env.CheckValue(inputSchema, nameof(inputSchema));

if (!inputSchema.TryFindColumn(FeatureColumnName, out var col))
throw Env.ExceptSchemaMismatch(nameof(inputSchema), "input", FeatureColumnName);

var result = inputSchema.ToDictionary(x => x.Name);

var treeColumnName = OutputColumnSuffix != null ? OutputColumnSuffix + "Trees" : "Trees";
result[treeColumnName] = new SchemaShape.Column(treeColumnName,
SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);

var leafColumnName = OutputColumnSuffix != null ? OutputColumnSuffix + "Leaves" : "Leaves";
result[leafColumnName] = new SchemaShape.Column(leafColumnName,
SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);

var pathColumnName = OutputColumnSuffix != null ? OutputColumnSuffix + "Paths" : "Paths";
result[pathColumnName] = new SchemaShape.Column(pathColumnName,
SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);

return new SchemaShape(result.Values);
}
}

/// <summary>
/// A <see cref="IEstimator{TTransformer}"/> which takes a trained <see cref="TreeEnsembleModelParameters"/> and calling its
/// <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> produces a featurizer based on the trained model.
/// </summary>
public sealed class PretrainedTreeFeaturizationEstimator : FeaturizationEstimatorBase
{
public sealed class Options : FeaturizationEstimatorBase.CommonOptions
{
public TreeEnsembleModelParameters ModelParameters;
};

private TreeEnsembleModelParameters _modelParameters;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TreeEnsembleModelParameters [](start = 16, length = 27)

Should this be readonly or something similar to make sure it is not altered?


internal PretrainedTreeFeaturizationEstimator(IHostEnvironment env, Options options) : base(env, options)
{
_modelParameters = options.ModelParameters;
}

/// <summary>
/// Produce the <see cref="TreeEnsembleModelParameters"/> for tree-based feature engineering. This function does not
/// invoke training procedure and just returns the pre-trained model passed in via <see cref="Options.ModelParameters"/>.
/// </summary>
private protected override TreeEnsembleModelParameters PrepareModel(IDataView input) => _modelParameters;
}

public sealed class FastTreeBinaryFeaturizationEstimator : FeaturizationEstimatorBase
{
private readonly FastTreeBinaryTrainer.Options _trainerOptions;

public sealed class Options : CommonOptions
{
public FastTreeBinaryTrainer.Options TrainerOptions;
}

internal FastTreeBinaryFeaturizationEstimator(IHostEnvironment env, Options options)
: base(env, options)
{
_trainerOptions = options.TrainerOptions;
}

private protected override TreeEnsembleModelParameters PrepareModel(IDataView input)
{
var trainer = new FastTreeBinaryTrainer(Env, _trainerOptions);
var trained = trainer.Fit(input);
return trained.Model.SubModel;
}
}

public sealed class FastTreeRegressionFeaturizationEstimator : FeaturizationEstimatorBase
{
private readonly FastTreeRegressionTrainer.Options _trainerOptions;

public sealed class Options : CommonOptions
{
public FastTreeRegressionTrainer.Options TrainerOptions;
}

internal FastTreeRegressionFeaturizationEstimator(IHostEnvironment env, Options options)
: base(env, options)
{
_trainerOptions = options.TrainerOptions;
}

private protected override TreeEnsembleModelParameters PrepareModel(IDataView input)
{
var trainer = new FastTreeRegressionTrainer(Env, _trainerOptions);
var trained = trainer.Fit(input);
return trained.Model;
}
}

public sealed class FastForestBinaryFeaturizationEstimator : FeaturizationEstimatorBase
{
private readonly FastForestBinaryTrainer.Options _trainerOptions;

public sealed class Options : CommonOptions
{
public FastForestBinaryTrainer.Options TrainerOptions;
}

internal FastForestBinaryFeaturizationEstimator(IHostEnvironment env, Options options)
: base(env, options)
{
_trainerOptions = options.TrainerOptions;
}

private protected override TreeEnsembleModelParameters PrepareModel(IDataView input)
{
var trainer = new FastForestBinaryTrainer(Env, _trainerOptions);
var trained = trainer.Fit(input);
return trained.Model;
}
}

public sealed class FastForestRegressionFeaturizationEstimator : FeaturizationEstimatorBase
{
private readonly FastForestRegressionTrainer.Options _trainerOptions;

public sealed class Options : CommonOptions
{
public FastForestRegressionTrainer.Options TrainerOptions;
}

internal FastForestRegressionFeaturizationEstimator(IHostEnvironment env, Options options)
: base(env, options)
{
_trainerOptions = options.TrainerOptions;
}

private protected override TreeEnsembleModelParameters PrepareModel(IDataView input)
{
var trainer = new FastForestRegressionTrainer(Env, _trainerOptions);
var trained = trainer.Fit(input);
return trained.Model;
}
}
}
138 changes: 138 additions & 0 deletions src/Microsoft.ML.FastTree/TreeEnsembleFeaturizationTransformer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;

[assembly: LoadableClass(typeof(TreeEnsembleFeaturizationTransformer), typeof(TreeEnsembleFeaturizationTransformer),
null, typeof(SignatureLoadModel), "", TreeEnsembleFeaturizationTransformer.LoaderSignature)]

namespace Microsoft.ML.Trainers.FastTree
{
public sealed class TreeEnsembleFeaturizationTransformer : PredictionTransformerBase<TreeEnsembleModelParameters>
{
internal const string LoaderSignature = "TreeEnseFeat";
private readonly TreeEnsembleFeaturizerBindableMapper.Arguments _scorerArgs;
private readonly DataViewSchema.DetachedColumn _featureDetachedColumn;
private readonly string _outputColumnSuffix;

/// <summary>
/// Check if <see cref="_featureDetachedColumn"/> is compatible with <paramref name="inspectedFeatureColumn"/>.
/// </summary>
/// <param name="inspectedFeatureColumn">A column checked against <see cref="_featureDetachedColumn"/>.</param>
private void CheckFeatureColumnCompatibility(DataViewSchema.Column inspectedFeatureColumn)
{
string nameErrorMessage = $"The column called {inspectedFeatureColumn.Name} does not match the expected " +
$"feature column with name {_featureDetachedColumn.Name} and type {_featureDetachedColumn.Type}. " +
$"Please rename your column by calling CopyColumns defined in TransformExtensionsCatalog";
// Check if column names are the same.
Host.Check(_featureDetachedColumn.Name == inspectedFeatureColumn.Name, nameErrorMessage);

string typeErrorMessage = $"The column called {inspectedFeatureColumn.Name} has a type {inspectedFeatureColumn.Type}, " +
$"which does not match the expected feature column with name {_featureDetachedColumn.Name} and type {_featureDetachedColumn.Type}. " +
$"Please make sure your feature column type is {_featureDetachedColumn.Type}.";
// Check if column types are identical.
Host.Check(_featureDetachedColumn.Type.Equals(inspectedFeatureColumn.Type), typeErrorMessage);
}

/// <summary>
/// Create <see cref="RoleMappedSchema"/> from <paramref name="schema"/> by using <see cref="_featureDetachedColumn"/> as the feature role.
/// </summary>
/// <param name="schema">The original schema to be mapped.</param>
private RoleMappedSchema MakeFeatureRoleMappedSchema(DataViewSchema schema)
{
var roles = new List<KeyValuePair<RoleMappedSchema.ColumnRole, string>>();
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, _featureDetachedColumn.Name));
return new RoleMappedSchema(schema, roles);
}

internal TreeEnsembleFeaturizationTransformer(IHostEnvironment env, DataViewSchema inputSchema,
DataViewSchema.Column featureColumn, TreeEnsembleModelParameters modelParameters, string outputColumnNameSuffix=null) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TreeEnsembleFeaturizationTransformer)), modelParameters, inputSchema)
{
// Store featureColumn as a detached column because a fitted transformer can be applied to different IDataViews and different
// IDataView may have different schemas.
_featureDetachedColumn = new DataViewSchema.DetachedColumn(featureColumn);
// Check if featureColumn matches a column in inputSchema. The answer is yes if they have the same name and type.
// The indexed column, inputSchema[featureColumn.Index], should match the detached column, _featureDetachedColumn.
CheckFeatureColumnCompatibility(inputSchema[featureColumn.Index]);
// Store outputColumnNameSuffix so that this transformer can be saved into a file later.
_outputColumnSuffix = outputColumnNameSuffix;
// Create an argument, _scorerArgs, to pass the suffix of output column names to the underlying scorer.
_scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments { Suffix = _outputColumnSuffix };
// Create a bindable mapper. It provides the core computation and can be attached to any IDataView and produce
// a transformed IDataView.
BindableMapper = new TreeEnsembleFeaturizerBindableMapper(env, _scorerArgs, modelParameters);
// Create a scorer.
var roleMappedSchema = MakeFeatureRoleMappedSchema(inputSchema);
Scorer = new GenericScorer(Host, _scorerArgs, new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, roleMappedSchema), roleMappedSchema);
}

private TreeEnsembleFeaturizationTransformer(IHostEnvironment host, ModelLoadContext ctx)
: base(Contracts.CheckRef(host, nameof(host)).Register(nameof(TreeEnsembleFeaturizationTransformer)), ctx)
{
// *** Binary format ***
// <base info>
// string: feature column's name.
// string: output columns' suffix.

string featureColumnName = ctx.LoadString();
_featureDetachedColumn = new DataViewSchema.DetachedColumn(TrainSchema[featureColumnName]);
_outputColumnSuffix = ctx.LoadStringOrNull();

BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);

var args = new GenericScorer.Arguments { Suffix = "" };
var schema = MakeFeatureRoleMappedSchema(TrainSchema);
Scorer = new GenericScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
}

public override DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => Transform(new EmptyDataView(Host, inputSchema)).Schema;

private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

// *** Binary format ***
// model: prediction model.
// stream: empty data view that contains train schema.
// ids of strings: feature columns.
// float: scorer threshold
// id of string: scorer threshold column

ctx.SaveModel(Model, DirModel);
ctx.SaveBinaryStream(DirTransSchema, writer =>
{
using (var ch = Host.Start("Saving train schema"))
{
var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true });
DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
}
});

ctx.SaveString(_featureDetachedColumn.Name);
ctx.SaveStringOrNull(_outputColumnSuffix);
}

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "TREEFEAT", // "TREE" ensemble "FEAT"urizer.
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(TreeEnsembleFeaturizationTransformer).Assembly.FullName);
}

private static TreeEnsembleFeaturizationTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
=> new TreeEnsembleFeaturizationTransformer(env, ctx);
}
}
Loading