diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs index 3fd9c042b8..0b6475275c 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs @@ -51,13 +51,13 @@ public static BinaryClassifierEvaluator.CalibratedResult Evaluate( } /// - /// Evaluates scored binary classification data. + /// Evaluates scored binary classification data, if the predictions are not calibrated. /// /// The shape type for the input data. /// The binary classification context. /// The data to evaluate. /// The index delegate for the label column. - /// The index delegate for columns from calibrated prediction of a binary classifier. + /// The index delegate for columns from uncalibrated prediction of a binary classifier. /// Under typical scenarios, this will just be the same tuple of results returned from the trainer. /// The evaluation results for these uncalibrated outputs. public static BinaryClassifierEvaluator.Result Evaluate( @@ -83,5 +83,89 @@ public static BinaryClassifierEvaluator.Result Evaluate( var eval = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { }); return eval.Evaluate(data.AsDynamic, labelName, scoreName, predName); } + + /// + /// Evaluates scored multiclass classification data. + /// + /// The shape type for the input data. + /// The value type for the key label. + /// The multiclass classification context. + /// The data to evaluate. + /// The index delegate for the label column. + /// The index delegate for columns from the prediction of a multiclass classifier. + /// Under typical scenarios, this will just be the same tuple of results returned from the trainer. + /// If given a positive value, the will be filled with + /// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within + /// the top-K values as being stored "correctly." + /// The evaluation metrics. + public static MultiClassClassifierEvaluator.Result Evaluate( + this MulticlassClassificationContext ctx, + DataView data, + Func> label, + Func score, Key predictedLabel)> pred, + int topK = 0) + { + Contracts.CheckValue(data, nameof(data)); + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckValue(label, nameof(label)); + env.CheckValue(pred, nameof(pred)); + env.CheckParam(topK >= 0, nameof(topK), "Must not be negative."); + + var indexer = StaticPipeUtils.GetIndexer(data); + string labelName = indexer.Get(label(indexer.Indices)); + (var scoreCol, var predCol) = pred(indexer.Indices); + Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column."); + Contracts.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column."); + string scoreName = indexer.Get(scoreCol); + string predName = indexer.Get(predCol); + + var args = new MultiClassClassifierEvaluator.Arguments() { }; + if (topK > 0) + args.OutputTopKAcc = topK; + + var eval = new MultiClassClassifierEvaluator(env, args); + return eval.Evaluate(data.AsDynamic, labelName, scoreName, predName); + } + + private sealed class TrivialRegressionLossFactory : ISupportRegressionLossFactory + { + private readonly IRegressionLoss _loss; + public TrivialRegressionLossFactory(IRegressionLoss loss) => _loss = loss; + public IRegressionLoss CreateComponent(IHostEnvironment env) => _loss; + } + + /// + /// Evaluates scored multiclass classification data. + /// + /// The shape type for the input data. + /// The regression context. + /// The data to evaluate. + /// The index delegate for the label column. + /// The index delegate for predicted score column. + /// Potentially custom loss function. If left unspecified defaults to . + /// The evaluation metrics. + public static RegressionEvaluator.Result Evaluate( + this RegressionContext ctx, + DataView data, + Func> label, + Func> score, + IRegressionLoss loss = null) + { + Contracts.CheckValue(data, nameof(data)); + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckValue(label, nameof(label)); + env.CheckValue(score, nameof(score)); + + var indexer = StaticPipeUtils.GetIndexer(data); + string labelName = indexer.Get(label(indexer.Indices)); + string scoreName = indexer.Get(score(indexer.Indices)); + + var args = new RegressionEvaluator.Arguments() { }; + if (loss != null) + args.LossFunction = new TrivialRegressionLossFactory(loss); + return new RegressionEvaluator(env, args).Evaluate(data.AsDynamic, labelName, scoreName); + } } } diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs index 900df4ac53..8a5f921eb3 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs @@ -598,62 +598,37 @@ internal Result(IExceptionContext ectx, IRow overallResult, int topK) } /// - /// Evaluates scored regression data. + /// Evaluates scored multiclass classification data. /// - /// The shape type for the input data. - /// The value type for the key label. - /// The data to evaluate. - /// The index delegate for the label column. - /// The index delegate for columns from prediction of a multi-class classifier. - /// Under typical scenarios, this will just be the same tuple of results returned from the trainer. - /// If given a positive value, the will be filled with - /// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within - /// the top-K values as being stored "correctly." + /// The scored data. + /// The name of the label column in . + /// The name of the score column in . + /// The name of the predicted label column in . /// The evaluation results for these outputs. - public static Result Evaluate( - DataView data, - Func> label, - Func score, Key predictedLabel)> pred, - int topK = 0) + public Result Evaluate(IDataView data, string label, string score, string predictedLabel) { - Contracts.CheckValue(data, nameof(data)); - var env = StaticPipeUtils.GetEnvironment(data); - Contracts.AssertValue(env); - env.CheckValue(label, nameof(label)); - env.CheckValue(pred, nameof(pred)); - env.CheckParam(topK >= 0, nameof(topK), "Must not be negative."); - - var indexer = StaticPipeUtils.GetIndexer(data); - string labelName = indexer.Get(label(indexer.Indices)); - (var scoreCol, var predCol) = pred(indexer.Indices); - Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column."); - Contracts.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column."); - string scoreName = indexer.Get(scoreCol); - string predName = indexer.Get(predCol); - - var args = new Arguments() { }; - if (topK > 0) - args.OutputTopKAcc = topK; - - var eval = new MultiClassClassifierEvaluator(env, args); - - var roles = new RoleMappedData(data.AsDynamic, opt: false, - RoleMappedSchema.ColumnRole.Label.Bind(labelName), - RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreName), - RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.PredictedLabel, predName)); - - var resultDict = eval.Evaluate(roles); - env.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); + Host.CheckValue(data, nameof(data)); + Host.CheckNonEmpty(label, nameof(label)); + Host.CheckNonEmpty(score, nameof(score)); + Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel)); + + var roles = new RoleMappedData(data, opt: false, + RoleMappedSchema.ColumnRole.Label.Bind(label), + RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score), + RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel)); + + var resultDict = Evaluate(roles); + Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); var overall = resultDict[MetricKinds.OverallMetrics]; Result result; using (var cursor = overall.GetRowCursor(i => true)) { var moved = cursor.MoveNext(); - env.Assert(moved); - result = new Result(env, cursor, topK); + Host.Assert(moved); + result = new Result(Host, cursor, _outputTopKAcc ?? 0); moved = cursor.MoveNext(); - env.Assert(!moved); + Host.Assert(!moved); } return result; } diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs index b20492d6dc..a6b2e4c509 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs @@ -219,65 +219,43 @@ internal Result(IExceptionContext ectx, IRow overallResult) double Fetch(string name) => Fetch(ectx, overallResult, name); L1 = Fetch(RegressionEvaluator.L1); L2 = Fetch(RegressionEvaluator.L2); - Rms= Fetch(RegressionEvaluator.Rms); + Rms = Fetch(RegressionEvaluator.Rms); LossFn = Fetch(RegressionEvaluator.Loss); RSquared = Fetch(RegressionEvaluator.RSquared); } } - private sealed class TrivialLossFactory : ISupportRegressionLossFactory - { - private readonly IRegressionLoss _loss; - public TrivialLossFactory(IRegressionLoss loss) => _loss = loss; - public IRegressionLoss CreateComponent(IHostEnvironment env) => _loss; - } - /// /// Evaluates scored regression data. /// - /// The shape type for the input data. /// The data to evaluate. - /// The index delegate for the label column. - /// The index delegate for the predicted score column. - /// Potentially custom loss function. If left unspecified defaults to . - /// The evaluation results for these outputs. - public static Result Evaluate( - DataView data, - Func> label, - Func> score, - IRegressionLoss loss = null) + /// The name of the label column. + /// The name of the predicted score column. + /// The evaluation metrics for these outputs. + public Result Evaluate( + IDataView data, + string label, + string score) { - Contracts.CheckValue(data, nameof(data)); - var env = StaticPipeUtils.GetEnvironment(data); - Contracts.AssertValue(env); - env.CheckValue(label, nameof(label)); - env.CheckValue(score, nameof(score)); - - var indexer = StaticPipeUtils.GetIndexer(data); - string labelName = indexer.Get(label(indexer.Indices)); - string scoreName = indexer.Get(score(indexer.Indices)); - - var args = new Arguments() { }; - if (loss != null) - args.LossFunction = new TrivialLossFactory(loss); - var eval = new RegressionEvaluator(env, args); - - var roles = new RoleMappedData(data.AsDynamic, opt: false, - RoleMappedSchema.ColumnRole.Label.Bind(labelName), - RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreName)); - - var resultDict = eval.Evaluate(roles); - env.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); + Host.CheckValue(data, nameof(data)); + Host.CheckNonEmpty(label, nameof(label)); + Host.CheckNonEmpty(score, nameof(score)); + var roles = new RoleMappedData(data, opt: false, + RoleMappedSchema.ColumnRole.Label.Bind(label), + RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score)); + + var resultDict = Evaluate(roles); + Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); var overall = resultDict[MetricKinds.OverallMetrics]; Result result; using (var cursor = overall.GetRowCursor(i => true)) { var moved = cursor.MoveNext(); - env.Assert(moved); - result = new Result(env, cursor); + Host.Assert(moved); + result = new Result(Host, cursor); moved = cursor.MoveNext(); - env.Assert(!moved); + Host.Assert(!moved); } return result; } diff --git a/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs b/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs index 3fa8d1ae7a..c196b326dc 100644 --- a/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs +++ b/src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs @@ -19,7 +19,7 @@ namespace Microsoft.ML.Data.StaticPipe.Runtime /// public abstract class TrainerEstimatorReconciler : EstimatorReconciler { - private readonly PipelineColumn[] _inputs; + protected readonly PipelineColumn[] Inputs; private readonly string[] _outputNames; /// @@ -38,7 +38,7 @@ protected TrainerEstimatorReconciler(PipelineColumn[] inputs, string[] outputNam Contracts.CheckValue(inputs, nameof(inputs)); Contracts.CheckValue(outputNames, nameof(outputNames)); - _inputs = inputs; + Inputs = inputs; _outputNames = outputNames; } @@ -71,7 +71,7 @@ public sealed override IEstimator Reconcile(IHostEnvironment env, env.AssertValue(usedNames); // The reconciler should have been called with all the input columns having names. - env.Assert(inputNames.Keys.All(_inputs.Contains) && _inputs.All(inputNames.Keys.Contains)); + env.Assert(inputNames.Keys.All(Inputs.Contains) && Inputs.All(inputNames.Keys.Contains)); // The output name map should contain only outputs as their keys. Yet, it is possible not all // outputs will be required in which case these will both be subsets of those outputs indicated // at construction. @@ -105,7 +105,7 @@ public sealed override IEstimator Reconcile(IHostEnvironment env, } // Map the inputs to the names. - string[] mappedInputNames = _inputs.Select(c => inputNames[c]).ToArray(); + string[] mappedInputNames = Inputs.Select(c => inputNames[c]).ToArray(); // Finally produce the trainer. var trainerEst = ReconcileCore(env, mappedInputNames); if (result == null) @@ -172,7 +172,7 @@ public Regression(EstimatorFactory estimatorFactory, Scalar label, Vector { Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory)); _estFact = estimatorFactory; - Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3); + Contracts.Assert(Inputs.Length == 2 || Inputs.Length == 3); Score = new Impl(this); } @@ -182,13 +182,13 @@ private static PipelineColumn[] MakeInputs(Scalar label, Vector fe protected override IEstimator ReconcileCore(IHostEnvironment env, string[] inputNames) { Contracts.AssertValue(env); - env.Assert(Utils.Size(inputNames) == _inputs.Length); + env.Assert(Utils.Size(inputNames) == Inputs.Length); return _estFact(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null); } private sealed class Impl : Scalar { - public Impl(Regression rec) : base(rec, rec._inputs) { } + public Impl(Regression rec) : base(rec, rec.Inputs) { } } } @@ -231,7 +231,7 @@ public BinaryClassifier(EstimatorFactory estimatorFactory, Scalar label, V { Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory)); _estFact = estimatorFactory; - Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3); + Contracts.Assert(Inputs.Length == 2 || Inputs.Length == 3); Output = (new Impl(this), new Impl(this), new ImplBool(this)); } @@ -242,18 +242,18 @@ private static PipelineColumn[] MakeInputs(Scalar label, Vector fea protected override IEstimator ReconcileCore(IHostEnvironment env, string[] inputNames) { Contracts.AssertValue(env); - env.Assert(Utils.Size(inputNames) == _inputs.Length); + env.Assert(Utils.Size(inputNames) == Inputs.Length); return _estFact(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null); } private sealed class Impl : Scalar { - public Impl(BinaryClassifier rec) : base(rec, rec._inputs) { } + public Impl(BinaryClassifier rec) : base(rec, rec.Inputs) { } } private sealed class ImplBool : Scalar { - public ImplBool(BinaryClassifier rec) : base(rec, rec._inputs) { } + public ImplBool(BinaryClassifier rec) : base(rec, rec.Inputs) { } } } @@ -306,7 +306,7 @@ public BinaryClassifierNoCalibration(EstimatorFactory estimatorFactory, Scalar label, Vector fea protected override IEstimator ReconcileCore(IHostEnvironment env, string[] inputNames) { Contracts.AssertValue(env); - env.Assert(Utils.Size(inputNames) == _inputs.Length); + env.Assert(Utils.Size(inputNames) == Inputs.Length); return _estFact(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null); } private sealed class Impl : Scalar { - public Impl(BinaryClassifierNoCalibration rec) : base(rec, rec._inputs) { } + public Impl(BinaryClassifierNoCalibration rec) : base(rec, rec.Inputs) { } } private sealed class ImplBool : Scalar { - public ImplBool(BinaryClassifierNoCalibration rec) : base(rec, rec._inputs) { } + public ImplBool(BinaryClassifierNoCalibration rec) : base(rec, rec.Inputs) { } } } @@ -378,7 +378,7 @@ public MulticlassClassifier(EstimatorFactory estimatorFactory, Key l { Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory)); _estFact = estimatorFactory; - Contracts.Assert(_inputs.Length == 2 || _inputs.Length == 3); + Contracts.Assert(Inputs.Length == 2 || Inputs.Length == 3); Output = (new ImplScore(this), new ImplLabel(this)); } @@ -388,18 +388,18 @@ private static PipelineColumn[] MakeInputs(Key label, Vector protected override IEstimator ReconcileCore(IHostEnvironment env, string[] inputNames) { Contracts.AssertValue(env); - env.Assert(Utils.Size(inputNames) == _inputs.Length); + env.Assert(Utils.Size(inputNames) == Inputs.Length); return _estFact(env, inputNames[0], inputNames[1], inputNames.Length > 2 ? inputNames[2] : null); } private sealed class ImplLabel : Key { - public ImplLabel(MulticlassClassifier rec) : base(rec, rec._inputs) { } + public ImplLabel(MulticlassClassifier rec) : base(rec, rec.Inputs) { } } private sealed class ImplScore : Vector { - public ImplScore(MulticlassClassifier rec) : base(rec, rec._inputs) { } + public ImplScore(MulticlassClassifier rec) : base(rec, rec.Inputs) { } } } diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs index ec7aeec648..b15bd1b317 100644 --- a/src/Microsoft.ML.Data/Training/TrainContext.cs +++ b/src/Microsoft.ML.Data/Training/TrainContext.cs @@ -81,7 +81,7 @@ public sealed class BinaryClassificationContext : TrainContextBase /// For trainers for performing binary classification. /// /// - /// Component authors that have written binary classification. + /// Component authors that have written binary classification. They are great people. /// public BinaryClassificationTrainers Trainers { get; } @@ -141,4 +141,97 @@ public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, st return eval.Evaluate(data, label, score, predictedLabel); } } + + /// + /// The central context for multiclass classification trainers. + /// + public sealed class MulticlassClassificationContext : TrainContextBase + { + /// + /// For trainers for performing multiclass classification. + /// + public MulticlassClassificationTrainers Trainers { get; } + + public MulticlassClassificationContext(IHostEnvironment env) + : base(env, nameof(MulticlassClassificationContext)) + { + Trainers = new MulticlassClassificationTrainers(this); + } + + public sealed class MulticlassClassificationTrainers : ContextInstantiatorBase + { + internal MulticlassClassificationTrainers(MulticlassClassificationContext ctx) + : base(ctx) + { + } + } + + /// + /// Evaluates scored multiclass classification data. + /// + /// The scored data. + /// The name of the label column in . + /// The name of the score column in . + /// The name of the predicted label column in . + /// If given a positive value, the will be filled with + /// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within + /// the top-K values as being stored "correctly." + /// The evaluation results for these calibrated outputs. + public MultiClassClassifierEvaluator.Result Evaluate(IDataView data, string label, string score = DefaultColumnNames.Score, + string predictedLabel = DefaultColumnNames.PredictedLabel, int topK = 0) + { + Host.CheckValue(data, nameof(data)); + Host.CheckNonEmpty(label, nameof(label)); + Host.CheckNonEmpty(score, nameof(score)); + Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel)); + + var args = new MultiClassClassifierEvaluator.Arguments() { }; + if (topK > 0) + args.OutputTopKAcc = topK; + var eval = new MultiClassClassifierEvaluator(Host, args); + return eval.Evaluate(data, label, score, predictedLabel); + } + } + + /// + /// The central context for regression trainers. + /// + public sealed class RegressionContext : TrainContextBase + { + /// + /// For trainers for performing regression. + /// + public RegressionTrainers Trainers { get; } + + public RegressionContext(IHostEnvironment env) + : base(env, nameof(RegressionContext)) + { + Trainers = new RegressionTrainers(this); + } + + public sealed class RegressionTrainers : ContextInstantiatorBase + { + internal RegressionTrainers(RegressionContext ctx) + : base(ctx) + { + } + } + + /// + /// Evaluates scored regression data. + /// + /// The scored data. + /// The name of the label column in . + /// The name of the score column in . + /// The evaluation results for these calibrated outputs. + public RegressionEvaluator.Result Evaluate(IDataView data, string label, string score = DefaultColumnNames.Score) + { + Host.CheckValue(data, nameof(data)); + Host.CheckNonEmpty(label, nameof(label)); + Host.CheckNonEmpty(score, nameof(score)); + + var eval = new RegressionEvaluator(Host, new RegressionEvaluator.Arguments() { }); + return eval.Evaluate(data, label, score); + } + } } diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs new file mode 100644 index 0000000000..9d5b89c1c2 --- /dev/null +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs @@ -0,0 +1,123 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.FactorizationMachine; +using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Training; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Trainers +{ + /// + /// Extension methods and utilities for instantiating FFM trainer estimators inside statically typed pipelines. + /// + public static class FactorizationMachineStatic + { + /// + /// Predict a target using a field-aware factorization machine. + /// + /// The binary classifier context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// Initial learning rate. + /// Number of training iterations. + /// Latent space dimensions. + /// A delegate to set more settings. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to + /// be informed about what was learnt. + /// The predicted output. + public static (Scalar score, Scalar predictedLabel) FieldAwareFactorizationMachine(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + Scalar label, Vector[] features, + float learningRate = 0.1f, + int numIterations = 5, + int numLatentDimensions = 20, + Action advancedSettings = null, + Action onFit = null) + { + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckNonEmpty(features, nameof(features)); + + Contracts.CheckParam(learningRate > 0, nameof(learningRate), "Must be positive"); + Contracts.CheckParam(numIterations > 0, nameof(numIterations), "Must be positive"); + Contracts.CheckParam(numLatentDimensions > 0, nameof(numLatentDimensions), "Must be positive"); + Contracts.CheckValueOrNull(advancedSettings); + Contracts.CheckValueOrNull(onFit); + + var rec = new CustomReconciler((env, labelCol, featureCols) => + { + var trainer = new FieldAwareFactorizationMachineTrainer(env, labelCol, featureCols, advancedSettings: + args => + { + args.LearningRate = learningRate; + args.Iters = numIterations; + args.LatentDim = numLatentDimensions; + advancedSettings?.Invoke(args); + }); + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + else + return trainer; + }, label, features); + return rec.Output; + } + + private sealed class CustomReconciler : TrainerEstimatorReconciler + { + private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel }; + private readonly Func> _factory; + + /// + /// The general output for binary classifiers. + /// + public (Scalar score, Scalar predictedLabel) Output { get; } + + /// + /// The output columns. + /// + protected override IEnumerable Outputs { get; } + + public CustomReconciler(Func> factory, Scalar label, Vector[] features) + : base(MakeInputs(Contracts.CheckRef(label, nameof(label)), Contracts.CheckRef(features, nameof(features))), _fixedOutputNames) + { + Contracts.AssertValue(factory); + _factory = factory; + + Output = (new Impl(this), new ImplBool(this)); + Outputs = new PipelineColumn[] { Output.score, Output.predictedLabel }; + } + + private static PipelineColumn[] MakeInputs(Scalar label, Vector[] features) + => new PipelineColumn[] { label }.Concat(features).ToArray(); + + protected override IEstimator ReconcileCore(IHostEnvironment env, string[] inputNames) + { + Contracts.AssertValue(env); + env.Assert(Utils.Size(inputNames) == Inputs.Length); + + // First input is label, rest are features. + return _factory(env, inputNames[0], inputNames.Skip(1).ToArray()); + } + + private sealed class Impl : Scalar + { + public Impl(CustomReconciler rec) : base(rec, rec.Inputs) { } + } + + private sealed class ImplBool : Scalar + { + public ImplBool(CustomReconciler rec) : base(rec, rec.Inputs) { } + } + } + } +} diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs index cd84c72347..0a92451e78 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs @@ -5,11 +5,13 @@ using System; using Microsoft.ML.Data.StaticPipe; using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Calibration; +using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.Training; -namespace Microsoft.ML.Runtime.Learners +namespace Microsoft.ML.Trainers { /// /// Extension methods and utilities for instantiating SDCA trainer estimators inside statically typed pipelines. @@ -19,6 +21,7 @@ public static class SdcaStatic /// /// Predict a target using a linear regression model trained with the SDCA trainer. /// + /// The regression context trainer object. /// The label, or dependent variable. /// The features, or independent variables. /// The optional example weights. @@ -32,7 +35,8 @@ public static class SdcaStatic /// the linear model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to /// be informed about what was learnt. /// The predicted output. - public static Scalar PredictSdcaRegression(this Scalar label, Vector features, Scalar weights = null, + public static Scalar Sdca(this RegressionContext.RegressionTrainers ctx, + Scalar label, Vector features, Scalar weights = null, float? l2Const = null, float? l1Threshold = null, int? maxIterations = null, @@ -205,6 +209,7 @@ public static (Scalar score, Scalar predictedLabel) Sdca( /// /// Predict a target using a linear multiclass classification model trained with the SDCA trainer. /// + /// The multiclass classification context trainer object. /// The label, or dependent variable. /// The features, or independent variables. /// /// The custom loss. @@ -219,7 +224,9 @@ public static (Scalar score, Scalar predictedLabel) Sdca( /// result in any way; it is only a way for the caller to be informed about what was learnt. /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. public static (Vector score, Key predictedLabel) - PredictSdcaClassification(this Key label, Vector features, + Sdca(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + Key label, + Vector features, ISupportSdcaClassificationLoss loss = null, Scalar weights = null, float? l2Const = null, diff --git a/test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj b/test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj index 9c28ddc5cb..c0c714f00b 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj +++ b/test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj @@ -7,6 +7,7 @@ + false diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 3769aff953..a1c2e04463 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -2,12 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data.StaticPipe; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.FactorizationMachine; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Trainers; using System; using Xunit; using Xunit.Abstractions; @@ -24,9 +26,11 @@ public Training(ITestOutputHelper output) : base(output) public void SdcaRegression() { var env = new ConsoleEnvironment(seed: 0); - var dataPath = GetDataPath("generated_regression_dataset.csv"); + var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); var dataSource = new MultiFileSource(dataPath); + var ctx = new RegressionContext(env); + var reader = TextLoader.CreateReader(env, c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), separator: ';', hasHeader: true); @@ -34,7 +38,7 @@ public void SdcaRegression() LinearRegressionPredictor pred = null; var est = reader.MakeNewEstimator() - .Append(r => (r.label, score: r.label.PredictSdcaRegression(r.features, maxIterations: 2, onFit: p => pred = p))); + .Append(r => (r.label, score: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2, onFit: p => pred = p))); var pipe = reader.Append(est); @@ -46,7 +50,7 @@ public void SdcaRegression() var data = model.Read(dataSource); - var metrics = RegressionEvaluator.Evaluate(data, r => r.label, r => r.score, new PoissonLoss()); + var metrics = ctx.Evaluate(data, r => r.label, r => r.score, new PoissonLoss()); // Run a sanity check against a few of the metrics. Assert.InRange(metrics.L1, 0, double.PositiveInfinity); Assert.InRange(metrics.L2, 0, double.PositiveInfinity); @@ -64,16 +68,17 @@ public void SdcaRegression() public void SdcaRegressionNameCollision() { var env = new ConsoleEnvironment(seed: 0); - var dataPath = GetDataPath("generated_regression_dataset.csv"); + var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); var dataSource = new MultiFileSource(dataPath); - + var ctx = new RegressionContext(env); + // Here we introduce another column called "Score" to collide with the name of the default output. Heh heh heh... var reader = TextLoader.CreateReader(env, c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10), Score: c.LoadText(2)), separator: ';', hasHeader: true); var est = reader.MakeNewEstimator() - .Append(r => (r.label, r.Score, score: r.label.PredictSdcaRegression(r.features, maxIterations: 2))); + .Append(r => (r.label, r.Score, score: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2))); var pipe = reader.Append(est); @@ -93,7 +98,7 @@ public void SdcaRegressionNameCollision() public void SdcaBinaryClassification() { var env = new ConsoleEnvironment(seed: 0); - var dataPath = GetDataPath("breast-cancer.txt"); + var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new BinaryClassificationContext(env); @@ -138,7 +143,7 @@ public void SdcaBinaryClassification() public void SdcaBinaryClassificationNoClaibration() { var env = new ConsoleEnvironment(seed: 0); - var dataPath = GetDataPath("breast-cancer.txt"); + var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new BinaryClassificationContext(env); @@ -177,13 +182,46 @@ public void SdcaBinaryClassificationNoClaibration() Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}"); } + [Fact] + public void FfmBinaryClassification() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); + var dataSource = new MultiFileSource(dataPath); + var ctx = new BinaryClassificationContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadBool(0), features1: c.LoadFloat(1, 4), features2: c.LoadFloat(5, 9))); + + FieldAwareFactorizationMachinePredictor pred = null; + + // With a custom loss function we no longer get calibrated predictions. + var est = reader.MakeNewEstimator() + .Append(r => (r.label, preds: ctx.Trainers.FieldAwareFactorizationMachine(r.label, new[] { r.features1, r.features2 }, onFit: p => pred = p))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + + var data = model.Read(dataSource); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.preds); + // Run a sanity check against a few of the metrics. + Assert.InRange(metrics.Accuracy, 0, 1); + Assert.InRange(metrics.Auc, 0, 1); + Assert.InRange(metrics.Auprc, 0, 1); + } + [Fact] public void SdcaMulticlass() { var env = new ConsoleEnvironment(seed: 0); - var dataPath = GetDataPath("iris.txt"); + var dataPath = GetDataPath(TestDatasets.iris.trainFilename); var dataSource = new MultiFileSource(dataPath); + var ctx = new MulticlassClassificationContext(env); var reader = TextLoader.CreateReader(env, c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); @@ -194,7 +232,8 @@ public void SdcaMulticlass() // With a custom loss function we no longer get calibrated predictions. var est = reader.MakeNewEstimator() .Append(r => (label: r.label.ToKey(), r.features)) - .Append(r => (r.label, preds: r.label.PredictSdcaClassification( + .Append(r => (r.label, preds: ctx.Trainers.Sdca( + r.label, r.features, maxIterations: 2, loss: loss, onFit: p => pred = p))); @@ -216,6 +255,10 @@ public void SdcaMulticlass() var schema = data.AsDynamic.Schema; for (int c = 0; c < schema.ColumnCount; ++c) Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}"); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2); + Assert.True(metrics.LogLoss > 0); + Assert.True(metrics.TopKAccuracy > 0); } } } diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 62b713b46d..1ac7e80877 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -1,4 +1,4 @@ - + @@ -21,18 +21,18 @@ - + - + - +