Skip to content

Field-aware factorization machine to estimator #912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 20, 2018
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Core/Data/IEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,14 @@ public interface IDataReaderEstimator<in TSource, out TReader>

/// <summary>
/// The transformer is a component that transforms data.
/// It also supports 'schema propagation' to answer the question of 'how the data with this schema look after you transform it?'.
/// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'.
/// </summary>
public interface ITransformer
{
/// <summary>
/// Schema propagation for transformers.
/// Returns the output schema of the data, if the input schema is like the one provided.
/// Throws <see cref="SchemaException"/> iff the input schema is not valid for the transformer.
/// Throws <see cref="SchemaException"/> if the input schema is not valid for the transformer.
/// </summary>
ISchema GetOutputSchema(ISchema inputSchema);

Expand Down
27 changes: 21 additions & 6 deletions src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,37 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Calibration;
using System;
using System.Collections.Generic;
using System.Text;

namespace Microsoft.ML.Runtime
{
/// <summary>
/// An interface for all the transformer that can transform data based on the <see cref="IPredictor"/> field.
/// The implemendations of this interface either have no feature column, or have more than one feature column, and cannot implement the
/// <see cref="ISingleFeaturePredictionTransformer{TModel}"/>, which most of the ML.Net tranformer implement.
/// </summary>
/// <typeparam name="TModel">The <see cref="IPredictor"/> used for the data transformation.</typeparam>
public interface IPredictionTransformer<out TModel> : ITransformer
where TModel : IPredictor
{
TModel Model { get; }
}

/// <summary>
/// An ISingleFeaturePredictionTransformer contains the name of the <see cref="FeatureColumn"/>
/// and its type, <see cref="FeatureColumnType"/>. Implementations of this interface, have the ability
/// to score the data of an input <see cref="IDataView"/> through the <see cref="ITransformer.Transform(IDataView)"/>
/// </summary>
/// <typeparam name="TModel">The <see cref="IPredictor"/> used for the data transformation.</typeparam>
public interface ISingleFeaturePredictionTransformer<out TModel> : IPredictionTransformer<TModel>
where TModel : IPredictor
{
/// <summary>The name of the feature column.</summary>
string FeatureColumn { get; }

/// <summary>Holds information about the type of the feature column.</summary>
ColumnType FeatureColumnType { get; }

TModel Model { get; }
}
}
161 changes: 111 additions & 50 deletions src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.IO;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
Expand All @@ -23,54 +22,49 @@

namespace Microsoft.ML.Runtime.Data
{
public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel>, ICanSaveModel

/// <summary>
/// Base class for transformers with no feature column, or more than one feature columns.
/// </summary>
/// <typeparam name="TModel"></typeparam>
public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel>
Copy link
Contributor

@TomFinley TomFinley Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No documentation? #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is there? Are you looking at iteration 10?


In reply to: 218905403 [](ancestors = 218905403)

where TModel : class, IPredictor
{
private const string DirModel = "Model";
private const string DirTransSchema = "TrainSchema";
/// <summary>
/// The model.
/// </summary>
public TModel Model { get; }

protected const string DirModel = "Model";
protected const string DirTransSchema = "TrainSchema";
protected readonly IHost Host;
protected readonly ISchemaBindableMapper BindableMapper;
protected readonly ISchema TrainSchema;

public string FeatureColumn { get; }

public ColumnType FeatureColumnType { get; }
protected ISchemaBindableMapper BindableMapper;
protected ISchema TrainSchema;

public TModel Model { get; }

public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
protected PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema)
{
Contracts.CheckValue(host, nameof(host));
Contracts.CheckValueOrNull(featureColumn);

Host = host;
Host.CheckValue(trainSchema, nameof(trainSchema));

Model = model;
FeatureColumn = featureColumn;
if (featureColumn == null)
FeatureColumnType = null;
else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
else
FeatureColumnType = trainSchema.GetColumnType(col);

TrainSchema = trainSchema;
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
}

internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
protected PredictionTransformerBase(IHost host, ModelLoadContext ctx)

{
Host = host;

ctx.LoadModel<TModel, SignatureLoadModel>(host, out TModel model, DirModel);
Model = model;

// *** Binary format ***
// model: prediction model.
// stream: empty data view that contains train schema.
// id of string: feature column.

ctx.LoadModel<TModel, SignatureLoadModel>(host, out TModel model, DirModel);
Model = model;

// Clone the stream with the schema into memory.
var ms = new MemoryStream();
ctx.TryLoadBinaryStream(DirTransSchema, reader =>
Expand All @@ -81,19 +75,90 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
ms.Position = 0;
var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms);
TrainSchema = loader.Schema;
}

/// <summary>
/// Gets the output schema resulting from the <see cref="Transform(IDataView)"/>
/// </summary>
/// <param name="inputSchema">The <see cref="ISchema"/> of the input data.</param>
/// <returns>The resulting <see cref="ISchema"/>.</returns>
public abstract ISchema GetOutputSchema(ISchema inputSchema);

/// <summary>
/// Transforms the input data.
/// </summary>
/// <param name="input">The input data.</param>
/// <returns>The transformed <see cref="IDataView"/></returns>
public abstract IDataView Transform(IDataView input);

protected void SaveModel(ModelSaveContext ctx)
{
// *** Binary format ***
// <base info>
// stream: empty data view that contains train schema.

ctx.SaveModel(Model, DirModel);
ctx.SaveBinaryStream(DirTransSchema, writer =>
{
using (var ch = Host.Start("Saving train schema"))
{
var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true });
DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
}
});
}
}

/// <summary>
/// The base class for all the transformers implementing the <see cref="ISingleFeaturePredictionTransformer{TModel}"/>.
/// Those are all the transformers that work with one feature column.
/// </summary>
/// <typeparam name="TModel">The model used to transform the data.</typeparam>
public abstract class SingleFeaturePredictionTransformerBase<TModel> : PredictionTransformerBase<TModel>, ISingleFeaturePredictionTransformer<TModel>, ICanSaveModel
where TModel : class, IPredictor
{
/// <summary>
/// The name of the feature column used by the prediction transformer.
/// </summary>
public string FeatureColumn { get; }

/// <summary>
/// The type of the prediction transformer
/// </summary>
public ColumnType FeatureColumnType { get; }

public SingleFeaturePredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
:base(host, model, trainSchema)
{
FeatureColumn = featureColumn;

FeatureColumn = featureColumn;
if (featureColumn == null)
FeatureColumnType = null;
else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
else
FeatureColumnType = trainSchema.GetColumnType(col);

BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
}

internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx)
:base(host, ctx)
{
FeatureColumn = ctx.LoadStringOrNull();

if (FeatureColumn == null)
FeatureColumnType = null;
else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
else
FeatureColumnType = TrainSchema.GetColumnType(col);

BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);
}

public ISchema GetOutputSchema(ISchema inputSchema)
public override ISchema GetOutputSchema(ISchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

Expand All @@ -108,8 +173,6 @@ public ISchema GetOutputSchema(ISchema inputSchema)
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
}

public abstract IDataView Transform(IDataView input);

public void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Expand All @@ -119,26 +182,16 @@ public void Save(ModelSaveContext ctx)

protected virtual void SaveCore(ModelSaveContext ctx)
{
// *** Binary format ***
// model: prediction model.
// stream: empty data view that contains train schema.
// id of string: feature column.

ctx.SaveModel(Model, DirModel);
ctx.SaveBinaryStream(DirTransSchema, writer =>
{
using (var ch = Host.Start("Saving train schema"))
{
var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true });
DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
}
});

SaveModel(ctx);
ctx.SaveStringOrNull(FeatureColumn);
}
}

public sealed class BinaryPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
/// <summary>
/// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on binary classification tasks.
/// </summary>
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
public sealed class BinaryPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel>
where TModel : class, IPredictorProducing<float>
{
private readonly BinaryClassifierScorer _scorer;
Expand Down Expand Up @@ -207,7 +260,11 @@ private static VersionInfo GetVersionInfo()
}
}

public sealed class MulticlassPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
/// <summary>
/// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on multi-class classification tasks.
/// </summary>
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
public sealed class MulticlassPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel>
where TModel : class, IPredictorProducing<VBuffer<float>>
{
private readonly MultiClassClassifierScorer _scorer;
Expand Down Expand Up @@ -268,7 +325,11 @@ private static VersionInfo GetVersionInfo()
}
}

public sealed class RegressionPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
/// <summary>
/// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on regression tasks.
/// </summary>
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
public sealed class RegressionPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel>
where TModel : class, IPredictorProducing<float>
{
private readonly GenericScorer _scorer;
Expand Down Expand Up @@ -314,7 +375,7 @@ private static VersionInfo GetVersionInfo()
}
}

public sealed class RankingPredictionTransformer<TModel> : PredictionTransformerBase<TModel>
public sealed class RankingPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel>
where TModel : class, IPredictorProducing<float>
{
private readonly GenericScorer _scorer;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Training/ITrainerEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace Microsoft.ML.Runtime.Training
{
public interface ITrainerEstimator<out TTransformer, out TPredictor>: IEstimator<TTransformer>
where TTransformer: IPredictionTransformer<TPredictor>
where TTransformer: ISingleFeaturePredictionTransformer<TPredictor>
where TPredictor: IPredictor
{
TrainerInfo Info { get; }
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace Microsoft.ML.Runtime.Training
/// It produces a 'prediction transformer'.
/// </summary>
public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstimator<TTransformer, TModel>, ITrainer<TModel>
where TTransformer : IPredictionTransformer<TModel>
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
where TModel : IPredictor
{
/// <summary>
Expand Down
50 changes: 50 additions & 0 deletions src/Microsoft.ML.Data/Training/TrainerEstimatorContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Training;

namespace Microsoft.ML.Core.Prediction
{
/// <summary>
/// Holds information relevant to trainers. It is passed to the constructor of the<see cref="ITrainerEstimator{IPredictionTransformer, IPredictor}"/>
/// holding additional data needed to fit the estimator. The additional data can be a validation set or an initial model.
/// This holds at least a training set, as well as optioonally a predictor.
/// </summary>
public class TrainerEstimatorContext
{
/// <summary>
/// The validation set. Can be <c>null</c>. Note that passing a non-<c>null</c> validation set into
/// a trainer that does not support validation sets should not be considered an error condition. It
/// should simply be ignored in that case.
/// </summary>
public IDataView ValidationSet { get; }

/// <summary>
/// The initial predictor, for incremental training. Note that if a <see cref="ITrainerEstimator{IPredictionTransformer, IPredictor}"/> implementor
/// does not support incremental training, then it can ignore it similarly to how one would ignore
/// <see cref="ValidationSet"/>. However, if the trainer does support incremental training and there
/// is something wrong with a non-<c>null</c> value of this, then the trainer ought to throw an exception.
/// </summary>
public IPredictor InitialPredictor { get; }

/// <summary>
/// Initializes a new instance of <see cref="TrainerEstimatorContext"/>, given a training set and optional other arguments.
/// </summary>
/// <param name="validationSet">Will set <see cref="ValidationSet"/> to this value if specified</param>
/// <param name="initialPredictor">Will set <see cref="InitialPredictor"/> to this value if specified</param>
public TrainerEstimatorContext(IDataView validationSet = null, IPredictor initialPredictor = null)
{
Contracts.CheckValueOrNull(validationSet);
Contracts.CheckValueOrNull(initialPredictor);

ValidationSet = validationSet;
InitialPredictor = initialPredictor;
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/BoostingFastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace Microsoft.ML.Runtime.FastTree
{
public abstract class BoostingFastTreeTrainerBase<TArgs, TTransformer, TModel> : FastTreeTrainerBase<TArgs, TTransformer, TModel>
where TTransformer : IPredictionTransformer<TModel>
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
where TArgs : BoostedTreeArgs, new()
where TModel : IPredictorProducing<Float>
{
Expand Down
Loading