From 7f4e341c5bd645ae237e482656f10f5523f9264c Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 19 Mar 2019 10:35:36 -0700 Subject: [PATCH 01/18] Fix bug in TextLoader --- .../DataLoadSave/Text/TextLoader.cs | 5 ----- test/Microsoft.ML.Tests/TextLoaderTests.cs | 13 +++++++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index 9caf93c626..01c70b06da 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -1291,11 +1291,6 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files, if (loader == null || string.IsNullOrWhiteSpace(loader.Name)) goto LDone; - // Make sure the loader binds to us. - var info = host.ComponentCatalog.GetLoadableClassInfo(loader.Name); - if (info.Type != typeof(ILegacyDataLoader) || info.ArgType != typeof(Options)) - goto LDone; - var optionsNew = new Options(); // Set the fields of optionsNew to the arguments parsed from the file. if (!CmdParser.ParseArguments(host, loader.GetSettingsString(), optionsNew, typeof(Options), msg => ch.Error(msg))) diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index 8ef41ab3a3..7b96a8ab2e 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -598,6 +598,19 @@ public void ThrowsExceptionWithPropertyName() catch (NullReferenceException) { }; } + [Fact] + public void ParseSchemaFromTextFile() + { + var mlContext = new MLContext(seed: 1); + var fileName = GetDataPath(TestDatasets.adult.trainFilename); + var loader = mlContext.Data.CreateTextLoader(new TextLoader.Options(), new MultiFileSource(fileName)); + var data = loader.Load(new MultiFileSource(fileName)); + Assert.NotNull(data.Schema.GetColumnOrNull("Label")); + Assert.NotNull(data.Schema.GetColumnOrNull("Workclass")); + Assert.NotNull(data.Schema.GetColumnOrNull("Categories")); + Assert.NotNull(data.Schema.GetColumnOrNull("NumericFeatures")); + } + public class QuoteInput { [LoadColumn(0)] From 1a89468c6dff6a4d3d42e15b9d893eedd674a7af Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Tue, 19 Mar 2019 11:11:49 -0700 Subject: [PATCH 02/18] Clean FeatureContributionCalculation and PermutationFeatureImportance (#2966) --- ...FeatureContributionCalculationTransform.cs | 4 +- .../PFIRegressionExample.cs | 2 +- .../PfiBinaryClassificationExample.cs | 2 +- .../Model/ModelOperationsCatalog.cs | 18 ---- .../Transforms/ExplainabilityCatalog.cs | 59 +++++++++--- ...atureContributionCalculationTransformer.cs | 70 +++++++------- .../PermutationFeatureImportanceExtensions.cs | 92 +++++++++---------- .../Explainability.cs | 8 +- .../FeatureContributionTests.cs | 57 +++++++++--- 9 files changed, 171 insertions(+), 141 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs index e1ab038926..1101de6d51 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs @@ -35,12 +35,12 @@ public static void Example() // Create a Feature Contribution Calculator // Calculate the feature contributions for all features given trained model parameters // And don't normalize the contribution scores - var featureContributionCalculator = mlContext.Model.Explainability.FeatureContributionCalculation(model.Model, model.FeatureColumnName, numPositiveContributions: 11, normalize: false); + var featureContributionCalculator = mlContext.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 11, normalize: false); var outputData = featureContributionCalculator.Fit(scoredData).Transform(scoredData); // FeatureContributionCalculatingEstimator can be use as an intermediary step in a pipeline. // The features retained by FeatureContributionCalculatingEstimator will be in the FeatureContribution column. - var pipeline = mlContext.Model.Explainability.FeatureContributionCalculation(model.Model, model.FeatureColumnName, numPositiveContributions: 11) + var pipeline = mlContext.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 11) .Append(mlContext.Regression.Trainers.Ols(featureColumnName: "FeatureContributions")); var outData = featureContributionCalculator.Fit(scoredData).Transform(scoredData); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs index e14c9e2d08..4afa964850 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIRegressionExample.cs @@ -31,7 +31,7 @@ public static void Example() // Compute the permutation metrics using the properly normalized data. var transformedData = model.Transform(data); var permutationMetrics = mlContext.Regression.PermutationFeatureImportance( - linearPredictor, transformedData, label: labelName, features: "Features", permutationCount: 3); + linearPredictor, transformedData, labelColumnName: labelName, permutationCount: 3); // Now let's look at which features are most important to the model overall // Get the feature indices sorted by their impact on R-Squared diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs index aea3dfb24c..a74c659de2 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs @@ -35,7 +35,7 @@ public static void Example() // Compute the permutation metrics using the properly normalized data. var transformedData = model.Transform(data); var permutationMetrics = mlContext.BinaryClassification.PermutationFeatureImportance( - linearPredictor, transformedData, label: labelName, features: "Features", permutationCount: 3); + linearPredictor, transformedData, labelColumnName: labelName, permutationCount: 3); // Now let's look at which features are most important to the model overall. // Get the feature indices sorted by their impact on AreaUnderRocCurve. diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index b25d9bc9b1..dc4cb551d0 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -22,14 +22,10 @@ public sealed class ModelOperationsCatalog : IInternalCatalog IHostEnvironment IInternalCatalog.Environment => _env; private readonly IHostEnvironment _env; - public ExplainabilityTransforms Explainability { get; } - internal ModelOperationsCatalog(IHostEnvironment env) { Contracts.AssertValue(env); _env = env; - - Explainability = new ExplainabilityTransforms(this); } /// @@ -228,20 +224,6 @@ public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader(); } - /// - /// The catalog of model explainability operations. - /// - public sealed class ExplainabilityTransforms : IInternalCatalog - { - IHostEnvironment IInternalCatalog.Environment => _env; - private readonly IHostEnvironment _env; - - internal ExplainabilityTransforms(ModelOperationsCatalog owner) - { - _env = owner._env; - } - } - /// /// Create a prediction engine for one-time prediction. /// diff --git a/src/Microsoft.ML.Data/Transforms/ExplainabilityCatalog.cs b/src/Microsoft.ML.Data/Transforms/ExplainabilityCatalog.cs index c0df4a566c..083d0f6380 100644 --- a/src/Microsoft.ML.Data/Transforms/ExplainabilityCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/ExplainabilityCatalog.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Calibrators; using Microsoft.ML.Data; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; @@ -17,19 +18,53 @@ public static class ExplainabilityCatalog /// Note that this functionality is not supported by all the models. See for a list of the suported models. /// /// The model explainability operations catalog. - /// Trained model parameters that support Feature Contribution Calculation and which will be used for scoring. - /// The name of the feature column that will be used as input. - /// The number of positive contributions to report, sorted from highest magnitude to lowest magnitude. - /// Note that if there are fewer features with positive contributions than , the rest will be returned as zeros. - /// The number of negative contributions to report, sorted from highest magnitude to lowest magnitude. - /// Note that if there are fewer features with negative contributions than , the rest will be returned as zeros. + /// A that supports Feature Contribution Calculation, + /// and which will also be used for scoring. + /// The number of positive contributions to report, sorted from highest magnitude to lowest magnitude. + /// Note that if there are fewer features with positive contributions than , the rest will be returned as zeros. + /// The number of negative contributions to report, sorted from highest magnitude to lowest magnitude. + /// Note that if there are fewer features with negative contributions than , the rest will be returned as zeros. /// Whether the feature contributions should be normalized to the [-1, 1] interval. - public static FeatureContributionCalculatingEstimator FeatureContributionCalculation(this ModelOperationsCatalog.ExplainabilityTransforms catalog, - ICalculateFeatureContribution modelParameters, - string featureColumnName = DefaultColumnNames.Features, - int numPositiveContributions = FeatureContributionDefaults.NumPositiveContributions, - int numNegativeContributions = FeatureContributionDefaults.NumNegativeContributions, + /// + /// + /// + /// + /// + public static FeatureContributionCalculatingEstimator CalculateFeatureContribution(this TransformsCatalog catalog, + ISingleFeaturePredictionTransformer predictionTransformer, + int numberOfPositiveContributions = FeatureContributionDefaults.NumberOfPositiveContributions, + int numberOfNegativeContributions = FeatureContributionDefaults.NumberOfNegativeContributions, bool normalize = FeatureContributionDefaults.Normalize) - => new FeatureContributionCalculatingEstimator(CatalogUtils.GetEnvironment(catalog), modelParameters, featureColumnName, numPositiveContributions, numNegativeContributions, normalize); + => new FeatureContributionCalculatingEstimator(CatalogUtils.GetEnvironment(catalog), predictionTransformer.Model, numberOfPositiveContributions, numberOfNegativeContributions, predictionTransformer.FeatureColumnName, normalize); + + /// + /// Feature Contribution Calculation computes model-specific contribution scores for each feature. + /// Note that this functionality is not supported by all the models. See for a list of the suported models. + /// + /// The model explainability operations catalog. + /// A that supports Feature Contribution Calculation, + /// and which will also be used for scoring. + /// The number of positive contributions to report, sorted from highest magnitude to lowest magnitude. + /// Note that if there are fewer features with positive contributions than , the rest will be returned as zeros. + /// The number of negative contributions to report, sorted from highest magnitude to lowest magnitude. + /// Note that if there are fewer features with negative contributions than , the rest will be returned as zeros. + /// Whether the feature contributions should be normalized to the [-1, 1] interval. + /// + /// + /// + /// + /// + public static FeatureContributionCalculatingEstimator CalculateFeatureContribution(this TransformsCatalog catalog, + ISingleFeaturePredictionTransformer> predictionTransformer, + int numberOfPositiveContributions = FeatureContributionDefaults.NumberOfPositiveContributions, + int numberOfNegativeContributions = FeatureContributionDefaults.NumberOfNegativeContributions, + bool normalize = FeatureContributionDefaults.Normalize) + where TModelParameters : class, ICalculateFeatureContribution + where TCalibrator : class, ICalibrator + => new FeatureContributionCalculatingEstimator(CatalogUtils.GetEnvironment(catalog), predictionTransformer.Model.SubModel, numberOfPositiveContributions, numberOfNegativeContributions, predictionTransformer.FeatureColumnName, normalize); } } diff --git a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs index 120fe41097..2f17b896ea 100644 --- a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs @@ -67,13 +67,6 @@ namespace Microsoft.ML.Transforms /// See the sample below for an example of how to compute feature importance using the FeatureContributionCalculatingTransformer. /// /// - /// - /// - /// - /// - /// public sealed class FeatureContributionCalculatingTransformer : OneToOneTransformerBase { internal sealed class Options : TransformInputBase @@ -85,10 +78,10 @@ internal sealed class Options : TransformInputBase public string FeatureColumn = DefaultColumnNames.Features; [Argument(ArgumentType.AtMostOnce, HelpText = "Number of top contributions", SortOrder = 3)] - public int Top = FeatureContributionCalculatingEstimator.Defaults.NumPositiveContributions; + public int Top = FeatureContributionCalculatingEstimator.Defaults.NumberOfPositiveContributions; [Argument(ArgumentType.AtMostOnce, HelpText = "Number of bottom contributions", SortOrder = 4)] - public int Bottom = FeatureContributionCalculatingEstimator.Defaults.NumNegativeContributions; + public int Bottom = FeatureContributionCalculatingEstimator.Defaults.NumberOfNegativeContributions; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether or not output of Features contribution should be normalized", ShortName = "norm", SortOrder = 5)] public bool Normalize = FeatureContributionCalculatingEstimator.Defaults.Normalize; @@ -122,24 +115,24 @@ private static VersionInfo GetVersionInfo() /// /// The environment to use. /// Trained model parameters that support Feature Contribution Calculation and which will be used for scoring. - /// The name of the feature column that will be used as input. - /// The number of positive contributions to report, sorted from highest magnitude to lowest magnitude. - /// Note that if there are fewer features with positive contributions than , the rest will be returned as zeros. - /// The number of negative contributions to report, sorted from highest magnitude to lowest magnitude. - /// Note that if there are fewer features with negative contributions than , the rest will be returned as zeros. + /// The name of the feature column that will be used as input. + /// The number of positive contributions to report, sorted from highest magnitude to lowest magnitude. + /// Note that if there are fewer features with positive contributions than , the rest will be returned as zeros. + /// The number of negative contributions to report, sorted from highest magnitude to lowest magnitude. + /// Note that if there are fewer features with negative contributions than , the rest will be returned as zeros. /// Whether the feature contributions should be normalized to the [-1, 1] interval. internal FeatureContributionCalculatingTransformer(IHostEnvironment env, ICalculateFeatureContribution modelParameters, - string featureColumn = DefaultColumnNames.Features, - int numPositiveContributions = FeatureContributionCalculatingEstimator.Defaults.NumPositiveContributions, - int numNegativeContributions = FeatureContributionCalculatingEstimator.Defaults.NumNegativeContributions, + string featureColumnName = DefaultColumnNames.Features, + int numberOfPositiveContributions = FeatureContributionCalculatingEstimator.Defaults.NumberOfPositiveContributions, + int numberOfNegativeContributions = FeatureContributionCalculatingEstimator.Defaults.NumberOfNegativeContributions, bool normalize = FeatureContributionCalculatingEstimator.Defaults.Normalize) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(FeatureContributionCalculatingTransformer)), new[] { (name: DefaultColumnNames.FeatureContributions, source: featureColumn) }) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(FeatureContributionCalculatingTransformer)), new[] { (name: DefaultColumnNames.FeatureContributions, source: featureColumnName) }) { Host.CheckValue(modelParameters, nameof(modelParameters)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - if (numPositiveContributions < 0) + Host.CheckNonEmpty(featureColumnName, nameof(featureColumnName)); + if (numberOfPositiveContributions < 0) throw Host.Except($"Number of top contribution must be non negative"); - if (numNegativeContributions < 0) + if (numberOfNegativeContributions < 0) throw Host.Except($"Number of bottom contribution must be non negative"); // If a predictor implements ICalculateFeatureContribution, it also implements the internal interface IFeatureContributionMapper. @@ -147,8 +140,8 @@ internal FeatureContributionCalculatingTransformer(IHostEnvironment env, ICalcul _predictor = modelParameters as IFeatureContributionMapper; Host.AssertValue(_predictor); - Top = numPositiveContributions; - Bottom = numNegativeContributions; + Top = numberOfPositiveContributions; + Bottom = numberOfNegativeContributions; Normalize = normalize; } @@ -283,8 +276,8 @@ public sealed class FeatureContributionCalculatingEstimator : TrivialEstimator for a list of the suported models. /// /// The environment to use. - /// Trained model parameters that support Feature Contribution Calculation and which will be used for scoring. - /// The name of the feature column that will be used as input. - /// The number of positive contributions to report, sorted from highest magnitude to lowest magnitude. - /// Note that if there are fewer features with positive contributions than , the rest will be returned as zeros. - /// The number of negative contributions to report, sorted from highest magnitude to lowest magnitude. - /// Note that if there are fewer features with negative contributions than , the rest will be returned as zeros. + /// A that supports Feature Contribution Calculation, + /// and which will also be used for scoring. + /// The number of positive contributions to report, sorted from highest magnitude to lowest magnitude. + /// Note that if there are fewer features with positive contributions than , the rest will be returned as zeros. + /// The number of negative contributions to report, sorted from highest magnitude to lowest magnitude. + /// Note that if there are fewer features with negative contributions than , the rest will be returned as zeros. + /// TODO /// Whether the feature contributions should be normalized to the [-1, 1] interval. - internal FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateFeatureContribution modelParameters, - string featureColumn = DefaultColumnNames.Features, - int numPositiveContributions = Defaults.NumPositiveContributions, - int numNegativeContributions = Defaults.NumNegativeContributions, + internal FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateFeatureContribution model, + int numberOfPositiveContributions = Defaults.NumberOfPositiveContributions, + int numberOfNegativeContributions = Defaults.NumberOfNegativeContributions, + string featureColumnName = DefaultColumnNames.Features, bool normalize = Defaults.Normalize) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(FeatureContributionCalculatingTransformer)), - new FeatureContributionCalculatingTransformer(env, modelParameters, featureColumn, numPositiveContributions, numNegativeContributions, normalize)) + new FeatureContributionCalculatingTransformer(env, model, featureColumnName, numberOfPositiveContributions, numberOfNegativeContributions, normalize)) { - _featureColumn = featureColumn; - _predictor = modelParameters; + _featureColumn = featureColumnName; + _predictor = model; } /// diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs index c66a7ed0de..e356769ac5 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs @@ -46,36 +46,34 @@ public static class PermutationFeatureImportanceExtensions /// /// /// The regression catalog. - /// The model to evaluate. + /// The model on which to evaluate feature importance. /// The evaluation data set. - /// Label column name. - /// Feature column name. + /// Label column name. /// Use features weight to pre-filter features. - /// Limit the number of examples to evaluate on. null means examples (up to ~ 2 bln) from input will be used. + /// Limit the number of examples to evaluate on. means up to ~2 bln examples from will be used. /// The number of permutations to perform. /// Array of per-feature 'contributions' to the score. public static ImmutableArray PermutationFeatureImportance( this RegressionCatalog catalog, - IPredictionTransformer model, + ISingleFeaturePredictionTransformer predictionTransformer, IDataView data, - string label = DefaultColumnNames.Label, - string features = DefaultColumnNames.Features, + string labelColumnName = DefaultColumnNames.Label, bool useFeatureWeightFilter = false, - int? topExamples = null, + int? numberOfExamplesToUse = null, int permutationCount = 1) where TModel : class { return PermutationFeatureImportance.GetImportanceMetricsMatrix( catalog.GetEnvironment(), - model, + predictionTransformer, data, () => new RegressionMetricsStatistics(), - idv => catalog.Evaluate(idv, label), + idv => catalog.Evaluate(idv, labelColumnName), RegressionDelta, - features, + predictionTransformer.FeatureColumnName, permutationCount, useFeatureWeightFilter, - topExamples); + numberOfExamplesToUse); } private static RegressionMetrics RegressionDelta( @@ -124,36 +122,34 @@ private static RegressionMetrics RegressionDelta( /// /// /// The binary classification catalog. - /// The model to evaluate. + /// The model on which to evaluate feature importance. /// The evaluation data set. - /// Label column name. - /// Feature column name. + /// Label column name. /// Use features weight to pre-filter features. - /// Limit the number of examples to evaluate on. null means examples (up to ~ 2 bln) from input will be used. + /// Limit the number of examples to evaluate on. means up to ~2 bln examples from will be used. /// The number of permutations to perform. /// Array of per-feature 'contributions' to the score. public static ImmutableArray PermutationFeatureImportance( this BinaryClassificationCatalog catalog, - IPredictionTransformer model, + ISingleFeaturePredictionTransformer predictionTransformer, IDataView data, - string label = DefaultColumnNames.Label, - string features = DefaultColumnNames.Features, + string labelColumnName = DefaultColumnNames.Label, bool useFeatureWeightFilter = false, - int? topExamples = null, + int? numberOfExamplesToUse = null, int permutationCount = 1) where TModel : class { return PermutationFeatureImportance.GetImportanceMetricsMatrix( catalog.GetEnvironment(), - model, + predictionTransformer, data, () => new BinaryClassificationMetricsStatistics(), - idv => catalog.Evaluate(idv, label), + idv => catalog.Evaluate(idv, labelColumnName), BinaryClassifierDelta, - features, + predictionTransformer.FeatureColumnName, permutationCount, useFeatureWeightFilter, - topExamples); + numberOfExamplesToUse); } private static BinaryClassificationMetrics BinaryClassifierDelta( @@ -199,36 +195,34 @@ private static BinaryClassificationMetrics BinaryClassifierDelta( /// /// /// The clustering catalog. - /// The model to evaluate. + /// The model on which to evaluate feature importance. /// The evaluation data set. - /// Label column name. - /// Feature column name. + /// Label column name. /// Use features weight to pre-filter features. - /// Limit the number of examples to evaluate on. null means examples (up to ~ 2 bln) from input will be used. + /// Limit the number of examples to evaluate on. means up to ~2 bln examples from will be used. /// The number of permutations to perform. /// Array of per-feature 'contributions' to the score. public static ImmutableArray PermutationFeatureImportance( this MulticlassClassificationCatalog catalog, - IPredictionTransformer model, + ISingleFeaturePredictionTransformer predictionTransformer, IDataView data, - string label = DefaultColumnNames.Label, - string features = DefaultColumnNames.Features, + string labelColumnName = DefaultColumnNames.Label, bool useFeatureWeightFilter = false, - int? topExamples = null, + int? numberOfExamplesToUse = null, int permutationCount = 1) where TModel : class { return PermutationFeatureImportance.GetImportanceMetricsMatrix( catalog.GetEnvironment(), - model, + predictionTransformer, data, () => new MulticlassClassificationMetricsStatistics(), - idv => catalog.Evaluate(idv, label), + idv => catalog.Evaluate(idv, labelColumnName), MulticlassClassificationDelta, - features, + predictionTransformer.FeatureColumnName, permutationCount, useFeatureWeightFilter, - topExamples); + numberOfExamplesToUse); } private static MulticlassClassificationMetrics MulticlassClassificationDelta( @@ -279,38 +273,36 @@ private static MulticlassClassificationMetrics MulticlassClassificationDelta( /// /// /// The clustering catalog. - /// The model to evaluate. + /// The model on which to evaluate feature importance. /// The evaluation data set. - /// Label column name. - /// GroupId column name - /// Feature column name. + /// Label column name. + /// GroupId column name /// Use features weight to pre-filter features. - /// Limit the number of examples to evaluate on. null means examples (up to ~ 2 bln) from input will be used. + /// Limit the number of examples to evaluate on. means up to ~2 bln examples from will be used. /// The number of permutations to perform. /// Array of per-feature 'contributions' to the score. public static ImmutableArray PermutationFeatureImportance( this RankingCatalog catalog, - IPredictionTransformer model, + ISingleFeaturePredictionTransformer predictionTransformer, IDataView data, - string label = DefaultColumnNames.Label, - string groupId = DefaultColumnNames.GroupId, - string features = DefaultColumnNames.Features, + string labelColumnName = DefaultColumnNames.Label, + string rowGroupColumnName = DefaultColumnNames.GroupId, bool useFeatureWeightFilter = false, - int? topExamples = null, + int? numberOfExamplesToUse = null, int permutationCount = 1) where TModel : class { return PermutationFeatureImportance.GetImportanceMetricsMatrix( catalog.GetEnvironment(), - model, + predictionTransformer, data, () => new RankingMetricsStatistics(), - idv => catalog.Evaluate(idv, label, groupId), + idv => catalog.Evaluate(idv, labelColumnName, rowGroupColumnName), RankingDelta, - features, + predictionTransformer.FeatureColumnName, permutationCount, useFeatureWeightFilter, - topExamples); + numberOfExamplesToUse); } private static RankingMetrics RankingDelta( diff --git a/test/Microsoft.ML.Functional.Tests/Explainability.cs b/test/Microsoft.ML.Functional.Tests/Explainability.cs index 7652233052..11729c3959 100644 --- a/test/Microsoft.ML.Functional.Tests/Explainability.cs +++ b/test/Microsoft.ML.Functional.Tests/Explainability.cs @@ -152,7 +152,7 @@ public void LocalFeatureImportanceForLinearModel() // Create a Feature Contribution Calculator. var predictor = model.LastTransformer; - var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumnName, normalize: false); + var featureContributions = mlContext.Transforms.CalculateFeatureContribution(predictor, normalize: false); // Compute the contributions var outputData = featureContributions.Fit(scoredData).Transform(scoredData); @@ -189,7 +189,7 @@ public void LocalFeatureImportanceForFastTreeModel() // Create a Feature Contribution Calculator. var predictor = model.LastTransformer; - var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumnName, normalize: false); + var featureContributions = mlContext.Transforms.CalculateFeatureContribution(predictor, normalize: false); // Compute the contributions var outputData = featureContributions.Fit(scoredData).Transform(scoredData); @@ -226,7 +226,7 @@ public void LocalFeatureImportanceForFastForestModel() // Create a Feature Contribution Calculator. var predictor = model.LastTransformer; - var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumnName, normalize: false); + var featureContributions = mlContext.Transforms.CalculateFeatureContribution(predictor, normalize: false); // Compute the contributions var outputData = featureContributions.Fit(scoredData).Transform(scoredData); @@ -264,7 +264,7 @@ public void LocalFeatureImportanceForGamModel() // Create a Feature Contribution Calculator. var predictor = model.LastTransformer; - var featureContributions = mlContext.Model.Explainability.FeatureContributionCalculation(predictor.Model, predictor.FeatureColumnName, normalize: false); + var featureContributions = mlContext.Transforms.CalculateFeatureContribution(predictor, normalize: false); // Compute the contributions var outputData = featureContributions.Fit(scoredData).Transform(scoredData); diff --git a/test/Microsoft.ML.Tests/FeatureContributionTests.cs b/test/Microsoft.ML.Tests/FeatureContributionTests.cs index a294bb8639..63aabbcd95 100644 --- a/test/Microsoft.ML.Tests/FeatureContributionTests.cs +++ b/test/Microsoft.ML.Tests/FeatureContributionTests.cs @@ -4,6 +4,7 @@ using System; using System.IO; +using Microsoft.ML.Calibrators; using Microsoft.ML.Data; using Microsoft.ML.Data.IO; using Microsoft.ML.Internal.Utilities; @@ -29,11 +30,11 @@ public void FeatureContributionEstimatorWorkout() var data = GetSparseDataset(); var model = ML.Regression.Trainers.Ols().Fit(data); - var estPipe = new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumnName) - .Append(new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumnName, normalize: false)) - .Append(new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumnName, numPositiveContributions: 0)) - .Append(new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumnName, numNegativeContributions: 0)) - .Append(new FeatureContributionCalculatingEstimator(ML, model.Model, model.FeatureColumnName, numPositiveContributions: 0, numNegativeContributions: 0)); + var estPipe = ML.Transforms.CalculateFeatureContribution(model) + .Append(ML.Transforms.CalculateFeatureContribution(model, normalize: false)) + .Append(ML.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 0)) + .Append(ML.Transforms.CalculateFeatureContribution(model, numberOfNegativeContributions: 0)) + .Append(ML.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 0, numberOfNegativeContributions: 0)); TestEstimatorCore(estPipe, data); Done(); @@ -182,7 +183,7 @@ public void TestGAMBinary() } private void TestFeatureContribution( - ITrainerEstimator, IPredictor> trainer, + ITrainerEstimator, ICalculateFeatureContribution> trainer, IDataView data, string testFile, int precision = 6) @@ -190,28 +191,54 @@ private void TestFeatureContribution( // Train the model. var model = trainer.Fit(data); - // Extract the predictor, check that it supports feature contribution. - var predictor = model.Model as ICalculateFeatureContribution; - Assert.NotNull(predictor); + // Calculate feature contributions. + var est = ML.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 3, numberOfNegativeContributions: 0) + .Append(ML.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 0, numberOfNegativeContributions: 3)) + .Append(ML.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 1, numberOfNegativeContributions: 1)) + .Append(ML.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 1, numberOfNegativeContributions: 1, normalize: false)); + + TestEstimatorCore(est, data); + + // Verify output. + CheckOutput(est, data, testFile, precision); + Done(); + } + + private void TestFeatureContribution( + ITrainerEstimator>, CalibratedModelParametersBase> trainer, + IDataView data, + string testFile, + int precision = 6) + where TModelParameters : class, ICalculateFeatureContribution + where TCalibrator : class, ICalibrator + { + // Train the model. + var model = trainer.Fit(data); // Calculate feature contributions. - var est = new FeatureContributionCalculatingEstimator(ML, predictor, "Features", numPositiveContributions: 3, numNegativeContributions: 0) - .Append(new FeatureContributionCalculatingEstimator(ML, predictor, "Features", numPositiveContributions: 0, numNegativeContributions: 3)) - .Append(new FeatureContributionCalculatingEstimator(ML, predictor, "Features", numPositiveContributions: 1, numNegativeContributions: 1)) - .Append(new FeatureContributionCalculatingEstimator(ML, predictor, "Features", numPositiveContributions: 1, numNegativeContributions: 1, normalize: false)); + var est = ML.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 3, numberOfNegativeContributions: 0) + .Append(ML.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 0, numberOfNegativeContributions: 3)) + .Append(ML.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 1, numberOfNegativeContributions: 1)) + .Append(ML.Transforms.CalculateFeatureContribution(model, numberOfPositiveContributions: 1, numberOfNegativeContributions: 1, normalize: false)); TestEstimatorCore(est, data); + // Verify output. + CheckOutput(est, data, testFile, precision); + Done(); + } + + private void CheckOutput(IEstimator estimator, IDataView data, string testFile, int precision = 6) + { var outputPath = GetOutputPath("FeatureContribution", testFile + ".tsv"); using (var ch = Env.Start("save")) { var saver = new TextSaver(ML, new TextSaver.Arguments { Silent = true, OutputHeader = false }); - var savedData = ML.Data.TakeRows(est.Fit(data).Transform(data), 4); + var savedData = ML.Data.TakeRows(estimator.Fit(data).Transform(data), 4); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); } CheckEquality("FeatureContribution", testFile + ".tsv", digitsOfPrecision: precision); - Done(); } /// From aea88dcc72df85b9fd1e4fe5cfe00d6ac94d3dee Mon Sep 17 00:00:00 2001 From: Scott Inglis Date: Tue, 19 Mar 2019 12:19:19 -0700 Subject: [PATCH 03/18] Updating LightGBM Arguments (#2948) * Breaking down the LightGBM options class into separate option classes for each LightGBM trainer. * Refactored the Boost Option classes * Hides the interface for the Booster Parameter Factory (IBoosterParameterFactory). Fixes #2559 Fixes #2618 --- .../LightGbmWithOptions.cs | 5 +- .../LightGbmWithOptions.cs | 5 +- .../Trainers/Ranking/LightGbmWithOptions.cs | 5 +- .../Regression/LightGbmWithOptions.cs | 5 +- .../LightGbmStaticExtensions.cs | 33 +- .../LightGbmArguments.cs | 874 ++++++------------ .../LightGbmBinaryTrainer.cs | 104 ++- src/Microsoft.ML.LightGbm/LightGbmCatalog.cs | 19 +- .../LightGbmMulticlassTrainer.cs | 99 +- .../LightGbmRankingTrainer.cs | 96 +- .../LightGbmRegressionTrainer.cs | 71 +- .../LightGbmTrainerBase.cs | 318 ++++++- .../WrappedLightGbmInterface.cs | 27 + .../Common/EntryPoints/core_ep-list.tsv | 8 +- .../Common/EntryPoints/core_manifest.json | 534 ++++------- ...MReg-CV-generatedRegressionDataset-out.txt | 4 +- ...ainTest-generatedRegressionDataset-out.txt | 2 +- ...-CV-generatedRegressionDataset.MAE-out.txt | 6 +- ...e-CV-generatedRegressionDataset.MAE-rp.txt | 4 +- ...est-generatedRegressionDataset.MAE-out.txt | 4 +- ...Test-generatedRegressionDataset.MAE-rp.txt | 4 +- ...CV-generatedRegressionDataset.RMSE-out.txt | 2 +- ...-CV-generatedRegressionDataset.RMSE-rp.txt | 4 +- ...st-generatedRegressionDataset.RMSE-out.txt | 2 +- ...est-generatedRegressionDataset.RMSE-rp.txt | 4 +- .../TestPredictors.cs | 2 +- test/Microsoft.ML.TestFramework/Learners.cs | 4 +- .../TensorflowTests.cs | 2 +- .../TrainerEstimators/TreeEstimators.cs | 23 +- 29 files changed, 1159 insertions(+), 1111 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/LightGbmWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/LightGbmWithOptions.cs index 815e4694b4..41f85c327c 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/LightGbmWithOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/LightGbmWithOptions.cs @@ -1,5 +1,4 @@ -using Microsoft.ML.Trainers.LightGbm; -using static Microsoft.ML.Trainers.LightGbm.Options; +using Microsoft.ML.Trainers.LightGbm; namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification { @@ -19,7 +18,7 @@ public static void Example() // Create the pipeline with LightGbm Estimator using advanced options. var pipeline = mlContext.BinaryClassification.Trainers.LightGbm( - new Options + new LightGbmBinaryTrainer.Options { Booster = new GossBooster.Options { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbmWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbmWithOptions.cs index 5146bf66d8..2b24423e5e 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbmWithOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbmWithOptions.cs @@ -3,7 +3,6 @@ using Microsoft.ML.Data; using Microsoft.ML.SamplesUtils; using Microsoft.ML.Trainers.LightGbm; -using static Microsoft.ML.Trainers.LightGbm.Options; namespace Microsoft.ML.Samples.Dynamic.Trainers.MulticlassClassification { @@ -33,11 +32,11 @@ public static void Example() // - Convert the string labels into key types. // - Apply LightGbm multiclass trainer with advanced options. var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelIndex", "Label") - .Append(mlContext.MulticlassClassification.Trainers.LightGbm(new Options + .Append(mlContext.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options { LabelColumnName = "LabelIndex", FeatureColumnName = "Features", - Booster = new DartBooster.Options + Booster = new DartBooster.Options() { TreeDropFraction = 0.15, XgboostDartMode = false diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Ranking/LightGbmWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Ranking/LightGbmWithOptions.cs index bdc966fc92..1cb039bb18 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Ranking/LightGbmWithOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Ranking/LightGbmWithOptions.cs @@ -1,5 +1,4 @@ using Microsoft.ML.Trainers.LightGbm; -using static Microsoft.ML.Trainers.LightGbm.Options; namespace Microsoft.ML.Samples.Dynamic.Trainers.Ranking { @@ -21,13 +20,13 @@ public static void Example() // Create the Estimator pipeline. For simplicity, we will train a small tree with 4 leaves and 2 boosting iterations. var pipeline = mlContext.Ranking.Trainers.LightGbm( - new Options + new LightGbmRankingTrainer.Options { NumberOfLeaves = 4, MinimumExampleCountPerGroup = 10, LearningRate = 0.1, NumberOfIterations = 2, - Booster = new TreeBooster.Options + Booster = new GradientBooster.Options { FeatureFraction = 0.9 } diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGbmWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGbmWithOptions.cs index 4d1eb39a6a..f6d3eeb1f9 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGbmWithOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGbmWithOptions.cs @@ -2,7 +2,6 @@ using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.Trainers.LightGbm; -using static Microsoft.ML.Trainers.LightGbm.Options; namespace Microsoft.ML.Samples.Dynamic.Trainers.Regression { @@ -36,13 +35,13 @@ public static void Example() .Where(name => name != labelName) // Drop the Label .ToArray(); var pipeline = mlContext.Transforms.Concatenate("Features", featureNames) - .Append(mlContext.Regression.Trainers.LightGbm(new Options + .Append(mlContext.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options { LabelColumnName = labelName, NumberOfLeaves = 4, MinimumExampleCountPerLeaf = 6, LearningRate = 0.001, - Booster = new GossBooster.Options + Booster = new GossBooster.Options() { TopRate = 0.3, OtherRate = 0.2 diff --git a/src/Microsoft.ML.LightGbm.StaticPipe/LightGbmStaticExtensions.cs b/src/Microsoft.ML.LightGbm.StaticPipe/LightGbmStaticExtensions.cs index 129ab2f593..0fb8bf312b 100644 --- a/src/Microsoft.ML.LightGbm.StaticPipe/LightGbmStaticExtensions.cs +++ b/src/Microsoft.ML.LightGbm.StaticPipe/LightGbmStaticExtensions.cs @@ -42,7 +42,7 @@ public static Scalar LightGbm(this RegressionCatalog.RegressionTrainers c int? numberOfLeaves = null, int? minimumExampleCountPerLeaf = null, double? learningRate = null, - int numberOfIterations = Options.Defaults.NumberOfIterations, + int numberOfIterations = Defaults.NumberOfIterations, Action onFit = null) { CheckUserValues(label, features, weights, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations, onFit); @@ -76,10 +76,11 @@ public static Scalar LightGbm(this RegressionCatalog.RegressionTrainers c /// The Score output column indicating the predicted value. public static Scalar LightGbm(this RegressionCatalog.RegressionTrainers catalog, Scalar label, Vector features, Scalar weights, - Options options, + LightGbmRegressionTrainer.Options options, Action onFit = null) { - CheckUserValues(label, features, weights, options, onFit); + Contracts.CheckValue(options, nameof(options)); + CheckUserValues(label, features, weights, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => @@ -128,7 +129,7 @@ public static (Scalar score, Scalar probability, Scalar pred int? numberOfLeaves = null, int? minimumExampleCountPerLeaf = null, double? learningRate = null, - int numberOfIterations = Options.Defaults.NumberOfIterations, + int numberOfIterations = Defaults.NumberOfIterations, Action> onFit = null) { CheckUserValues(label, features, weights, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations, onFit); @@ -165,10 +166,11 @@ public static (Scalar score, Scalar probability, Scalar pred /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label. public static (Scalar score, Scalar probability, Scalar predictedLabel) LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, Scalar label, Vector features, Scalar weights, - Options options, + LightGbmBinaryTrainer.Options options, Action> onFit = null) { - CheckUserValues(label, features, weights, options, onFit); + Contracts.CheckValue(options, nameof(options)); + CheckUserValues(label, features, weights, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => @@ -215,7 +217,7 @@ public static Scalar LightGbm(this RankingCatalog.RankingTrainers c int? numberOfLeaves = null, int? minimumExampleCountPerLeaf = null, double? learningRate = null, - int numberOfIterations = Options.Defaults.NumberOfIterations, + int numberOfIterations = Defaults.NumberOfIterations, Action onFit = null) { CheckUserValues(label, features, weights, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations, onFit); @@ -253,10 +255,11 @@ public static Scalar LightGbm(this RankingCatalog.RankingTrainers c /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label. public static Scalar LightGbm(this RankingCatalog.RankingTrainers catalog, Scalar label, Vector features, Key groupId, Scalar weights, - Options options, + LightGbmRankingTrainer.Options options, Action onFit = null) { - CheckUserValues(label, features, weights, options, onFit); + Contracts.CheckValue(options, nameof(options)); + CheckUserValues(label, features, weights, onFit); Contracts.CheckValue(groupId, nameof(groupId)); var rec = new TrainerEstimatorReconciler.Ranker( @@ -309,7 +312,7 @@ public static (Vector score, Key predictedLabel) int? numberOfLeaves = null, int? minimumExampleCountPerLeaf = null, double? learningRate = null, - int numberOfIterations = Options.Defaults.NumberOfIterations, + int numberOfIterations = Defaults.NumberOfIterations, Action onFit = null) { CheckUserValues(label, features, weights, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations, onFit); @@ -347,10 +350,11 @@ public static (Vector score, Key predictedLabel) Key label, Vector features, Scalar weights, - Options options, + LightGbmMulticlassTrainer.Options options, Action onFit = null) { - CheckUserValues(label, features, weights, options, onFit); + Contracts.CheckValue(options, nameof(options)); + CheckUserValues(label, features, weights, onFit); var rec = new TrainerEstimatorReconciler.MulticlassClassificationReconciler( (env, labelName, featuresName, weightsName) => @@ -386,14 +390,11 @@ private static void CheckUserValues(PipelineColumn label, Vector features Contracts.CheckValueOrNull(onFit); } - private static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, - Options options, - Delegate onFit) + private static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, Delegate onFit) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); Contracts.CheckValueOrNull(weights); - Contracts.CheckValue(options, nameof(options)); Contracts.CheckValueOrNull(onFit); } } diff --git a/src/Microsoft.ML.LightGbm/LightGbmArguments.cs b/src/Microsoft.ML.LightGbm/LightGbmArguments.cs index cb8a18b384..0ac5a63572 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmArguments.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmArguments.cs @@ -1,10 +1,5 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Collections.Generic; +using System.Collections.Generic; using System.Reflection; -using System.Text; using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.EntryPoints; @@ -12,667 +7,344 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Trainers.LightGbm; -[assembly: LoadableClass(typeof(Options.TreeBooster), typeof(Options.TreeBooster.Options), - typeof(SignatureLightGBMBooster), Options.TreeBooster.FriendlyName, Options.TreeBooster.Name)] -[assembly: LoadableClass(typeof(Options.DartBooster), typeof(Options.DartBooster.Options), - typeof(SignatureLightGBMBooster), Options.DartBooster.FriendlyName, Options.DartBooster.Name)] -[assembly: LoadableClass(typeof(Options.GossBooster), typeof(Options.GossBooster.Options), - typeof(SignatureLightGBMBooster), Options.GossBooster.FriendlyName, Options.GossBooster.Name)] +[assembly: LoadableClass(typeof(GradientBooster), typeof(GradientBooster.Options), + typeof(SignatureLightGBMBooster), GradientBooster.FriendlyName, GradientBooster.Name)] +[assembly: LoadableClass(typeof(DartBooster), typeof(DartBooster.Options), + typeof(SignatureLightGBMBooster), DartBooster.FriendlyName, DartBooster.Name)] +[assembly: LoadableClass(typeof(GossBooster), typeof(GossBooster.Options), + typeof(SignatureLightGBMBooster), GossBooster.FriendlyName, GossBooster.Name)] -[assembly: EntryPointModule(typeof(Options.TreeBooster.Options))] -[assembly: EntryPointModule(typeof(Options.DartBooster.Options))] -[assembly: EntryPointModule(typeof(Options.GossBooster.Options))] +[assembly: EntryPointModule(typeof(GradientBooster.Options))] +[assembly: EntryPointModule(typeof(DartBooster.Options))] +[assembly: EntryPointModule(typeof(GossBooster.Options))] namespace Microsoft.ML.Trainers.LightGbm { internal delegate void SignatureLightGBMBooster(); [TlcModule.ComponentKind("BoosterParameterFunction")] - public interface ISupportBoosterParameterFactory : IComponentFactory + internal interface IBoosterParameterFactory : IComponentFactory { + new BoosterParameterBase CreateComponent(IHostEnvironment env); } - public interface IBoosterParameter + public abstract class BoosterParameterBase { - void UpdateParameters(Dictionary res); - } - - /// - /// Options for LightGBM trainer. - /// - /// - /// LightGBM is an external library that's integrated with ML.NET. For detailed information about the parameters - /// please see https://github.com/Microsoft/LightGBM/blob/master/docs/Parameters.rst. - /// - public sealed class Options : TrainerInputBaseWithGroupId - { - public abstract class BoosterParameter : IBoosterParameter - where TOptions : class, new() + private protected static Dictionary NameMapping = new Dictionary() { - private protected TOptions BoosterParameterOptions { get; } - - private protected BoosterParameter(TOptions options) - { - BoosterParameterOptions = options; - } - - /// - /// Update the parameters by specific Booster, will update parameters into "res" directly. - /// - internal virtual void UpdateParameters(Dictionary res) - { - FieldInfo[] fields = BoosterParameterOptions.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); - foreach (var field in fields) - { - var attribute = field.GetCustomAttribute(false); - - if (attribute == null) - continue; - - res[GetOptionName(field.Name)] = field.GetValue(BoosterParameterOptions); - } - } - - void IBoosterParameter.UpdateParameters(Dictionary res) => UpdateParameters(res); - } - - private static string GetOptionName(string name) - { - if (_nameMapping.ContainsKey(name)) - return _nameMapping[name]; - - // Otherwise convert the name to the light gbm argument - StringBuilder strBuf = new StringBuilder(); - bool first = true; - foreach (char c in name) - { - if (char.IsUpper(c)) - { - if (first) - first = false; - else - strBuf.Append('_'); - strBuf.Append(char.ToLower(c)); - } - else - strBuf.Append(c); - } - return strBuf.ToString(); - } - - // Static override name map that maps friendly names to lightGBM arguments. - // If an argument is not here, then its name is identical to a lightGBM argument - // and does not require a mapping, for example, Subsample. - private static Dictionary _nameMapping = new Dictionary() - { - {nameof(TreeBooster.Options.MinimumSplitGain), "min_split_gain" }, - {nameof(TreeBooster.Options.MaximumTreeDepth), "max_depth"}, - {nameof(TreeBooster.Options.MinimumChildWeight), "min_child_weight"}, - {nameof(TreeBooster.Options.SubsampleFraction), "subsample"}, - {nameof(TreeBooster.Options.SubsampleFrequency), "subsample_freq"}, - {nameof(TreeBooster.Options.L1Regularization), "reg_alpha"}, - {nameof(TreeBooster.Options.L2Regularization), "reg_lambda"}, - {nameof(TreeBooster.Options.WeightOfPositiveExamples), "scale_pos_weight"}, - {nameof(DartBooster.Options.TreeDropFraction), "drop_rate" }, - {nameof(DartBooster.Options.MaximumNumberOfDroppedTreesPerRound),"max_drop" }, - {nameof(DartBooster.Options.SkipDropFraction), "skip_drop" }, - {nameof(MinimumExampleCountPerLeaf), "min_data_per_leaf"}, - {nameof(NumberOfLeaves), "num_leaves"}, - {nameof(MaximumBinCountPerFeature), "max_bin" }, - {nameof(CustomGains), "label_gain" }, - {nameof(MinimumExampleCountPerGroup), "min_data_per_group" }, - {nameof(MaximumCategoricalSplitPointCount), "max_cat_threshold" }, - {nameof(CategoricalSmoothing), "cat_smooth" }, - {nameof(L2CategoricalRegularization), "cat_l2" } + {nameof(OptionsBase.MinimumSplitGain), "min_split_gain" }, + {nameof(OptionsBase.MaximumTreeDepth), "max_depth"}, + {nameof(OptionsBase.MinimumChildWeight), "min_child_weight"}, + {nameof(OptionsBase.SubsampleFraction), "subsample"}, + {nameof(OptionsBase.SubsampleFrequency), "subsample_freq"}, + {nameof(OptionsBase.L1Regularization), "reg_alpha"}, + {nameof(OptionsBase.L2Regularization), "reg_lambda"}, }; - - [BestFriend] - internal static class Defaults + public BoosterParameterBase(OptionsBase options) { - public const int NumberOfIterations = 100; + Contracts.CheckUserArg(options.MinimumSplitGain >= 0, nameof(OptionsBase.MinimumSplitGain), "must be >= 0."); + Contracts.CheckUserArg(options.MinimumChildWeight >= 0, nameof(OptionsBase.MinimumChildWeight), "must be >= 0."); + Contracts.CheckUserArg(options.SubsampleFraction > 0 && options.SubsampleFraction <= 1, nameof(OptionsBase.SubsampleFraction), "must be in (0,1]."); + Contracts.CheckUserArg(options.FeatureFraction > 0 && options.FeatureFraction <= 1, nameof(OptionsBase.FeatureFraction), "must be in (0,1]."); + Contracts.CheckUserArg(options.L2Regularization >= 0, nameof(OptionsBase.L2Regularization), "must be >= 0."); + Contracts.CheckUserArg(options.L1Regularization >= 0, nameof(OptionsBase.L1Regularization), "must be >= 0."); + BoosterOptions = options; } - /// - /// Gradient boosting decision tree. - /// - /// - /// For details, please see gradient tree boosting. - /// - public sealed class TreeBooster : BoosterParameter + public abstract class OptionsBase : IBoosterParameterFactory { - internal const string Name = "gbdt"; - internal const string FriendlyName = "Tree Booster"; + internal BoosterParameterBase GetBooster() { return null; } /// - /// The options for , used for setting . + /// The minimum loss reduction required to make a further partition on a leaf node of the tree. /// - [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Traditional Gradient Boosting Decision Tree.")] - public class Options : ISupportBoosterParameterFactory - { - /// - /// Whether training data is unbalanced. Used by . - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Use for binary classification when training data is not balanced.", ShortName = "us")] - public bool UnbalancedSets = false; - - /// - /// The minimum loss reduction required to make a further partition on a leaf node of the tree. - /// - /// - /// Larger values make the algorithm more conservative. - /// - [Argument(ArgumentType.AtMostOnce, - HelpText = "Minimum loss reduction required to make a further partition on a leaf node of the tree. the larger, " + - "the more conservative the algorithm will be.")] - [TlcModule.Range(Min = 0.0)] - public double MinimumSplitGain = 0; - - /// - /// The maximum depth of a tree. - /// - /// - /// 0 means no limit. - /// - [Argument(ArgumentType.AtMostOnce, - HelpText = "Maximum depth of a tree. 0 means no limit. However, tree still grows by best-first.")] - [TlcModule.Range(Min = 0, Max = int.MaxValue)] - public int MaximumTreeDepth = 0; - - /// - /// The minimum sum of instance weight needed to form a new node. - /// - /// - /// If the tree partition step results in a leaf node with the sum of instance weight less than , - /// the building process will give up further partitioning. In linear regression mode, this simply corresponds to minimum number - /// of instances needed to be in each node. The larger, the more conservative the algorithm will be. - /// - [Argument(ArgumentType.AtMostOnce, - HelpText = "Minimum sum of instance weight(hessian) needed in a child. If the tree partition step results in a leaf " + - "node with the sum of instance weight less than min_child_weight, then the building process will give up further partitioning. In linear regression mode, " + - "this simply corresponds to minimum number of instances needed to be in each node. The larger, the more conservative the algorithm will be.")] - [TlcModule.Range(Min = 0.0)] - public double MinimumChildWeight = 0.1; - - /// - /// The frequency of performing subsampling (bagging). - /// - /// - /// 0 means disable bagging. N means perform bagging at every N iterations. - /// To enable bagging, should also be set to a value less than 1.0. - /// - [Argument(ArgumentType.AtMostOnce, - HelpText = "Subsample frequency for bagging. 0 means no subsample. " - + "Specifies the frequency at which the bagging occurs, where if this is set to N, the subsampling will happen at every N iterations." + - "This must be set with Subsample as this specifies the amount to subsample.")] - [TlcModule.Range(Min = 0, Max = int.MaxValue)] - public int SubsampleFrequency = 0; - - /// - /// The fraction of training data used for creating trees. - /// - /// - /// Setting it to 0.5 means that LightGBM randomly picks half of the data points to grow trees. - /// This can be used to speed up training and to reduce over-fitting. Valid range is (0,1]. - /// - [Argument(ArgumentType.AtMostOnce, - HelpText = "Subsample ratio of the training instance. Setting it to 0.5 means that LightGBM randomly collected " + - "half of the data instances to grow trees and this will prevent overfitting. Range: (0,1].")] - [TlcModule.Range(Inf = 0.0, Max = 1.0)] - public double SubsampleFraction = 1; - - /// - /// The fraction of features used when creating trees. - /// - /// - /// If is smaller than 1.0, LightGBM will randomly select fraction of features to train each tree. - /// For example, if you set it to 0.8, LightGBM will select 80% of features before training each tree. - /// This can be used to speed up training and to reduce over-fitting. Valid range is (0,1]. - /// - [Argument(ArgumentType.AtMostOnce, - HelpText = "Subsample ratio of columns when constructing each tree. Range: (0,1].", - ShortName = "ff")] - [TlcModule.Range(Inf = 0.0, Max = 1.0)] - public double FeatureFraction = 1; - - /// - /// The L2 regularization term on weights. - /// - /// - /// Increasing this value could help reduce over-fitting. - /// - [Argument(ArgumentType.AtMostOnce, - HelpText = "L2 regularization term on weights, increasing this value will make model more conservative.", - ShortName = "l2")] - [TlcModule.Range(Min = 0.0)] - [TGUI(Label = "Lambda(L2)", SuggestedSweeps = "0,0.5,1")] - [TlcModule.SweepableDiscreteParam("RegLambda", new object[] { 0f, 0.5f, 1f })] - public double L2Regularization = 0.01; - - /// - /// The L1 regularization term on weights. - /// - /// - /// Increasing this value could help reduce over-fitting. - /// - [Argument(ArgumentType.AtMostOnce, - HelpText = "L1 regularization term on weights, increase this value will make model more conservative.", - ShortName = "l1")] - [TlcModule.Range(Min = 0.0)] - [TGUI(Label = "Alpha(L1)", SuggestedSweeps = "0,0.5,1")] - [TlcModule.SweepableDiscreteParam("RegAlpha", new object[] { 0f, 0.5f, 1f })] - public double L1Regularization = 0; - - /// - /// Controls the balance of positive and negative weights in . - /// - /// - /// This is useful for training on unbalanced data. A typical value to consider is sum(negative cases) / sum(positive cases). - /// - [Argument(ArgumentType.AtMostOnce, - HelpText = "Control the balance of positive and negative weights, useful for unbalanced classes." + - " A typical value to consider: sum(negative cases) / sum(positive cases).", - ShortName = "ScalePosWeight")] - public double WeightOfPositiveExamples = 1; - - internal virtual IBoosterParameter CreateComponent(IHostEnvironment env) => new TreeBooster(this); - - IBoosterParameter IComponentFactory.CreateComponent(IHostEnvironment env) => CreateComponent(env); - } + /// + /// Larger values make the algorithm more conservative. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Minimum loss reduction required to make a further partition on a leaf node of the tree. the larger, " + + "the more conservative the algorithm will be.")] + [TlcModule.Range(Min = 0.0)] + public double MinimumSplitGain = 0; - internal TreeBooster(Options options) - : base(options) - { - Contracts.CheckUserArg(BoosterParameterOptions.MinimumSplitGain >= 0, nameof(BoosterParameterOptions.MinimumSplitGain), "must be >= 0."); - Contracts.CheckUserArg(BoosterParameterOptions.MinimumChildWeight >= 0, nameof(BoosterParameterOptions.MinimumChildWeight), "must be >= 0."); - Contracts.CheckUserArg(BoosterParameterOptions.SubsampleFraction > 0 && BoosterParameterOptions.SubsampleFraction <= 1, nameof(BoosterParameterOptions.SubsampleFraction), "must be in (0,1]."); - Contracts.CheckUserArg(BoosterParameterOptions.FeatureFraction > 0 && BoosterParameterOptions.FeatureFraction <= 1, nameof(BoosterParameterOptions.FeatureFraction), "must be in (0,1]."); - Contracts.CheckUserArg(BoosterParameterOptions.L2Regularization >= 0, nameof(BoosterParameterOptions.L2Regularization), "must be >= 0."); - Contracts.CheckUserArg(BoosterParameterOptions.L1Regularization >= 0, nameof(BoosterParameterOptions.L1Regularization), "must be >= 0."); - Contracts.CheckUserArg(BoosterParameterOptions.WeightOfPositiveExamples > 0, nameof(BoosterParameterOptions.WeightOfPositiveExamples), "must be >= 0."); - } + /// + /// The maximum depth of a tree. + /// + /// + /// 0 means no limit. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Maximum depth of a tree. 0 means no limit. However, tree still grows by best-first.")] + [TlcModule.Range(Min = 0, Max = int.MaxValue)] + public int MaximumTreeDepth = 0; - internal override void UpdateParameters(Dictionary res) - { - base.UpdateParameters(res); - res["boosting_type"] = Name; - } - } + /// + /// The minimum sum of instance weight needed to form a new node. + /// + /// + /// If the tree partition step results in a leaf node with the sum of instance weight less than , + /// the building process will give up further partitioning. In linear regression mode, this simply corresponds to minimum number + /// of instances needed to be in each node. The larger, the more conservative the algorithm will be. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Minimum sum of instance weight(hessian) needed in a child. If the tree partition step results in a leaf " + + "node with the sum of instance weight less than min_child_weight, then the building process will give up further partitioning. In linear regression mode, " + + "this simply corresponds to minimum number of instances needed to be in each node. The larger, the more conservative the algorithm will be.")] + [TlcModule.Range(Min = 0.0)] + public double MinimumChildWeight = 0.1; - /// - /// DART booster (Dropouts meet Multiple Additive Regression Trees) - /// - /// - /// For details, please see here. - /// - public sealed class DartBooster : BoosterParameter - { - internal const string Name = "dart"; - internal const string FriendlyName = "Tree Dropout Tree Booster"; + /// + /// The frequency of performing subsampling (bagging). + /// + /// + /// 0 means disable bagging. N means perform bagging at every N iterations. + /// To enable bagging, should also be set to a value less than 1.0. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Subsample frequency for bagging. 0 means no subsample. " + + "Specifies the frequency at which the bagging occurs, where if this is set to N, the subsampling will happen at every N iterations." + + "This must be set with Subsample as this specifies the amount to subsample.")] + [TlcModule.Range(Min = 0, Max = int.MaxValue)] + public int SubsampleFrequency = 0; /// - /// The options for , used for setting . + /// The fraction of training data used for creating trees. /// - [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Dropouts meet Multiple Additive Regression Trees. See https://arxiv.org/abs/1505.01866")] - public sealed class Options : TreeBooster.Options - { - /// - /// The dropout rate, i.e. the fraction of previous trees to drop during the dropout. - /// - /// - /// Valid range is [0,1]. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "The drop ratio for trees. Range:[0,1].")] - [TlcModule.Range(Inf = 0.0, Max = 1.0)] - public double TreeDropFraction = 0.1; - - /// - /// The maximum number of dropped trees in a boosting round. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of dropped trees in a boosting round.")] - [TlcModule.Range(Inf = 0, Max = int.MaxValue)] - public int MaximumNumberOfDroppedTreesPerRound = 1; - - /// - /// The probability of skipping the dropout procedure during a boosting iteration. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Probability for not dropping in a boosting round.")] - [TlcModule.Range(Inf = 0.0, Max = 1.0)] - public double SkipDropFraction = 0.5; - - /// - /// Whether to enable xgboost dart mode. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "True will enable xgboost dart mode.")] - public bool XgboostDartMode = false; - - /// - /// Whether to enable uniform drop. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "True will enable uniform drop.")] - public bool UniformDrop = false; - - internal override IBoosterParameter CreateComponent(IHostEnvironment env) => new DartBooster(this); - } + /// + /// Setting it to 0.5 means that LightGBM randomly picks half of the data points to grow trees. + /// This can be used to speed up training and to reduce over-fitting. Valid range is (0,1]. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Subsample ratio of the training instance. Setting it to 0.5 means that LightGBM randomly collected " + + "half of the data instances to grow trees and this will prevent overfitting. Range: (0,1].")] + [TlcModule.Range(Inf = 0.0, Max = 1.0)] + public double SubsampleFraction = 1; - internal DartBooster(Options options) - : base(options) - { - Contracts.CheckUserArg(BoosterParameterOptions.TreeDropFraction > 0 && BoosterParameterOptions.TreeDropFraction < 1, nameof(BoosterParameterOptions.TreeDropFraction), "must be in (0,1)."); - Contracts.CheckUserArg(BoosterParameterOptions.SkipDropFraction >= 0 && BoosterParameterOptions.SkipDropFraction < 1, nameof(BoosterParameterOptions.SkipDropFraction), "must be in [0,1)."); - } + /// + /// The fraction of features used when creating trees. + /// + /// + /// If is smaller than 1.0, LightGBM will randomly select fraction of features to train each tree. + /// For example, if you set it to 0.8, LightGBM will select 80% of features before training each tree. + /// This can be used to speed up training and to reduce over-fitting. Valid range is (0,1]. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Subsample ratio of columns when constructing each tree. Range: (0,1].", + ShortName = "ff")] + [TlcModule.Range(Inf = 0.0, Max = 1.0)] + public double FeatureFraction = 1; - internal override void UpdateParameters(Dictionary res) - { - base.UpdateParameters(res); - res["boosting_type"] = Name; - } + /// + /// The L2 regularization term on weights. + /// + /// + /// Increasing this value could help reduce over-fitting. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "L2 regularization term on weights, increasing this value will make model more conservative.", + ShortName = "l2")] + [TlcModule.Range(Min = 0.0)] + [TGUI(Label = "Lambda(L2)", SuggestedSweeps = "0,0.5,1")] + [TlcModule.SweepableDiscreteParam("RegLambda", new object[] { 0f, 0.5f, 1f })] + public double L2Regularization = 0.01; + + /// + /// The L1 regularization term on weights. + /// + /// + /// Increasing this value could help reduce over-fitting. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "L1 regularization term on weights, increase this value will make model more conservative.", + ShortName = "l1")] + [TlcModule.Range(Min = 0.0)] + [TGUI(Label = "Alpha(L1)", SuggestedSweeps = "0,0.5,1")] + [TlcModule.SweepableDiscreteParam("RegAlpha", new object[] { 0f, 0.5f, 1f })] + public double L1Regularization = 0; + + BoosterParameterBase IComponentFactory.CreateComponent(IHostEnvironment env) + => BuildOptions(); + + BoosterParameterBase IBoosterParameterFactory.CreateComponent(IHostEnvironment env) + => BuildOptions(); + + internal abstract BoosterParameterBase BuildOptions(); } - /// - /// Gradient-based One-Side Sampling booster. - /// - /// - /// For details, please see here. - /// - public sealed class GossBooster : BoosterParameter + internal void UpdateParameters(Dictionary res) { - internal const string Name = "goss"; - internal const string FriendlyName = "Gradient-based One-Size Sampling"; - - [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Gradient-based One-Side Sampling.")] - public sealed class Options : TreeBooster.Options + FieldInfo[] fields = BoosterOptions.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + foreach (var field in fields) { - /// - /// The retain ratio of large gradient data. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Retain ratio for large gradient instances.")] - [TlcModule.Range(Inf = 0.0, Max = 1.0)] - public double TopRate = 0.2; - - /// - /// The retain ratio of small gradient data. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Retain ratio for small gradient instances.")] - [TlcModule.Range(Inf = 0.0, Max = 1.0)] - public double OtherRate = 0.1; - - internal override IBoosterParameter CreateComponent(IHostEnvironment env) => new GossBooster(this); - } + var attribute = field.GetCustomAttribute(false); - internal GossBooster(Options options) - : base(options) - { - Contracts.CheckUserArg(BoosterParameterOptions.TopRate > 0 && BoosterParameterOptions.TopRate < 1, nameof(BoosterParameterOptions.TopRate), "must be in (0,1)."); - Contracts.CheckUserArg(BoosterParameterOptions.OtherRate >= 0 && BoosterParameterOptions.OtherRate < 1, nameof(BoosterParameterOptions.TopRate), "must be in [0,1)."); - Contracts.Check(BoosterParameterOptions.TopRate + BoosterParameterOptions.OtherRate <= 1, "Sum of topRate and otherRate cannot be larger than 1."); - } + if (attribute == null) + continue; - internal override void UpdateParameters(Dictionary res) - { - base.UpdateParameters(res); - res["boosting_type"] = Name; + var name = NameMapping.ContainsKey(field.Name) ? NameMapping[field.Name] : LightGbmInterfaceUtils.GetOptionName(field.Name); + res[name] = field.GetValue(BoosterOptions); } } /// - /// The evaluation metrics that are available for . - /// - public enum EvalMetricType - { - DefaultMetric, - Rmse, - Mae, - Logloss, - Error, - Merror, - Mlogloss, - Auc, - Ndcg, - Map - }; - - /// - /// The number of boosting iterations. A new tree is created in each iteration, so this is equivalent to the number of trees. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations.", SortOrder = 1, ShortName = "iter")] - [TGUI(Label = "Number of boosting iterations", SuggestedSweeps = "10,20,50,100,150,200")] - [TlcModule.SweepableDiscreteParam("NumBoostRound", new object[] { 10, 20, 50, 100, 150, 200 })] - public int NumberOfIterations = Defaults.NumberOfIterations; - - /// - /// The shrinkage rate for trees, used to prevent over-fitting. - /// - /// - /// Valid range is (0,1]. - /// - [Argument(ArgumentType.AtMostOnce, - HelpText = "Shrinkage rate for trees, used to prevent over-fitting. Range: (0,1].", - SortOrder = 2, ShortName = "lr", NullName = "")] - [TGUI(Label = "Learning Rate", SuggestedSweeps = "0.025-0.4;log")] - [TlcModule.SweepableFloatParamAttribute("LearningRate", 0.025f, 0.4f, isLogScale: true)] - public double? LearningRate; - - /// - /// The maximum number of leaves in one tree. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum leaves for trees.", - SortOrder = 2, ShortName = "nl", NullName = "")] - [TGUI(Description = "The maximum number of leaves per tree", SuggestedSweeps = "2-128;log;inc:4")] - [TlcModule.SweepableLongParamAttribute("NumLeaves", 2, 128, isLogScale: true, stepSize: 4)] - public int? NumberOfLeaves; - - /// - /// The minimal number of data points required to form a new tree leaf. + /// Create for supporting legacy infra built upon . /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Minimum number of instances needed in a child.", - SortOrder = 2, ShortName = "mil", NullName = "")] - [TGUI(Label = "Min Documents In Leaves", SuggestedSweeps = "1,10,20,50 ")] - [TlcModule.SweepableDiscreteParamAttribute("MinDataPerLeaf", new object[] { 1, 10, 20, 50 })] - public int? MinimumExampleCountPerLeaf; + internal abstract IBoosterParameterFactory BuildFactory(); + internal abstract string BoosterName { get; } - /// - /// The maximum number of bins that feature values will be bucketed in. - /// - /// - /// The small number of bins may reduce training accuracy but may increase general power (deal with over-fitting). - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of bucket bin for features.", ShortName = "mb")] - public int MaximumBinCountPerFeature = 255; + private protected OptionsBase BoosterOptions; + } - /// - /// Determines which booster to use. - /// - /// - /// Available boosters are , , and . - /// - [Argument(ArgumentType.Multiple, HelpText = "Which booster to use, can be gbtree, gblinear or dart. gbtree and dart use tree based model while gblinear uses linear function.", SortOrder = 3)] - public ISupportBoosterParameterFactory Booster = new TreeBooster.Options(); + /// + /// Gradient boosting decision tree. + /// + /// + /// For details, please see gradient tree boosting. + /// + public sealed class GradientBooster : BoosterParameterBase + { + internal const string Name = "gbdt"; + internal const string FriendlyName = "Tree Booster"; /// - /// Determines whether to output progress status during training and evaluation. + /// The options for , used for setting . /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Verbose", ShortName = "v")] - public bool Verbose = false; + [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Traditional Gradient Boosting Decision Tree.")] + public sealed class Options : OptionsBase + { + internal override BoosterParameterBase BuildOptions() => new GradientBooster(this); + } - /// - /// Controls the logging level in LighGBM. - /// - /// - /// means only output Fatal errors. means output Fatal, Warning, and Info level messages. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Printing running messages.")] - public bool Silent = true; + internal GradientBooster(Options options) + : base(options) + { + } - /// - /// Determines the number of threads used to run LightGBM. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Number of parallel threads used to run LightGBM.", ShortName = "nt")] - public int? NumberOfThreads; + internal override IBoosterParameterFactory BuildFactory() => BoosterOptions; - /// - /// Determines what evaluation metric to use. - /// - [Argument(ArgumentType.AtMostOnce, - HelpText = "Evaluation metrics.", - ShortName = "em")] - public EvalMetricType EvaluationMetric = EvalMetricType.DefaultMetric; + internal override string BoosterName => Name; + } - /// - /// Whether to use softmax loss. Used only by . - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Use softmax loss for the multi classification.")] - [TlcModule.SweepableDiscreteParam("UseSoftmax", new object[] { true, false })] - public bool? UseSoftmax; + /// + /// DART booster (Dropouts meet Multiple Additive Regression Trees) + /// + /// + /// For details, please see here. + /// + public sealed class DartBooster : BoosterParameterBase + { + internal const string Name = "dart"; + internal const string FriendlyName = "Tree Dropout Tree Booster"; /// - /// Determines the number of rounds, after which training will stop if validation metric doesn't improve. + /// The options for , used for setting . /// - /// - /// 0 means disable early stopping. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Rounds of early stopping, 0 will disable it.", - ShortName = "es")] - public int EarlyStoppingRound = 0; + [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Dropouts meet Multiple Additive Regresion Trees. See https://arxiv.org/abs/1505.01866")] + public sealed class Options : OptionsBase + { + static Options() + { + // Add additional name mappings + NameMapping.Add(nameof(TreeDropFraction), "drop_rate"); + NameMapping.Add(nameof(MaximumNumberOfDroppedTreesPerRound), "max_drop"); + NameMapping.Add(nameof(SkipDropFraction), "skip_drop"); + } - /// - /// Comma-separated list of gains associated with each relevance label. Used only by . - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Comma separated list of gains associated to each relevance label.", ShortName = "gains")] - [TGUI(Label = "Ranking Label Gain")] - public string CustomGains = "0,3,7,15,31,63,127,255,511,1023,2047,4095"; + /// + /// The dropout rate, i.e. the fraction of previous trees to drop during the dropout. + /// + /// + /// Valid range is [0,1]. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "The drop ratio for trees. Range:(0,1).")] + [TlcModule.Range(Inf = 0.0, Max = 1.0)] + public double TreeDropFraction = 0.1; - /// - /// Parameter for the sigmoid function. Used only by , , and . - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Parameter for the sigmoid function. Used only in " + nameof(LightGbmBinaryTrainer) + ", " + nameof(LightGbmMulticlassTrainer) + - " and in " + nameof(LightGbmRankingTrainer) + ".", ShortName = "sigmoid")] - [TGUI(Label = "Sigmoid", SuggestedSweeps = "0.5,1")] - public double Sigmoid = 0.5; + /// + /// The maximum number of dropped trees in a boosting round. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of dropped trees in a boosting round.")] + [TlcModule.Range(Inf = 0, Max = int.MaxValue)] + public int MaximumNumberOfDroppedTreesPerRound = 1; - /// - /// Number of data points per batch, when loading data. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Number of entries in a batch when loading data.", Hide = true)] - public int BatchSize = 1 << 20; + /// + /// The probability of skipping the dropout procedure during a boosting iteration. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Probability for not dropping in a boosting round.")] + [TlcModule.Range(Inf = 0.0, Max = 1.0)] + public double SkipDropFraction = 0.5; - /// - /// Whether to enable categorical split or not. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Enable categorical split or not.", ShortName = "cat")] - [TlcModule.SweepableDiscreteParam("UseCat", new object[] { true, false })] - public bool? UseCategoricalSplit; + /// + /// Whether to enable xgboost dart mode. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "True will enable xgboost dart mode.")] + public bool XgboostDartMode = false; - /// - /// Whether to enable special handling of missing value or not. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Enable special handling of missing value or not.")] - [TlcModule.SweepableDiscreteParam("UseMissing", new object[] { true, false })] - public bool HandleMissingValue = false; + /// + /// Whether to enable uniform drop. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "True will enable uniform drop.")] + public bool UniformDrop = false; - /// - /// The minimum number of data points per categorical group. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Minimum number of instances per categorical group.", ShortName = "mdpg")] - [TlcModule.Range(Inf = 0, Max = int.MaxValue)] - [TlcModule.SweepableDiscreteParam("MinDataPerGroup", new object[] { 10, 50, 100, 200 })] - public int MinimumExampleCountPerGroup = 100; + internal override BoosterParameterBase BuildOptions() => new DartBooster(this); + } - /// - /// When the number of categories of one feature is smaller than or equal to , - /// one-vs-other split algorithm will be used. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Max number of categorical thresholds.", ShortName = "maxcat")] - [TlcModule.Range(Inf = 0, Max = int.MaxValue)] - [TlcModule.SweepableDiscreteParam("MaxCatThreshold", new object[] { 8, 16, 32, 64 })] - public int MaximumCategoricalSplitPointCount = 32; + internal DartBooster(Options options) + :base(options) + { + Contracts.CheckUserArg(options.TreeDropFraction > 0 && options.TreeDropFraction < 1, nameof(options.TreeDropFraction), "must be in (0,1)."); + Contracts.CheckUserArg(options.SkipDropFraction >= 0 && options.SkipDropFraction < 1, nameof(options.SkipDropFraction), "must be in [0,1)."); + BoosterOptions = options; + } - /// - /// Laplace smooth term in categorical feature split. - /// This can reduce the effect of noises in categorical features, especially for categories with few data. - /// - /// - /// Constraints: >= 0.0 - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Laplace smooth term in categorical feature split. Avoid the bias of small categories.")] - [TlcModule.Range(Min = 0.0)] - [TlcModule.SweepableDiscreteParam("CatSmooth", new object[] { 1, 10, 20 })] - public double CategoricalSmoothing = 10; + internal override IBoosterParameterFactory BuildFactory() => BoosterOptions; + internal override string BoosterName => Name; + } - /// - /// L2 regularization for categorical split. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization for categorical split.")] - [TlcModule.Range(Min = 0.0)] - [TlcModule.SweepableDiscreteParam("CatL2", new object[] { 0.1, 0.5, 1, 5, 10 })] - public double L2CategoricalRegularization = 10; + /// + /// Gradient-based One-Side Sampling booster. + /// + /// + /// For details, please see here. + /// + public sealed class GossBooster : BoosterParameterBase + { + internal const string Name = "goss"; + internal const string FriendlyName = "Gradient-based One-Size Sampling"; /// - /// The random seed for LightGBM to use. + /// The options for , used for setting . /// - /// - /// If not specified, will generate a random seed to be used. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Sets the random seed for LightGBM to use.")] - public int? Seed; - - [Argument(ArgumentType.Multiple, HelpText = "Parallel LightGBM Learning Algorithm", ShortName = "parag")] - internal ISupportParallel ParallelTrainer = new SingleTrainerFactory(); - - internal Dictionary ToDictionary(IHost host) + [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "Gradient-based One-Side Sampling.")] + public sealed class Options : OptionsBase { - Contracts.CheckValue(host, nameof(host)); - Contracts.CheckUserArg(MaximumBinCountPerFeature > 0, nameof(MaximumBinCountPerFeature), "must be > 0."); - Contracts.CheckUserArg(Sigmoid > 0, nameof(Sigmoid), "must be > 0."); - Dictionary res = new Dictionary(); - - var boosterParams = Booster.CreateComponent(host); - boosterParams.UpdateParameters(res); - - res[GetOptionName(nameof(MaximumBinCountPerFeature))] = MaximumBinCountPerFeature; + /// + /// The retain ratio of large gradient data. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Retain ratio for large gradient instances.")] + [TlcModule.Range(Inf = 0.0, Max = 1.0)] + public double TopRate = 0.2; - res["verbose"] = Silent ? "-1" : "1"; - if (NumberOfThreads.HasValue) - res["nthread"] = NumberOfThreads.Value; + /// + /// The retain ratio of small gradient data. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Retain ratio for small gradient instances.")] + [TlcModule.Range(Inf = 0.0, Max = 1.0)] + public double OtherRate = 0.1; - res["seed"] = (Seed.HasValue) ? Seed : host.Rand.Next(); + internal override BoosterParameterBase BuildOptions() => new GossBooster(this); + } - string metric = null; - switch (EvaluationMetric) - { - case EvalMetricType.DefaultMetric: - break; - case EvalMetricType.Mae: - metric = "l1"; - break; - case EvalMetricType.Logloss: - metric = "binary_logloss"; - break; - case EvalMetricType.Error: - metric = "binary_error"; - break; - case EvalMetricType.Merror: - metric = "multi_error"; - break; - case EvalMetricType.Mlogloss: - metric = "multi_logloss"; - break; - case EvalMetricType.Rmse: - case EvalMetricType.Auc: - case EvalMetricType.Ndcg: - case EvalMetricType.Map: - metric = EvaluationMetric.ToString().ToLower(); - break; - } - if (!string.IsNullOrEmpty(metric)) - res[GetOptionName(nameof(metric))] = metric; - res[GetOptionName(nameof(Sigmoid))] = Sigmoid; - res[GetOptionName(nameof(CustomGains))] = CustomGains; - res[GetOptionName(nameof(HandleMissingValue))] = HandleMissingValue; - res[GetOptionName(nameof(MinimumExampleCountPerGroup))] = MinimumExampleCountPerGroup; - res[GetOptionName(nameof(MaximumCategoricalSplitPointCount))] = MaximumCategoricalSplitPointCount; - res[GetOptionName(nameof(CategoricalSmoothing))] = CategoricalSmoothing; - res[GetOptionName(nameof(L2CategoricalRegularization))] = L2CategoricalRegularization; - return res; + internal GossBooster(Options options) + :base(options) + { + Contracts.CheckUserArg(options.TopRate > 0 && options.TopRate < 1, nameof(Options.TopRate), "must be in (0,1)."); + Contracts.CheckUserArg(options.OtherRate >= 0 && options.OtherRate < 1, nameof(Options.OtherRate), "must be in [0,1)."); + Contracts.Check(options.TopRate + options.OtherRate <= 1, "Sum of topRate and otherRate cannot be larger than 1."); + BoosterOptions = options; } + + internal override IBoosterParameterFactory BuildFactory() => BoosterOptions; + internal override string BoosterName => Name; } -} +} \ No newline at end of file diff --git a/src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs index 73f47c097d..5796df75de 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs @@ -2,15 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; using Microsoft.ML; using Microsoft.ML.Calibrators; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Runtime; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.LightGbm; -[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(Options), +[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(LightGbmBinaryTrainer.Options), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) }, LightGbmBinaryTrainer.UserName, LightGbmBinaryTrainer.LoadNameValue, LightGbmBinaryTrainer.ShortName, DocName = "trainer/LightGBM.md")] @@ -82,7 +85,7 @@ private static IPredictorProducing Create(IHostEnvironment env, ModelLoad /// The for training a boosted decision tree binary classification model using LightGBM. /// /// - public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase>, CalibratedModelParametersBase> { @@ -93,9 +96,79 @@ public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase PredictionKind.BinaryClassification; + public sealed class Options : OptionsBase + { + + public enum EvaluateMetricType + { + None, + Default, + Logloss, + Error, + AreaUnderCurve, + }; + + /// + /// Whether training data is unbalanced. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Use for binary classification when training data is not balanced.", ShortName = "us")] + public bool UnbalancedSets = false; + + /// + /// Controls the balance of positive and negative weights in . + /// + /// + /// This is useful for training on unbalanced data. A typical value to consider is sum(negative cases) / sum(positive cases). + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Control the balance of positive and negative weights, useful for unbalanced classes." + + " A typical value to consider: sum(negative cases) / sum(positive cases).", + ShortName = "ScalePosWeight")] + public double WeightOfPositiveExamples = 1; + + /// + /// Parameter for the sigmoid function. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Parameter for the sigmoid function.", ShortName = "sigmoid")] + [TGUI(Label = "Sigmoid", SuggestedSweeps = "0.5,1")] + public double Sigmoid = 0.5; + + /// + /// Determines what evaluation metric to use. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Evaluation metrics.", + ShortName = "em")] + public EvaluateMetricType EvaluationMetric = EvaluateMetricType.Logloss; + + static Options() + { + NameMapping.Add(nameof(EvaluateMetricType), "metric"); + NameMapping.Add(nameof(EvaluateMetricType.None), ""); + NameMapping.Add(nameof(EvaluateMetricType.Logloss), "binary_logloss"); + NameMapping.Add(nameof(EvaluateMetricType.Error), "binary_error"); + NameMapping.Add(nameof(EvaluateMetricType.AreaUnderCurve), "auc"); + NameMapping.Add(nameof(WeightOfPositiveExamples), "scale_pos_weight"); + } + + internal override Dictionary ToDictionary(IHost host) + { + var res = base.ToDictionary(host); + res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets; + res[GetOptionName(nameof(WeightOfPositiveExamples))] = WeightOfPositiveExamples; + res[GetOptionName(nameof(Sigmoid))] = Sigmoid; + if (EvaluationMetric != EvaluateMetricType.Default) + res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); + + return res; + } + } + internal LightGbmBinaryTrainer(IHostEnvironment env, Options options) : base(env, LoadNameValue, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName)) { + Contracts.CheckUserArg(options.Sigmoid > 0, nameof(Options.Sigmoid), "must be > 0."); + Contracts.CheckUserArg(options.WeightOfPositiveExamples > 0, nameof(Options.WeightOfPositiveExamples), "must be > 0."); } /// @@ -116,15 +189,25 @@ internal LightGbmBinaryTrainer(IHostEnvironment env, int? numberOfLeaves = null, int? minimumExampleCountPerLeaf = null, double? learningRate = null, - int numberOfIterations = Trainers.LightGbm.Options.Defaults.NumberOfIterations) - : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumnName), featureColumnName, exampleWeightColumnName, null, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations) + int numberOfIterations = Defaults.NumberOfIterations) + : this(env, + new Options() + { + LabelColumnName = labelColumnName, + FeatureColumnName = featureColumnName, + ExampleWeightColumnName = exampleWeightColumnName, + NumberOfLeaves = numberOfLeaves, + MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf, + LearningRate = learningRate, + NumberOfIterations = numberOfIterations + }) { } private protected override CalibratedModelParametersBase CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); - var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); + var innerArgs = LightGbmInterfaceUtils.JoinParameters(base.GbmOptions); var pred = new LightGbmBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs); var cali = new PlattCalibrator(Host, -0.5, 0); return new FeatureWeightsCalibratedModelParameters(Host, pred, cali); @@ -143,12 +226,7 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) } private protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, RoleMappedData data, float[] labels, int[] groups) - { - Options["objective"] = "binary"; - // Add default metric. - if (!Options.ContainsKey("metric")) - Options["metric"] = "binary_logloss"; - } + => GbmOptions["objective"] = "binary"; private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { @@ -182,14 +260,14 @@ internal static partial class LightGbm Desc = LightGbmBinaryTrainer.Summary, UserName = LightGbmBinaryTrainer.UserName, ShortName = LightGbmBinaryTrainer.ShortName)] - public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input) + public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, LightGbmBinaryTrainer.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainLightGBM"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return TrainerEntryPointsUtils.Train(host, input, + return TrainerEntryPointsUtils.Train(host, input, () => new LightGbmBinaryTrainer(host, input), getLabel: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName), getWeight: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName)); diff --git a/src/Microsoft.ML.LightGbm/LightGbmCatalog.cs b/src/Microsoft.ML.LightGbm/LightGbmCatalog.cs index e740ee3836..40d5532cfb 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmCatalog.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmCatalog.cs @@ -38,7 +38,7 @@ public static LightGbmRegressionTrainer LightGbm(this RegressionCatalog.Regressi int? numberOfLeaves = null, int? minimumExampleCountPerLeaf = null, double? learningRate = null, - int numberOfIterations = Options.Defaults.NumberOfIterations) + int numberOfIterations = Defaults.NumberOfIterations) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); @@ -58,7 +58,7 @@ public static LightGbmRegressionTrainer LightGbm(this RegressionCatalog.Regressi /// /// public static LightGbmRegressionTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog, - Options options) + LightGbmRegressionTrainer.Options options) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); @@ -90,7 +90,7 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi int? numberOfLeaves = null, int? minimumExampleCountPerLeaf = null, double? learningRate = null, - int numberOfIterations = Options.Defaults.NumberOfIterations) + int numberOfIterations = Defaults.NumberOfIterations) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); @@ -110,7 +110,7 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi /// /// public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, - Options options) + LightGbmBinaryTrainer.Options options) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); @@ -144,11 +144,12 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer int? numberOfLeaves = null, int? minimumExampleCountPerLeaf = null, double? learningRate = null, - int numberOfIterations = Options.Defaults.NumberOfIterations) + int numberOfIterations = Defaults.NumberOfIterations) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new LightGbmRankingTrainer(env, labelColumnName, featureColumnName, rowGroupColumnName, exampleWeightColumnName, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations); + return new LightGbmRankingTrainer(env, labelColumnName, featureColumnName, rowGroupColumnName, exampleWeightColumnName, + numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations); } /// @@ -164,7 +165,7 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer /// /// public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainers catalog, - Options options) + LightGbmRankingTrainer.Options options) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); @@ -196,7 +197,7 @@ public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCa int? numberOfLeaves = null, int? minimumExampleCountPerLeaf = null, double? learningRate = null, - int numberOfIterations = Options.Defaults.NumberOfIterations) + int numberOfIterations = Defaults.NumberOfIterations) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); @@ -216,7 +217,7 @@ public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCa /// /// public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, - Options options) + LightGbmMulticlassTrainer.Options options) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index d4126d65d9..386a6112d7 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -3,16 +3,19 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Linq; using Microsoft.ML; using Microsoft.ML.Calibrators; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Runtime; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.LightGbm; -[assembly: LoadableClass(LightGbmMulticlassTrainer.Summary, typeof(LightGbmMulticlassTrainer), typeof(Options), +[assembly: LoadableClass(LightGbmMulticlassTrainer.Summary, typeof(LightGbmMulticlassTrainer), typeof(LightGbmMulticlassTrainer.Options), new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer) }, "LightGBM Multi-class Classifier", LightGbmMulticlassTrainer.LoadNameValue, LightGbmMulticlassTrainer.ShortName, DocName = "trainer/LightGBM.md")] @@ -22,7 +25,10 @@ namespace Microsoft.ML.Trainers.LightGbm /// The for training a boosted decision tree multi-class classification model using LightGBM. /// /// - public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase, MulticlassPredictionTransformer, OneVersusAllModelParameters> + public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase, + MulticlassPredictionTransformer, + OneVersusAllModelParameters> { internal const string Summary = "LightGBM Multi Class Classifier"; internal const string LoadNameValue = "LightGBMMulticlass"; @@ -34,9 +40,61 @@ public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase PredictionKind.MulticlassClassification; + public sealed class Options : OptionsBase + { + public enum EvaluateMetricType + { + None, + Default, + Error, + LogLoss, + } + + /// + /// Whether to use softmax loss. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Use softmax loss for the multi classification.")] + [TlcModule.SweepableDiscreteParam("UseSoftmax", new object[] { true, false })] + public bool? UseSoftmax; + + /// + /// Parameter for the sigmoid function. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Parameter for the sigmoid function.", ShortName = "sigmoid")] + [TGUI(Label = "Sigmoid", SuggestedSweeps = "0.5,1")] + public double Sigmoid = 0.5; + + /// + /// Determines what evaluation metric to use. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Evaluation metrics.", + ShortName = "em")] + public EvaluateMetricType EvaluationMetric = EvaluateMetricType.Error; + + static Options() + { + NameMapping.Add(nameof(EvaluateMetricType), "metric"); + NameMapping.Add(nameof(EvaluateMetricType.Error), "multi_error"); + NameMapping.Add(nameof(EvaluateMetricType.LogLoss), "multi_logloss"); + } + + internal override Dictionary ToDictionary(IHost host) + { + var res = base.ToDictionary(host); + + res[GetOptionName(nameof(Sigmoid))] = Sigmoid; + if(EvaluationMetric != EvaluateMetricType.Default) + res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); + + return res; + } + } + internal LightGbmMulticlassTrainer(IHostEnvironment env, Options options) : base(env, LoadNameValue, options, TrainerUtils.MakeU4ScalarColumn(options.LabelColumnName)) { + Contracts.CheckUserArg(options.Sigmoid > 0, nameof(Options.Sigmoid), "must be > 0."); _numClass = -1; } @@ -58,10 +116,19 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env, int? numberOfLeaves = null, int? minimumExampleCountPerLeaf = null, double? learningRate = null, - int numberOfIterations = Trainers.LightGbm.Options.Defaults.NumberOfIterations) - : base(env, LoadNameValue, TrainerUtils.MakeU4ScalarColumn(labelColumnName), featureColumnName, exampleWeightColumnName, null, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations) + int numberOfIterations = Defaults.NumberOfIterations) + : this(env, + new Options() + { + LabelColumnName = labelColumnName, + FeatureColumnName = featureColumnName, + ExampleWeightColumnName = exampleWeightColumnName, + NumberOfLeaves = numberOfLeaves, + MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf, + LearningRate = learningRate, + NumberOfIterations = numberOfIterations + }) { - _numClass = -1; } private InternalTreeEnsemble GetBinaryEnsemble(int classID) @@ -88,7 +155,7 @@ private protected override OneVersusAllModelParameters CreatePredictor() Host.Assert(_numClass > 1, "Must know the number of classes before creating a predictor."); Host.Assert(TrainedEnsemble.NumTrees % _numClass == 0, "Number of trees should be a multiple of number of classes."); - var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); + var innerArgs = LightGbmInterfaceUtils.JoinParameters(GbmOptions); IPredictorProducing[] predictors = new IPredictorProducing[_tlcNumClass]; for (int i = 0; i < _tlcNumClass; ++i) { @@ -163,13 +230,13 @@ private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData dat private protected override void GetDefaultParameters(IChannel ch, int numRow, bool hasCategorical, int totalCats, bool hiddenMsg = false) { base.GetDefaultParameters(ch, numRow, hasCategorical, totalCats, true); - int numberOfLeaves = (int)Options["num_leaves"]; + int numberOfLeaves = (int)GbmOptions["num_leaves"]; int minimumExampleCountPerLeaf = LightGbmTrainerOptions.MinimumExampleCountPerLeaf ?? DefaultMinDataPerLeaf(numRow, numberOfLeaves, _numClass); - Options["min_data_per_leaf"] = minimumExampleCountPerLeaf; + GbmOptions["min_data_per_leaf"] = minimumExampleCountPerLeaf; if (!hiddenMsg) { if (!LightGbmTrainerOptions.LearningRate.HasValue) - ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.LearningRate) + " = " + Options["learning_rate"]); + ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.LearningRate) + " = " + GbmOptions["learning_rate"]); if (!LightGbmTrainerOptions.NumberOfLeaves.HasValue) ch.Info("Auto-tuning parameters: " + nameof(LightGbmTrainerOptions.NumberOfLeaves) + " = " + numberOfLeaves); if (!LightGbmTrainerOptions.MinimumExampleCountPerLeaf.HasValue) @@ -182,7 +249,7 @@ private protected override void CheckAndUpdateParametersBeforeTraining(IChannel Host.AssertValue(ch); ch.Assert(PredictionKind == PredictionKind.MulticlassClassification); ch.Assert(_numClass > 1); - Options["num_class"] = _numClass; + GbmOptions["num_class"] = _numClass; bool useSoftmax = false; if (LightGbmTrainerOptions.UseSoftmax.HasValue) @@ -196,13 +263,9 @@ private protected override void CheckAndUpdateParametersBeforeTraining(IChannel } if (useSoftmax) - Options["objective"] = "multiclass"; + GbmOptions["objective"] = "multiclass"; else - Options["objective"] = "multiclassova"; - - // Add default metric. - if (!Options.ContainsKey("metric")) - Options["metric"] = "multi_error"; + GbmOptions["objective"] = "multiclassova"; } private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -240,14 +303,14 @@ internal static partial class LightGbm Desc = "Train a LightGBM multi class model.", UserName = LightGbmMulticlassTrainer.Summary, ShortName = LightGbmMulticlassTrainer.ShortName)] - public static CommonOutputs.MulticlassClassificationOutput TrainMulticlass(IHostEnvironment env, Options input) + public static CommonOutputs.MulticlassClassificationOutput TrainMulticlass(IHostEnvironment env, LightGbmMulticlassTrainer.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainLightGBM"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return TrainerEntryPointsUtils.Train(host, input, + return TrainerEntryPointsUtils.Train(host, input, () => new LightGbmMulticlassTrainer(host, input), getLabel: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName), getWeight: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName)); diff --git a/src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs index 6c6cc75dd5..0c8f183740 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs @@ -3,14 +3,17 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Runtime; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.LightGbm; -[assembly: LoadableClass(LightGbmRankingTrainer.UserName, typeof(LightGbmRankingTrainer), typeof(Options), +[assembly: LoadableClass(LightGbmRankingTrainer.UserName, typeof(LightGbmRankingTrainer), typeof(LightGbmRankingTrainer.Options), new[] { typeof(SignatureRankerTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) }, "LightGBM Ranking", LightGbmRankingTrainer.LoadNameValue, LightGbmRankingTrainer.ShortName, DocName = "trainer/LightGBM.md")] @@ -45,7 +48,6 @@ private static VersionInfo GetVersionInfo() private protected override uint VerDefaultValueSerialized => 0x00010004; private protected override uint VerCategoricalSplitSerialized => 0x00010005; private protected override PredictionKind PredictionKind => PredictionKind.Ranking; - internal LightGbmRankingModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { @@ -72,7 +74,10 @@ private static LightGbmRankingModelParameters Create(IHostEnvironment env, Model /// The for training a boosted decision tree ranking model using LightGBM. /// /// - public sealed class LightGbmRankingTrainer : LightGbmTrainerBase, LightGbmRankingModelParameters> + public sealed class LightGbmRankingTrainer : LightGbmTrainerBase, + LightGbmRankingModelParameters> { internal const string UserName = "LightGBM Ranking"; internal const string LoadNameValue = "LightGBMRanking"; @@ -80,9 +85,63 @@ public sealed class LightGbmRankingTrainer : LightGbmTrainerBase PredictionKind.Ranking; + public sealed class Options : OptionsBase + { + public enum EvaluateMetricType + { + None, + Default, + MeanAveragedPrecision, + NormalizedDiscountedCumulativeGain + }; + + /// + /// Comma-separated list of gains associated with each relevance label. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "An array of gains associated to each relevance label.", ShortName = "gains")] + [TGUI(Label = "Ranking Label Gain")] + public int[] CustomGains = { 0,3,7,15,31,63,127,255,511,1023,2047,4095 }; + + /// + /// Parameter for the sigmoid function. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Parameter for the sigmoid function.", ShortName = "sigmoid")] + [TGUI(Label = "Sigmoid", SuggestedSweeps = "0.5,1")] + public double Sigmoid = 0.5; + + /// + /// Determines what evaluation metric to use. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Evaluation metrics.", + ShortName = "em")] + public EvaluateMetricType EvaluationMetric = EvaluateMetricType.NormalizedDiscountedCumulativeGain; + + static Options() + { + NameMapping.Add(nameof(CustomGains), "label_gain"); + NameMapping.Add(nameof(EvaluateMetricType), "metric"); + NameMapping.Add(nameof(EvaluateMetricType.None), ""); + NameMapping.Add(nameof(EvaluateMetricType.MeanAveragedPrecision), "map"); + NameMapping.Add(nameof(EvaluateMetricType.NormalizedDiscountedCumulativeGain), "ndcg"); + } + + internal override Dictionary ToDictionary(IHost host) + { + var res = base.ToDictionary(host); + res[GetOptionName(nameof(Sigmoid))] = Sigmoid; + res[GetOptionName(nameof(CustomGains))] = string.Join(",",CustomGains); + if(EvaluationMetric != EvaluateMetricType.Default) + res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); + + return res; + } + } + internal LightGbmRankingTrainer(IHostEnvironment env, Options options) : base(env, LoadNameValue, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumnName)) { + Contracts.CheckUserArg(options.Sigmoid > 0, nameof(Options.Sigmoid), "must be > 0."); } /// @@ -105,10 +164,19 @@ internal LightGbmRankingTrainer(IHostEnvironment env, int? numberOfLeaves = null, int? minimumExampleCountPerLeaf = null, double? learningRate = null, - int numberOfIterations = Trainers.LightGbm.Options.Defaults.NumberOfIterations) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarColumn(labelColumnName), - featureColumnName, weightsColumnName, rowGroupdColumnName, numberOfLeaves, - minimumExampleCountPerLeaf, learningRate, numberOfIterations) + int numberOfIterations = Defaults.NumberOfIterations) + : this(env, + new Options() + { + LabelColumnName = labelColumnName, + FeatureColumnName = featureColumnName, + ExampleWeightColumnName = weightsColumnName, + RowGroupColumnName = rowGroupdColumnName, + NumberOfLeaves = numberOfLeaves, + MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf, + LearningRate = learningRate, + NumberOfIterations = numberOfIterations + }) { Host.CheckNonEmpty(rowGroupdColumnName, nameof(rowGroupdColumnName)); } @@ -152,20 +220,18 @@ private protected override void CheckLabelCompatible(SchemaShape.Column labelCol private protected override LightGbmRankingModelParameters CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); - var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); + var innerArgs = LightGbmInterfaceUtils.JoinParameters(GbmOptions); return new LightGbmRankingModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs); } private protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, RoleMappedData data, float[] labels, int[] groups) { Host.AssertValue(ch); - Options["objective"] = "lambdarank"; + GbmOptions["objective"] = "lambdarank"; ch.CheckValue(groups, nameof(groups)); - // Add default metric. - if (!Options.ContainsKey("metric")) - Options["metric"] = "ndcg"; + // Only output one ndcg score. - Options["eval_at"] = "5"; + GbmOptions["eval_at"] = "5"; } private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -196,14 +262,14 @@ internal static partial class LightGbm Desc = "Train a LightGBM ranking model.", UserName = LightGbmRankingTrainer.UserName, ShortName = LightGbmRankingTrainer.ShortName)] - public static CommonOutputs.RankingOutput TrainRanking(IHostEnvironment env, Options input) + public static CommonOutputs.RankingOutput TrainRanking(IHostEnvironment env, LightGbmRankingTrainer.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainLightGBM"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return TrainerEntryPointsUtils.Train(host, input, + return TrainerEntryPointsUtils.Train(host, input, () => new LightGbmRankingTrainer(host, input), getLabel: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName), getWeight: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName), diff --git a/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs index c9064f6cba..3a3843c106 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs @@ -2,14 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Runtime; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.LightGbm; -[assembly: LoadableClass(LightGbmRegressionTrainer.Summary, typeof(LightGbmRegressionTrainer), typeof(Options), +[assembly: LoadableClass(LightGbmRegressionTrainer.Summary, typeof(LightGbmRegressionTrainer), typeof(LightGbmRegressionTrainer.Options), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) }, LightGbmRegressionTrainer.UserNameValue, LightGbmRegressionTrainer.LoadNameValue, LightGbmRegressionTrainer.ShortName, DocName = "trainer/LightGBM.md")] @@ -74,7 +76,10 @@ private static LightGbmRegressionModelParameters Create(IHostEnvironment env, Mo /// The for training a boosted decision tree regression model using LightGBM. /// /// - public sealed class LightGbmRegressionTrainer : LightGbmTrainerBase, LightGbmRegressionModelParameters> + public sealed class LightGbmRegressionTrainer : LightGbmTrainerBase, + LightGbmRegressionModelParameters> { internal const string Summary = "LightGBM Regression"; internal const string LoadNameValue = "LightGBMRegression"; @@ -83,6 +88,44 @@ public sealed class LightGbmRegressionTrainer : LightGbmTrainerBase PredictionKind.Regression; + public sealed class Options : OptionsBase + { + public enum EvaluateMetricType + { + None, + Default, + MeanAbsoluteError, + RootMeanSquaredError, + MeanSquaredError + }; + + /// + /// Determines what evaluation metric to use. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Evaluation metrics.", + ShortName = "em")] + public EvaluateMetricType EvaluationMetric = EvaluateMetricType.RootMeanSquaredError; + + static Options() + { + NameMapping.Add(nameof(EvaluateMetricType), "metric"); + NameMapping.Add(nameof(EvaluateMetricType.None), ""); + NameMapping.Add(nameof(EvaluateMetricType.MeanAbsoluteError), "mae"); + NameMapping.Add(nameof(EvaluateMetricType.RootMeanSquaredError), "rmse"); + NameMapping.Add(nameof(EvaluateMetricType.MeanSquaredError), "mse"); + } + + internal override Dictionary ToDictionary(IHost host) + { + var res = base.ToDictionary(host); + if (EvaluationMetric != EvaluateMetricType.Default) + res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); + + return res; + } + } + /// /// Initializes a new instance of /// @@ -101,8 +144,17 @@ internal LightGbmRegressionTrainer(IHostEnvironment env, int? numberOfLeaves = null, int? minimumExampleCountPerLeaf = null, double? learningRate = null, - int numberOfIterations = Trainers.LightGbm.Options.Defaults.NumberOfIterations) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarColumn(labelColumnName), featureColumnName, exampleWeightColumnName, null, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations) + int numberOfIterations = Defaults.NumberOfIterations) + : this(env, new Options() + { + LabelColumnName = labelColumnName, + FeatureColumnName = featureColumnName, + ExampleWeightColumnName = exampleWeightColumnName, + NumberOfLeaves = numberOfLeaves, + MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf, + LearningRate = learningRate, + NumberOfIterations = numberOfIterations + }) { } @@ -115,7 +167,7 @@ private protected override LightGbmRegressionModelParameters CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); - var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); + var innerArgs = LightGbmInterfaceUtils.JoinParameters(GbmOptions); return new LightGbmRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs); } @@ -133,10 +185,7 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) private protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, RoleMappedData data, float[] labels, int[] groups) { - Options["objective"] = "regression"; - // Add default metric. - if (!Options.ContainsKey("metric")) - Options["metric"] = "l2"; + GbmOptions["objective"] = "regression"; } private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -167,14 +216,14 @@ internal static partial class LightGbm Desc = LightGbmRegressionTrainer.Summary, UserName = LightGbmRegressionTrainer.UserNameValue, ShortName = LightGbmRegressionTrainer.ShortName)] - public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Options input) + public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, LightGbmRegressionTrainer.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainLightGBM"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return TrainerEntryPointsUtils.Train(host, input, + return TrainerEntryPointsUtils.Train(host, input, () => new LightGbmRegressionTrainer(host, input), getLabel: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName), getWeight: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName)); diff --git a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs index e1992f1b63..bdaca7a1b7 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs @@ -4,7 +4,11 @@ using System; using System.Collections.Generic; +using System.Text; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; using Microsoft.ML.Trainers; @@ -12,6 +16,12 @@ namespace Microsoft.ML.Trainers.LightGbm { + [BestFriend] + internal static class Defaults + { + public const int NumberOfIterations = 100; + } + /// /// Lock for LightGBM trainer. /// @@ -26,10 +36,239 @@ internal static class LightGbmShared /// /// Base class for all training with LightGBM. /// - public abstract class LightGbmTrainerBase : TrainerEstimatorBaseWithGroupId + public abstract class LightGbmTrainerBase : TrainerEstimatorBaseWithGroupId where TTransformer : ISingleFeaturePredictionTransformer where TModel : class + where TOptions : LightGbmTrainerBase.OptionsBase, new() { + public class OptionsBase : TrainerInputBaseWithGroupId + { + // Static override name map that maps friendly names to lightGBM arguments. + // If an argument is not here, then its name is identicaltto a lightGBM argument + // and does not require a mapping, for example, Subsample. + private protected static Dictionary NameMapping = new Dictionary() + { + {nameof(MinimumExampleCountPerLeaf), "min_data_per_leaf"}, + {nameof(NumberOfLeaves), "num_leaves"}, + {nameof(MaximumBinCountPerFeature), "max_bin" }, + {nameof(MinimumExampleCountPerGroup), "min_data_per_group" }, + {nameof(MaximumCategoricalSplitPointCount), "max_cat_threshold" }, + {nameof(CategoricalSmoothing), "cat_smooth" }, + {nameof(L2CategoricalRegularization), "cat_l2" }, + {nameof(HandleMissingValue), "use_missing" } + }; + + private protected string GetOptionName(string name) + { + if (NameMapping.ContainsKey(name)) + return NameMapping[name]; + return LightGbmInterfaceUtils.GetOptionName(name); + } + + private protected OptionsBase() { } + + /// + /// The number of boosting iterations. A new tree is created in each iteration, so this is equivalent to the number of trees. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations.", SortOrder = 1, ShortName = "iter")] + [TGUI(Label = "Number of boosting iterations", SuggestedSweeps = "10,20,50,100,150,200")] + [TlcModule.SweepableDiscreteParam("NumBoostRound", new object[] { 10, 20, 50, 100, 150, 200 })] + public int NumberOfIterations = Defaults.NumberOfIterations; + + /// + /// The shrinkage rate for trees, used to prevent over-fitting. + /// + /// + /// Valid range is (0,1]. + /// + [Argument(ArgumentType.AtMostOnce, + HelpText = "Shrinkage rate for trees, used to prevent over-fitting. Range: (0,1].", + SortOrder = 2, ShortName = "lr", NullName = "")] + [TGUI(Label = "Learning Rate", SuggestedSweeps = "0.025-0.4;log")] + [TlcModule.SweepableFloatParamAttribute("LearningRate", 0.025f, 0.4f, isLogScale: true)] + public double? LearningRate; + + /// + /// The maximum number of leaves in one tree. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum leaves for trees.", + SortOrder = 2, ShortName = "nl", NullName = "")] + [TGUI(Description = "The maximum number of leaves per tree", SuggestedSweeps = "2-128;log;inc:4")] + [TlcModule.SweepableLongParamAttribute("NumLeaves", 2, 128, isLogScale: true, stepSize: 4)] + public int? NumberOfLeaves; + + /// + /// The minimal number of data points required to form a new tree leaf. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Minimum number of instances needed in a child.", + SortOrder = 2, ShortName = "mil", NullName = "")] + [TGUI(Label = "Min Documents In Leaves", SuggestedSweeps = "1,10,20,50 ")] + [TlcModule.SweepableDiscreteParamAttribute("MinDataPerLeaf", new object[] { 1, 10, 20, 50 })] + public int? MinimumExampleCountPerLeaf; + + /// + /// The maximum number of bins that feature values will be bucketed in. + /// + /// + /// The small number of bins may reduce training accuracy but may increase general power (deal with over-fitting). + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of bucket bin for features.", ShortName = "mb")] + public int MaximumBinCountPerFeature = 255; + + /// + /// Determines which booster to use. + /// + /// + /// Available boosters are , , and . + /// + [Argument(ArgumentType.Multiple, + HelpText = "Which booster to use, can be gbtree, gblinear or dart. gbtree and dart use tree based model while gblinear uses linear function.", + Name="Booster", + SortOrder = 3)] + internal IBoosterParameterFactory BoosterFactory = new GradientBooster.Options(); + + /// + /// Determines whether to output progress status during training and evaluation. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Verbose", ShortName = "v")] + public bool Verbose = false; + + /// + /// Controls the logging level in LighGBM. + /// + /// + /// means only output Fatal errors. means output Fatal, Warning, and Info level messages. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Printing running messages.")] + public bool Silent = true; + + /// + /// Determines the number of threads used to run LightGBM. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Number of parallel threads used to run LightGBM.", ShortName = "nt")] + public int? NumberOfThreads; + + /// + /// Determines the number of rounds, after which training will stop if validation metric doesn't improve. + /// + /// + /// 0 means disable early stopping. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Rounds of early stopping, 0 will disable it.", + ShortName = "es")] + public int EarlyStoppingRound = 0; + + /// + /// Number of data points per batch, when loading data. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Number of entries in a batch when loading data.", Hide = true)] + public int BatchSize = 1 << 20; + + /// + /// Whether to enable categorical split or not. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Enable categorical split or not.", ShortName = "cat")] + [TlcModule.SweepableDiscreteParam("UseCat", new object[] { true, false })] + public bool? UseCategoricalSplit; + + /// + /// Whether to enable special handling of missing value or not. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Enable special handling of missing value or not.")] + [TlcModule.SweepableDiscreteParam("UseMissing", new object[] { true, false })] + public bool HandleMissingValue = true; + + /// + /// The minimum number of data points per categorical group. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Minimum number of instances per categorical group.", ShortName = "mdpg")] + [TlcModule.Range(Inf = 0, Max = int.MaxValue)] + [TlcModule.SweepableDiscreteParam("MinDataPerGroup", new object[] { 10, 50, 100, 200 })] + public int MinimumExampleCountPerGroup = 100; + + /// + /// When the number of categories of one feature is smaller than or equal to , + /// one-vs-other split algorithm will be used. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Max number of categorical thresholds.", ShortName = "maxcat")] + [TlcModule.Range(Inf = 0, Max = int.MaxValue)] + [TlcModule.SweepableDiscreteParam("MaxCatThreshold", new object[] { 8, 16, 32, 64 })] + public int MaximumCategoricalSplitPointCount = 32; + + /// + /// Laplace smooth term in categorical feature split. + /// This can reduce the effect of noises in categorical features, especially for categories with few data. + /// + /// + /// Constraints: >= 0.0 + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Lapalace smooth term in categorical feature spilt. Avoid the bias of small categories.")] + [TlcModule.Range(Min = 0.0)] + [TlcModule.SweepableDiscreteParam("CatSmooth", new object[] { 1, 10, 20 })] + public double CategoricalSmoothing = 10; + + /// + /// L2 regularization for categorical split. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "L2 Regularization for categorical split.")] + [TlcModule.Range(Min = 0.0)] + [TlcModule.SweepableDiscreteParam("CatL2", new object[] { 0.1, 0.5, 1, 5, 10 })] + public double L2CategoricalRegularization = 10; + + /// + /// The random seed for LightGBM to use. + /// + /// + /// If not specified, will generate a random seed to be used. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Sets the random seed for LightGBM to use.")] + public int? Seed; + + [Argument(ArgumentType.Multiple, HelpText = "Parallel LightGBM Learning Algorithm", ShortName = "parag")] + internal ISupportParallel ParallelTrainer = new SingleTrainerFactory(); + + private BoosterParameterBase.OptionsBase _boosterParameter; + + /// + /// Booster parameter to use + /// + public BoosterParameterBase.OptionsBase Booster + { + get => _boosterParameter; + + set + { + _boosterParameter = value; + BoosterFactory = _boosterParameter; + } + } + + internal virtual Dictionary ToDictionary(IHost host) + { + Contracts.CheckValue(host, nameof(host)); + Dictionary res = new Dictionary(); + + var boosterParams = BoosterFactory.CreateComponent(host); + boosterParams.UpdateParameters(res); + res["boosting_type"] = boosterParams.BoosterName; + + res["verbose"] = Silent ? "-1" : "1"; + if (NumberOfThreads.HasValue) + res["nthread"] = NumberOfThreads.Value; + + res["seed"] = (Seed.HasValue) ? Seed : host.Rand.Next(); + + res[GetOptionName(nameof(MaximumBinCountPerFeature))] = MaximumBinCountPerFeature; + res[GetOptionName(nameof(HandleMissingValue))] = HandleMissingValue; + res[GetOptionName(nameof(MinimumExampleCountPerGroup))] = MinimumExampleCountPerGroup; + res[GetOptionName(nameof(MaximumCategoricalSplitPointCount))] = MaximumCategoricalSplitPointCount; + res[GetOptionName(nameof(CategoricalSmoothing))] = CategoricalSmoothing; + res[GetOptionName(nameof(L2CategoricalRegularization))] = L2CategoricalRegularization; + + return res; + } + } + private sealed class CategoricalMetaData { public int NumCol; @@ -40,14 +279,16 @@ private sealed class CategoricalMetaData public bool[] IsCategoricalFeature; } - private protected readonly Options LightGbmTrainerOptions; + // Contains the passed in options when the API is called + private protected readonly TOptions LightGbmTrainerOptions; /// /// Stores argumments as objects to convert them to invariant string type in the end so that /// the code is culture agnostic. When retrieving key value from this dictionary as string /// please convert to string invariant by string.Format(CultureInfo.InvariantCulture, "{0}", Option[key]). /// - private protected Dictionary Options; + private protected Dictionary GbmOptions; + private protected IParallel ParallelTraining; // Store _featureCount and _trainedEnsemble to construct predictor. @@ -67,29 +308,32 @@ private protected LightGbmTrainerBase(IHostEnvironment env, int? minimumExampleCountPerLeaf, double? learningRate, int numberOfIterations) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumnName), - labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(exampleWeightColumnName), TrainerUtils.MakeU4ScalarColumn(rowGroupColumnName)) + : this(env, name, new TOptions() + { + NumberOfLeaves = numberOfLeaves, + MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf, + LearningRate = learningRate, + NumberOfIterations = numberOfIterations, + LabelColumnName = labelColumn.Name, + FeatureColumnName = featureColumnName, + ExampleWeightColumnName = exampleWeightColumnName, + RowGroupColumnName = rowGroupColumnName + }, + labelColumn) { - LightGbmTrainerOptions = new Options(); - - LightGbmTrainerOptions.NumberOfLeaves = numberOfLeaves; - LightGbmTrainerOptions.MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf; - LightGbmTrainerOptions.LearningRate = learningRate; - LightGbmTrainerOptions.NumberOfIterations = numberOfIterations; - - LightGbmTrainerOptions.LabelColumnName = labelColumn.Name; - LightGbmTrainerOptions.FeatureColumnName = featureColumnName; - LightGbmTrainerOptions.ExampleWeightColumnName = exampleWeightColumnName; - LightGbmTrainerOptions.RowGroupColumnName = rowGroupColumnName; - - InitParallelTraining(); } - private protected LightGbmTrainerBase(IHostEnvironment env, string name, Options options, SchemaShape.Column label) + private protected LightGbmTrainerBase(IHostEnvironment env, string name, TOptions options, SchemaShape.Column label) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(options.FeatureColumnName), label, TrainerUtils.MakeR4ScalarWeightColumn(options.ExampleWeightColumnName), TrainerUtils.MakeU4ScalarColumn(options.RowGroupColumnName)) { Host.CheckValue(options, nameof(options)); + Contracts.CheckUserArg(options.NumberOfIterations >= 0, nameof(options.NumberOfIterations), "must be >= 0."); + Contracts.CheckUserArg(options.MaximumBinCountPerFeature > 0, nameof(options.MaximumBinCountPerFeature), "must be > 0."); + Contracts.CheckUserArg(options.MinimumExampleCountPerGroup > 0, nameof(options.MinimumExampleCountPerGroup), "must be > 0."); + Contracts.CheckUserArg(options.MaximumCategoricalSplitPointCount > 0, nameof(options.MaximumCategoricalSplitPointCount), "must be > 0."); + Contracts.CheckUserArg(options.CategoricalSmoothing >= 0, nameof(options.CategoricalSmoothing), "must be >= 0."); + Contracts.CheckUserArg(options.L2CategoricalRegularization >= 0.0, nameof(options.L2CategoricalRegularization), "must be >= 0."); LightGbmTrainerOptions = options; InitParallelTraining(); @@ -130,17 +374,17 @@ private protected override TModel TrainModelCore(TrainContext context) private void InitParallelTraining() { - Options = LightGbmTrainerOptions.ToDictionary(Host); + GbmOptions = LightGbmTrainerOptions.ToDictionary(Host); ParallelTraining = LightGbmTrainerOptions.ParallelTrainer != null ? LightGbmTrainerOptions.ParallelTrainer.CreateComponent(Host) : new SingleTrainer(); if (ParallelTraining.ParallelType() != "serial" && ParallelTraining.NumMachines() > 1) { - Options["tree_learner"] = ParallelTraining.ParallelType(); + GbmOptions["tree_learner"] = ParallelTraining.ParallelType(); var otherParams = ParallelTraining.AdditionalParams(); if (otherParams != null) { foreach (var pair in otherParams) - Options[pair.Key] = pair.Value; + GbmOptions[pair.Key] = pair.Value; } Contracts.CheckValue(ParallelTraining.GetReduceScatterFunction(), nameof(ParallelTraining.GetReduceScatterFunction)); @@ -166,14 +410,14 @@ private protected virtual void CheckDataValid(IChannel ch, RoleMappedData data) ch.CheckParam(data.Schema.Label.HasValue, nameof(data), "Need a label column"); } - private protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCategarical, int totalCats, bool hiddenMsg = false) + private protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCategorical, int totalCats, bool hiddenMsg = false) { - double learningRate = LightGbmTrainerOptions.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats); - int numberOfLeaves = LightGbmTrainerOptions.NumberOfLeaves ?? DefaultNumLeaves(numRow, hasCategarical, totalCats); + double learningRate = LightGbmTrainerOptions.LearningRate ?? DefaultLearningRate(numRow, hasCategorical, totalCats); + int numberOfLeaves = LightGbmTrainerOptions.NumberOfLeaves ?? DefaultNumLeaves(numRow, hasCategorical, totalCats); int minimumExampleCountPerLeaf = LightGbmTrainerOptions.MinimumExampleCountPerLeaf ?? DefaultMinDataPerLeaf(numRow, numberOfLeaves, 1); - Options["learning_rate"] = learningRate; - Options["num_leaves"] = numberOfLeaves; - Options["min_data_per_leaf"] = minimumExampleCountPerLeaf; + GbmOptions["learning_rate"] = learningRate; + GbmOptions["num_leaves"] = numberOfLeaves; + GbmOptions["min_data_per_leaf"] = minimumExampleCountPerLeaf; if (!hiddenMsg) { if (!LightGbmTrainerOptions.LearningRate.HasValue) @@ -186,7 +430,7 @@ private protected virtual void GetDefaultParameters(IChannel ch, int numRow, boo } [BestFriend] - internal Dictionary GetGbmParameters() => Options; + internal Dictionary GetGbmParameters() => GbmOptions; private FloatLabelCursor.Factory CreateCursorFactory(RoleMappedData data) { @@ -297,7 +541,7 @@ private CategoricalMetaData GetCategoricalMetaData(IChannel ch, RoleMappedData t { var catIndices = ConstructCategoricalFeatureMetaData(categoricalFeatures, rawNumCol, ref catMetaData); // Set categorical features - Options["categorical_feature"] = string.Join(",", catIndices); + GbmOptions["categorical_feature"] = string.Join(",", catIndices); } return catMetaData; } @@ -316,9 +560,10 @@ private Dataset LoadTrainingData(IChannel ch, RoleMappedData trainData, out Cate catMetaData = GetCategoricalMetaData(ch, trainData, numRow); GetDefaultParameters(ch, numRow, catMetaData.CategoricalBoudaries != null, catMetaData.TotalCats); - Dataset dtrain; - string param = LightGbmInterfaceUtils.JoinParameters(Options); + CheckAndUpdateParametersBeforeTraining(ch, trainData, labels, groups); + string param = LightGbmInterfaceUtils.JoinParameters(GbmOptions); + Dataset dtrain; // To reduce peak memory usage, only enable one sampling task at any given time. lock (LightGbmShared.SampleLock) { @@ -329,8 +574,6 @@ private Dataset LoadTrainingData(IChannel ch, RoleMappedData trainData, out Cate // Push rows into dataset. LoadDataset(ch, factory, dtrain, numRow, LightGbmTrainerOptions.BatchSize, catMetaData); - // Some checks. - CheckAndUpdateParametersBeforeTraining(ch, trainData, labels, groups); return dtrain; } @@ -362,15 +605,16 @@ private void TrainCore(IChannel ch, IProgressChannel pch, Dataset dtrain, Catego Host.AssertValue(pch); Host.AssertValue(dtrain); Host.AssertValueOrNull(dvalid); + // For multi class, the number of labels is required. - ch.Assert(((ITrainer)this).PredictionKind != PredictionKind.MulticlassClassification || Options.ContainsKey("num_class"), + ch.Assert(((ITrainer)this).PredictionKind != PredictionKind.MulticlassClassification || GbmOptions.ContainsKey("num_class"), "LightGBM requires the number of classes to be specified in the parameters."); // Only enable one trainer to run at one time. lock (LightGbmShared.LockForMultiThreadingInside) { - ch.Info("LightGBM objective={0}", Options["objective"]); - using (Booster bst = WrappedLightGbmTraining.Train(ch, pch, Options, dtrain, + ch.Info("LightGBM objective={0}", GbmOptions["objective"]); + using (Booster bst = WrappedLightGbmTraining.Train(ch, pch, GbmOptions, dtrain, dvalid: dvalid, numIteration: LightGbmTrainerOptions.NumberOfIterations, verboseEval: LightGbmTrainerOptions.Verbose, earlyStoppingRound: LightGbmTrainerOptions.EarlyStoppingRound)) { diff --git a/src/Microsoft.ML.LightGbm/WrappedLightGbmInterface.cs b/src/Microsoft.ML.LightGbm/WrappedLightGbmInterface.cs index db72a689b6..aea7bd4162 100644 --- a/src/Microsoft.ML.LightGbm/WrappedLightGbmInterface.cs +++ b/src/Microsoft.ML.LightGbm/WrappedLightGbmInterface.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Globalization; using System.Runtime.InteropServices; +using System.Text; using Microsoft.ML.Runtime; namespace Microsoft.ML.Trainers.LightGbm @@ -225,6 +226,32 @@ public static string JoinParameters(Dictionary parameters) return string.Join(" ", res); } + /// + /// Helper function used for generating the LightGbm argument name. + /// When given a name, this will convert the name to lower-case with underscores. + /// The underscore will be placed when an upper-case letter is encountered. + /// + public static string GetOptionName(string name) + { + // Otherwise convert the name to the light gbm argument + StringBuilder strBuf = new StringBuilder(); + bool first = true; + foreach (char c in name) + { + if (char.IsUpper(c)) + { + if (first) + first = false; + else + strBuf.Append('_'); + strBuf.Append(char.ToLower(c)); + } + else + strBuf.Append(c); + } + return strBuf.ToString(); + } + /// /// Convert the pointer of c string to c# string. /// diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index af4afd7cfc..a23c9b29ae 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -53,10 +53,10 @@ Trainers.FieldAwareFactorizationMachineBinaryClassifier Train a field-aware fact Trainers.GeneralizedAdditiveModelBinaryClassifier Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainBinary Microsoft.ML.Trainers.FastTree.GamBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.GeneralizedAdditiveModelRegressor Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainRegression Microsoft.ML.Trainers.FastTree.GamRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.KMeansPlusPlusClusterer K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers. Microsoft.ML.Trainers.KMeansTrainer TrainKMeans Microsoft.ML.Trainers.KMeansTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+ClusteringOutput -Trainers.LightGbmBinaryClassifier Train a LightGBM binary classification model. Microsoft.ML.Trainers.LightGbm.LightGbm TrainBinary Microsoft.ML.Trainers.LightGbm.Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Trainers.LightGbmClassifier Train a LightGBM multi class model. Microsoft.ML.Trainers.LightGbm.LightGbm TrainMulticlass Microsoft.ML.Trainers.LightGbm.Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput -Trainers.LightGbmRanker Train a LightGBM ranking model. Microsoft.ML.Trainers.LightGbm.LightGbm TrainRanking Microsoft.ML.Trainers.LightGbm.Options Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput -Trainers.LightGbmRegressor LightGBM Regression Microsoft.ML.Trainers.LightGbm.LightGbm TrainRegression Microsoft.ML.Trainers.LightGbm.Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput +Trainers.LightGbmBinaryClassifier Train a LightGBM binary classification model. Microsoft.ML.Trainers.LightGbm.LightGbm TrainBinary Microsoft.ML.Trainers.LightGbm.LightGbmBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput +Trainers.LightGbmClassifier Train a LightGBM multi class model. Microsoft.ML.Trainers.LightGbm.LightGbm TrainMulticlass Microsoft.ML.Trainers.LightGbm.LightGbmMulticlassTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput +Trainers.LightGbmRanker Train a LightGBM ranking model. Microsoft.ML.Trainers.LightGbm.LightGbm TrainRanking Microsoft.ML.Trainers.LightGbm.LightGbmRankingTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput +Trainers.LightGbmRegressor LightGBM Regression Microsoft.ML.Trainers.LightGbm.LightGbm TrainRegression Microsoft.ML.Trainers.LightGbm.LightGbmRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.LinearSvmBinaryClassifier Train a linear SVM. Microsoft.ML.Trainers.LinearSvmTrainer TrainLinearSvm Microsoft.ML.Trainers.LinearSvmTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.LogisticRegressionBinaryClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer TrainBinary Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.LogisticRegressionClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer TrainMulticlass Microsoft.ML.Trainers.LogisticRegressionMulticlassClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index f0e18bf050..c9251c758f 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -11243,23 +11243,11 @@ "Default": "Auto" }, { - "Name": "MaximumBinCountPerFeature", - "Type": "Int", - "Desc": "Maximum number of bucket bin for features.", - "Aliases": [ - "mb" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 255 - }, - { - "Name": "Verbose", + "Name": "UnbalancedSets", "Type": "Bool", - "Desc": "Verbose", + "Desc": "Use for binary classification when training data is not balanced.", "Aliases": [ - "v" + "us" ], "Required": false, "SortOrder": 150.0, @@ -11267,41 +11255,39 @@ "Default": false }, { - "Name": "Silent", - "Type": "Bool", - "Desc": "Printing running messages.", + "Name": "WeightOfPositiveExamples", + "Type": "Float", + "Desc": "Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: sum(negative cases) / sum(positive cases).", + "Aliases": [ + "ScalePosWeight" + ], "Required": false, "SortOrder": 150.0, "IsNullable": false, - "Default": true + "Default": 1.0 }, { - "Name": "NumberOfThreads", - "Type": "Int", - "Desc": "Number of parallel threads used to run LightGBM.", + "Name": "Sigmoid", + "Type": "Float", + "Desc": "Parameter for the sigmoid function.", "Aliases": [ - "nt" + "sigmoid" ], "Required": false, "SortOrder": 150.0, - "IsNullable": true, - "Default": null + "IsNullable": false, + "Default": 0.5 }, { "Name": "EvaluationMetric", "Type": { "Kind": "Enum", "Values": [ - "DefaultMetric", - "Rmse", - "Mae", + "None", + "Default", "Logloss", "Error", - "Merror", - "Mlogloss", - "Auc", - "Ndcg", - "Map" + "AreaUnderCurve" ] }, "Desc": "Evaluation metrics.", @@ -11311,59 +11297,64 @@ "Required": false, "SortOrder": 150.0, "IsNullable": false, - "Default": "DefaultMetric" + "Default": "Logloss" }, { - "Name": "UseSoftmax", - "Type": "Bool", - "Desc": "Use softmax loss for the multi classification.", + "Name": "MaximumBinCountPerFeature", + "Type": "Int", + "Desc": "Maximum number of bucket bin for features.", + "Aliases": [ + "mb" + ], "Required": false, "SortOrder": 150.0, - "IsNullable": true, - "Default": null, - "SweepRange": { - "RangeType": "Discrete", - "Values": [ - true, - false - ] - } + "IsNullable": false, + "Default": 255 }, { - "Name": "EarlyStoppingRound", - "Type": "Int", - "Desc": "Rounds of early stopping, 0 will disable it.", + "Name": "Verbose", + "Type": "Bool", + "Desc": "Verbose", "Aliases": [ - "es" + "v" ], "Required": false, "SortOrder": 150.0, "IsNullable": false, - "Default": 0 + "Default": false }, { - "Name": "CustomGains", - "Type": "String", - "Desc": "Comma separated list of gains associated to each relevance label.", + "Name": "Silent", + "Type": "Bool", + "Desc": "Printing running messages.", + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": true + }, + { + "Name": "NumberOfThreads", + "Type": "Int", + "Desc": "Number of parallel threads used to run LightGBM.", "Aliases": [ - "gains" + "nt" ], "Required": false, "SortOrder": 150.0, - "IsNullable": false, - "Default": "0,3,7,15,31,63,127,255,511,1023,2047,4095" + "IsNullable": true, + "Default": null }, { - "Name": "Sigmoid", - "Type": "Float", - "Desc": "Parameter for the sigmoid function. Used only in LightGbmBinaryTrainer, LightGbmMulticlassTrainer and in LightGbmRankingTrainer.", + "Name": "EarlyStoppingRound", + "Type": "Int", + "Desc": "Rounds of early stopping, 0 will disable it.", "Aliases": [ - "sigmoid" + "es" ], "Required": false, "SortOrder": 150.0, "IsNullable": false, - "Default": 0.5 + "Default": 0 }, { "Name": "BatchSize", @@ -11400,7 +11391,7 @@ "Required": false, "SortOrder": 150.0, "IsNullable": false, - "Default": false, + "Default": true, "SweepRange": { "RangeType": "Discrete", "Values": [ @@ -11462,7 +11453,7 @@ { "Name": "CategoricalSmoothing", "Type": "Float", - "Desc": "Laplace smooth term in categorical feature split. Avoid the bias of small categories.", + "Desc": "Lapalace smooth term in categorical feature spilt. Avoid the bias of small categories.", "Required": false, "SortOrder": 150.0, "IsNullable": false, @@ -11745,6 +11736,54 @@ "IsNullable": false, "Default": "Auto" }, + { + "Name": "UseSoftmax", + "Type": "Bool", + "Desc": "Use softmax loss for the multi classification.", + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null, + "SweepRange": { + "RangeType": "Discrete", + "Values": [ + true, + false + ] + } + }, + { + "Name": "Sigmoid", + "Type": "Float", + "Desc": "Parameter for the sigmoid function.", + "Aliases": [ + "sigmoid" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": 0.5 + }, + { + "Name": "EvaluationMetric", + "Type": { + "Kind": "Enum", + "Values": [ + "None", + "Default", + "Error", + "LogLoss" + ] + }, + "Desc": "Evaluation metrics.", + "Aliases": [ + "em" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": "Error" + }, { "Name": "MaximumBinCountPerFeature", "Type": "Int", @@ -11790,48 +11829,6 @@ "IsNullable": true, "Default": null }, - { - "Name": "EvaluationMetric", - "Type": { - "Kind": "Enum", - "Values": [ - "DefaultMetric", - "Rmse", - "Mae", - "Logloss", - "Error", - "Merror", - "Mlogloss", - "Auc", - "Ndcg", - "Map" - ] - }, - "Desc": "Evaluation metrics.", - "Aliases": [ - "em" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": "DefaultMetric" - }, - { - "Name": "UseSoftmax", - "Type": "Bool", - "Desc": "Use softmax loss for the multi classification.", - "Required": false, - "SortOrder": 150.0, - "IsNullable": true, - "Default": null, - "SweepRange": { - "RangeType": "Discrete", - "Values": [ - true, - false - ] - } - }, { "Name": "EarlyStoppingRound", "Type": "Int", @@ -11844,30 +11841,6 @@ "IsNullable": false, "Default": 0 }, - { - "Name": "CustomGains", - "Type": "String", - "Desc": "Comma separated list of gains associated to each relevance label.", - "Aliases": [ - "gains" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": "0,3,7,15,31,63,127,255,511,1023,2047,4095" - }, - { - "Name": "Sigmoid", - "Type": "Float", - "Desc": "Parameter for the sigmoid function. Used only in LightGbmBinaryTrainer, LightGbmMulticlassTrainer and in LightGbmRankingTrainer.", - "Aliases": [ - "sigmoid" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 0.5 - }, { "Name": "BatchSize", "Type": "Int", @@ -11903,7 +11876,7 @@ "Required": false, "SortOrder": 150.0, "IsNullable": false, - "Default": false, + "Default": true, "SweepRange": { "RangeType": "Discrete", "Values": [ @@ -11965,7 +11938,7 @@ { "Name": "CategoricalSmoothing", "Type": "Float", - "Desc": "Laplace smooth term in categorical feature split. Avoid the bias of small categories.", + "Desc": "Lapalace smooth term in categorical feature spilt. Avoid the bias of small categories.", "Required": false, "SortOrder": 150.0, "IsNullable": false, @@ -12248,6 +12221,66 @@ "IsNullable": false, "Default": "Auto" }, + { + "Name": "CustomGains", + "Type": { + "Kind": "Array", + "ItemType": "Int" + }, + "Desc": "An array of gains associated to each relevance label.", + "Aliases": [ + "gains" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": [ + 0, + 3, + 7, + 15, + 31, + 63, + 127, + 255, + 511, + 1023, + 2047, + 4095 + ] + }, + { + "Name": "Sigmoid", + "Type": "Float", + "Desc": "Parameter for the sigmoid function.", + "Aliases": [ + "sigmoid" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": 0.5 + }, + { + "Name": "EvaluationMetric", + "Type": { + "Kind": "Enum", + "Values": [ + "None", + "Default", + "MeanAveragedPrecision", + "NormalizedDiscountedCumulativeGain" + ] + }, + "Desc": "Evaluation metrics.", + "Aliases": [ + "em" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": "NormalizedDiscountedCumulativeGain" + }, { "Name": "MaximumBinCountPerFeature", "Type": "Int", @@ -12293,48 +12326,6 @@ "IsNullable": true, "Default": null }, - { - "Name": "EvaluationMetric", - "Type": { - "Kind": "Enum", - "Values": [ - "DefaultMetric", - "Rmse", - "Mae", - "Logloss", - "Error", - "Merror", - "Mlogloss", - "Auc", - "Ndcg", - "Map" - ] - }, - "Desc": "Evaluation metrics.", - "Aliases": [ - "em" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": "DefaultMetric" - }, - { - "Name": "UseSoftmax", - "Type": "Bool", - "Desc": "Use softmax loss for the multi classification.", - "Required": false, - "SortOrder": 150.0, - "IsNullable": true, - "Default": null, - "SweepRange": { - "RangeType": "Discrete", - "Values": [ - true, - false - ] - } - }, { "Name": "EarlyStoppingRound", "Type": "Int", @@ -12347,30 +12338,6 @@ "IsNullable": false, "Default": 0 }, - { - "Name": "CustomGains", - "Type": "String", - "Desc": "Comma separated list of gains associated to each relevance label.", - "Aliases": [ - "gains" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": "0,3,7,15,31,63,127,255,511,1023,2047,4095" - }, - { - "Name": "Sigmoid", - "Type": "Float", - "Desc": "Parameter for the sigmoid function. Used only in LightGbmBinaryTrainer, LightGbmMulticlassTrainer and in LightGbmRankingTrainer.", - "Aliases": [ - "sigmoid" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 0.5 - }, { "Name": "BatchSize", "Type": "Int", @@ -12406,7 +12373,7 @@ "Required": false, "SortOrder": 150.0, "IsNullable": false, - "Default": false, + "Default": true, "SweepRange": { "RangeType": "Discrete", "Values": [ @@ -12468,7 +12435,7 @@ { "Name": "CategoricalSmoothing", "Type": "Float", - "Desc": "Laplace smooth term in categorical feature split. Avoid the bias of small categories.", + "Desc": "Lapalace smooth term in categorical feature spilt. Avoid the bias of small categories.", "Required": false, "SortOrder": 150.0, "IsNullable": false, @@ -12751,6 +12718,27 @@ "IsNullable": false, "Default": "Auto" }, + { + "Name": "EvaluationMetric", + "Type": { + "Kind": "Enum", + "Values": [ + "None", + "Default", + "MeanAbsoluteError", + "RootMeanSquaredError", + "MeanSquaredError" + ] + }, + "Desc": "Evaluation metrics.", + "Aliases": [ + "em" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": "RootMeanSquaredError" + }, { "Name": "MaximumBinCountPerFeature", "Type": "Int", @@ -12796,48 +12784,6 @@ "IsNullable": true, "Default": null }, - { - "Name": "EvaluationMetric", - "Type": { - "Kind": "Enum", - "Values": [ - "DefaultMetric", - "Rmse", - "Mae", - "Logloss", - "Error", - "Merror", - "Mlogloss", - "Auc", - "Ndcg", - "Map" - ] - }, - "Desc": "Evaluation metrics.", - "Aliases": [ - "em" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": "DefaultMetric" - }, - { - "Name": "UseSoftmax", - "Type": "Bool", - "Desc": "Use softmax loss for the multi classification.", - "Required": false, - "SortOrder": 150.0, - "IsNullable": true, - "Default": null, - "SweepRange": { - "RangeType": "Discrete", - "Values": [ - true, - false - ] - } - }, { "Name": "EarlyStoppingRound", "Type": "Int", @@ -12850,30 +12796,6 @@ "IsNullable": false, "Default": 0 }, - { - "Name": "CustomGains", - "Type": "String", - "Desc": "Comma separated list of gains associated to each relevance label.", - "Aliases": [ - "gains" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": "0,3,7,15,31,63,127,255,511,1023,2047,4095" - }, - { - "Name": "Sigmoid", - "Type": "Float", - "Desc": "Parameter for the sigmoid function. Used only in LightGbmBinaryTrainer, LightGbmMulticlassTrainer and in LightGbmRankingTrainer.", - "Aliases": [ - "sigmoid" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 0.5 - }, { "Name": "BatchSize", "Type": "Int", @@ -12909,7 +12831,7 @@ "Required": false, "SortOrder": 150.0, "IsNullable": false, - "Default": false, + "Default": true, "SweepRange": { "RangeType": "Discrete", "Values": [ @@ -12971,7 +12893,7 @@ { "Name": "CategoricalSmoothing", "Type": "Float", - "Desc": "Laplace smooth term in categorical feature split. Avoid the bias of small categories.", + "Desc": "Lapalace smooth term in categorical feature spilt. Avoid the bias of small categories.", "Required": false, "SortOrder": 150.0, "IsNullable": false, @@ -23517,13 +23439,13 @@ "Components": [ { "Name": "dart", - "Desc": "Dropouts meet Multiple Additive Regression Trees. See https://arxiv.org/abs/1505.01866", + "Desc": "Dropouts meet Multiple Additive Regresion Trees. See https://arxiv.org/abs/1505.01866", "FriendlyName": "Tree Dropout Tree Booster", "Settings": [ { "Name": "TreeDropFraction", "Type": "Float", - "Desc": "The drop ratio for trees. Range:[0,1].", + "Desc": "The drop ratio for trees. Range:(0,1).", "Required": false, "SortOrder": 150.0, "IsNullable": false, @@ -23577,18 +23499,6 @@ "IsNullable": false, "Default": false }, - { - "Name": "UnbalancedSets", - "Type": "Bool", - "Desc": "Use for binary classification when training data is not balanced.", - "Aliases": [ - "us" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": false - }, { "Name": "MinimumSplitGain", "Type": "Float", @@ -23713,18 +23623,6 @@ 1.0 ] } - }, - { - "Name": "WeightOfPositiveExamples", - "Type": "Float", - "Desc": "Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: sum(negative cases) / sum(positive cases).", - "Aliases": [ - "ScalePosWeight" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 1.0 } ] }, @@ -23733,18 +23631,6 @@ "Desc": "Traditional Gradient Boosting Decision Tree.", "FriendlyName": "Tree Booster", "Settings": [ - { - "Name": "UnbalancedSets", - "Type": "Bool", - "Desc": "Use for binary classification when training data is not balanced.", - "Aliases": [ - "us" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": false - }, { "Name": "MinimumSplitGain", "Type": "Float", @@ -23869,18 +23755,6 @@ 1.0 ] } - }, - { - "Name": "WeightOfPositiveExamples", - "Type": "Float", - "Desc": "Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: sum(negative cases) / sum(positive cases).", - "Aliases": [ - "ScalePosWeight" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 1.0 } ] }, @@ -23915,18 +23789,6 @@ "Max": 1.0 } }, - { - "Name": "UnbalancedSets", - "Type": "Bool", - "Desc": "Use for binary classification when training data is not balanced.", - "Aliases": [ - "us" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": false - }, { "Name": "MinimumSplitGain", "Type": "Float", @@ -24051,18 +23913,6 @@ 1.0 ] } - }, - { - "Name": "WeightOfPositiveExamples", - "Type": "Float", - "Desc": "Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: sum(negative cases) / sum(positive cases).", - "Aliases": [ - "ScalePosWeight" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": 1.0 } ] } diff --git a/test/BaselineOutput/Common/LightGBMR/LightGBMReg-CV-generatedRegressionDataset-out.txt b/test/BaselineOutput/Common/LightGBMR/LightGBMReg-CV-generatedRegressionDataset-out.txt index 1fc6084997..c81802feb6 100644 --- a/test/BaselineOutput/Common/LightGBMR/LightGBMReg-CV-generatedRegressionDataset-out.txt +++ b/test/BaselineOutput/Common/LightGBMR/LightGBMReg-CV-generatedRegressionDataset-out.txt @@ -35,10 +35,10 @@ Virtual memory usage(MB): %Number% [1] 'Loading data for LightGBM' started. [1] 'Loading data for LightGBM' finished in %Time%. [2] 'Training with LightGBM' started. -[2] (%Time%) Iteration: 50 Training-l2: 37.107605006517 +[2] (%Time%) Iteration: 50 Training-rmse: 6.09160118577349 [2] 'Training with LightGBM' finished in %Time%. [3] 'Loading data for LightGBM #2' started. [3] 'Loading data for LightGBM #2' finished in %Time%. [4] 'Training with LightGBM #2' started. -[4] (%Time%) Iteration: 50 Training-l2: 27.7037679135951 +[4] (%Time%) Iteration: 50 Training-rmse: 5.26343689176522 [4] 'Training with LightGBM #2' finished in %Time%. diff --git a/test/BaselineOutput/Common/LightGBMR/LightGBMReg-TrainTest-generatedRegressionDataset-out.txt b/test/BaselineOutput/Common/LightGBMR/LightGBMReg-TrainTest-generatedRegressionDataset-out.txt index 909d9f0012..c88cad62b7 100644 --- a/test/BaselineOutput/Common/LightGBMR/LightGBMReg-TrainTest-generatedRegressionDataset-out.txt +++ b/test/BaselineOutput/Common/LightGBMR/LightGBMReg-TrainTest-generatedRegressionDataset-out.txt @@ -26,7 +26,7 @@ Virtual memory usage(MB): %Number% [1] 'Loading data for LightGBM' started. [1] 'Loading data for LightGBM' finished in %Time%. [2] 'Training with LightGBM' started. -[2] (%Time%) Iteration: 50 Training-l2: 26.0644295080124 +[2] (%Time%) Iteration: 50 Training-rmse: 5.10533343749577 [2] 'Training with LightGBM' finished in %Time%. [3] 'Saving model' started. [3] 'Saving model' finished in %Time%. diff --git a/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-CV-generatedRegressionDataset.MAE-out.txt b/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-CV-generatedRegressionDataset.MAE-out.txt index 4550a80d3c..62d41c06a5 100644 --- a/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-CV-generatedRegressionDataset.MAE-out.txt +++ b/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-CV-generatedRegressionDataset.MAE-out.txt @@ -1,4 +1,4 @@ -maml.exe CV tr=LightGBMR{nt=1 iter=50 em=mae v=+ lr=0.2 mil=10 nl=20} threads=- dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% seed=1 +maml.exe CV tr=LightGBMR{nt=1 iter=50 em=MeanAbsoluteError v=+ lr=0.2 mil=10 nl=20} threads=- dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% seed=1 Not adding a normalizer. Auto-tuning parameters: UseCategoricalSplit = False LightGBM objective=regression @@ -35,10 +35,10 @@ Virtual memory usage(MB): %Number% [1] 'Loading data for LightGBM' started. [1] 'Loading data for LightGBM' finished in %Time%. [2] 'Training with LightGBM' started. -[2] (%Time%) Iteration: 50 Training-l1: 2.72125878362794 +[2] (%Time%) Iteration: 50 Training-mae: 2.72125878362794 [2] 'Training with LightGBM' finished in %Time%. [3] 'Loading data for LightGBM #2' started. [3] 'Loading data for LightGBM #2' finished in %Time%. [4] 'Training with LightGBM #2' started. -[4] (%Time%) Iteration: 50 Training-l1: 2.24116204430926 +[4] (%Time%) Iteration: 50 Training-mae: 2.24116204430926 [4] 'Training with LightGBM #2' finished in %Time%. diff --git a/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-CV-generatedRegressionDataset.MAE-rp.txt b/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-CV-generatedRegressionDataset.MAE-rp.txt index d7359a5e25..8a563605cf 100644 --- a/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-CV-generatedRegressionDataset.MAE-rp.txt +++ b/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-CV-generatedRegressionDataset.MAE-rp.txt @@ -1,4 +1,4 @@ LightGBMR -L1(avg) L2(avg) RMS(avg) Loss-fn(avg) R Squared /iter /lr /nl /mil /v /nt /em Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings -26.59978 1393.326 37.32081 1393.326 0.923402 50 0.2 20 10 + 1 Mae LightGBMR %Data% %Output% 99 0 0 maml.exe CV tr=LightGBMR{nt=1 iter=50 em=mae v=+ lr=0.2 mil=10 nl=20} threads=- dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% seed=1 /iter:50;/lr:0.2;/nl:20;/mil:10;/v:+;/nt:1;/em:Mae +L1(avg) L2(avg) RMS(avg) Loss-fn(avg) R Squared /em /iter /lr /nl /mil /v /nt Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings +26.59978 1393.326 37.32081 1393.326 0.923402 MeanAbsoluteError 50 0.2 20 10 + 1 LightGBMR %Data% %Output% 99 0 0 maml.exe CV tr=LightGBMR{nt=1 iter=50 em=MeanAbsoluteError v=+ lr=0.2 mil=10 nl=20} threads=- dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% seed=1 /em:MeanAbsoluteError;/iter:50;/lr:0.2;/nl:20;/mil:10;/v:+;/nt:1 diff --git a/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-TrainTest-generatedRegressionDataset.MAE-out.txt b/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-TrainTest-generatedRegressionDataset.MAE-out.txt index 59d2ceaa05..11a1112241 100644 --- a/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-TrainTest-generatedRegressionDataset.MAE-out.txt +++ b/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-TrainTest-generatedRegressionDataset.MAE-out.txt @@ -1,4 +1,4 @@ -maml.exe TrainTest test=%Data% tr=LightGBMR{nt=1 iter=50 em=mae v=+ lr=0.2 mil=10 nl=20} dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% out=%Output% seed=1 +maml.exe TrainTest test=%Data% tr=LightGBMR{nt=1 iter=50 em=MeanAbsoluteError v=+ lr=0.2 mil=10 nl=20} dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% out=%Output% seed=1 Not adding a normalizer. Auto-tuning parameters: UseCategoricalSplit = False LightGBM objective=regression @@ -26,7 +26,7 @@ Virtual memory usage(MB): %Number% [1] 'Loading data for LightGBM' started. [1] 'Loading data for LightGBM' finished in %Time%. [2] 'Training with LightGBM' started. -[2] (%Time%) Iteration: 50 Training-l1: 3.42889585604196 +[2] (%Time%) Iteration: 50 Training-mae: 3.42889585604196 [2] 'Training with LightGBM' finished in %Time%. [3] 'Saving model' started. [3] 'Saving model' finished in %Time%. diff --git a/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-TrainTest-generatedRegressionDataset.MAE-rp.txt b/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-TrainTest-generatedRegressionDataset.MAE-rp.txt index 59ac27dc61..d96adb798e 100644 --- a/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-TrainTest-generatedRegressionDataset.MAE-rp.txt +++ b/test/BaselineOutput/Common/LightGBMR/LightGBMRegMae-TrainTest-generatedRegressionDataset.MAE-rp.txt @@ -1,4 +1,4 @@ LightGBMR -L1(avg) L2(avg) RMS(avg) Loss-fn(avg) R Squared /iter /lr /nl /mil /v /nt /em Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings -3.428896 25.23601 5.023546 25.23601 0.998616 50 0.2 20 10 + 1 Mae LightGBMR %Data% %Data% %Output% 99 0 0 maml.exe TrainTest test=%Data% tr=LightGBMR{nt=1 iter=50 em=mae v=+ lr=0.2 mil=10 nl=20} dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% out=%Output% seed=1 /iter:50;/lr:0.2;/nl:20;/mil:10;/v:+;/nt:1;/em:Mae +L1(avg) L2(avg) RMS(avg) Loss-fn(avg) R Squared /em /iter /lr /nl /mil /v /nt Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings +3.428896 25.23601 5.023546 25.23601 0.998616 MeanAbsoluteError 50 0.2 20 10 + 1 LightGBMR %Data% %Data% %Output% 99 0 0 maml.exe TrainTest test=%Data% tr=LightGBMR{nt=1 iter=50 em=MeanAbsoluteError v=+ lr=0.2 mil=10 nl=20} dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% out=%Output% seed=1 /em:MeanAbsoluteError;/iter:50;/lr:0.2;/nl:20;/mil:10;/v:+;/nt:1 diff --git a/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-CV-generatedRegressionDataset.RMSE-out.txt b/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-CV-generatedRegressionDataset.RMSE-out.txt index 71d131bb5a..31c4280d50 100644 --- a/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-CV-generatedRegressionDataset.RMSE-out.txt +++ b/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-CV-generatedRegressionDataset.RMSE-out.txt @@ -1,4 +1,4 @@ -maml.exe CV tr=LightGBMR{nt=1 iter=50 em=rmse v=+ lr=0.2 mil=10 nl=20} threads=- dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% seed=1 +maml.exe CV tr=LightGBMR{nt=1 iter=50 em=RootMeanSquaredError v=+ lr=0.2 mil=10 nl=20} threads=- dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% seed=1 Not adding a normalizer. Auto-tuning parameters: UseCategoricalSplit = False LightGBM objective=regression diff --git a/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-CV-generatedRegressionDataset.RMSE-rp.txt b/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-CV-generatedRegressionDataset.RMSE-rp.txt index 1b893d09dc..d855ef6c79 100644 --- a/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-CV-generatedRegressionDataset.RMSE-rp.txt +++ b/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-CV-generatedRegressionDataset.RMSE-rp.txt @@ -1,4 +1,4 @@ LightGBMR -L1(avg) L2(avg) RMS(avg) Loss-fn(avg) R Squared /iter /lr /nl /mil /v /nt /em Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings -26.59978 1393.326 37.32081 1393.326 0.923402 50 0.2 20 10 + 1 Rmse LightGBMR %Data% %Output% 99 0 0 maml.exe CV tr=LightGBMR{nt=1 iter=50 em=rmse v=+ lr=0.2 mil=10 nl=20} threads=- dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% seed=1 /iter:50;/lr:0.2;/nl:20;/mil:10;/v:+;/nt:1;/em:Rmse +L1(avg) L2(avg) RMS(avg) Loss-fn(avg) R Squared /iter /lr /nl /mil /v /nt Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings +26.59978 1393.326 37.32081 1393.326 0.923402 50 0.2 20 10 + 1 LightGBMR %Data% %Output% 99 0 0 maml.exe CV tr=LightGBMR{nt=1 iter=50 em=RootMeanSquaredError v=+ lr=0.2 mil=10 nl=20} threads=- dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% seed=1 /iter:50;/lr:0.2;/nl:20;/mil:10;/v:+;/nt:1 diff --git a/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-TrainTest-generatedRegressionDataset.RMSE-out.txt b/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-TrainTest-generatedRegressionDataset.RMSE-out.txt index c919475347..cd21c6e9d4 100644 --- a/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-TrainTest-generatedRegressionDataset.RMSE-out.txt +++ b/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-TrainTest-generatedRegressionDataset.RMSE-out.txt @@ -1,4 +1,4 @@ -maml.exe TrainTest test=%Data% tr=LightGBMR{nt=1 iter=50 em=rmse v=+ lr=0.2 mil=10 nl=20} dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% out=%Output% seed=1 +maml.exe TrainTest test=%Data% tr=LightGBMR{nt=1 iter=50 em=RootMeanSquaredError v=+ lr=0.2 mil=10 nl=20} dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% out=%Output% seed=1 Not adding a normalizer. Auto-tuning parameters: UseCategoricalSplit = False LightGBM objective=regression diff --git a/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-TrainTest-generatedRegressionDataset.RMSE-rp.txt b/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-TrainTest-generatedRegressionDataset.RMSE-rp.txt index b8f135e448..bf64ad5ed2 100644 --- a/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-TrainTest-generatedRegressionDataset.RMSE-rp.txt +++ b/test/BaselineOutput/Common/LightGBMR/LightGBMRegRmse-TrainTest-generatedRegressionDataset.RMSE-rp.txt @@ -1,4 +1,4 @@ LightGBMR -L1(avg) L2(avg) RMS(avg) Loss-fn(avg) R Squared /iter /lr /nl /mil /v /nt /em Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings -3.428896 25.23601 5.023546 25.23601 0.998616 50 0.2 20 10 + 1 Rmse LightGBMR %Data% %Data% %Output% 99 0 0 maml.exe TrainTest test=%Data% tr=LightGBMR{nt=1 iter=50 em=rmse v=+ lr=0.2 mil=10 nl=20} dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% out=%Output% seed=1 /iter:50;/lr:0.2;/nl:20;/mil:10;/v:+;/nt:1;/em:Rmse +L1(avg) L2(avg) RMS(avg) Loss-fn(avg) R Squared /iter /lr /nl /mil /v /nt Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings +3.428896 25.23601 5.023546 25.23601 0.998616 50 0.2 20 10 + 1 LightGBMR %Data% %Data% %Output% 99 0 0 maml.exe TrainTest test=%Data% tr=LightGBMR{nt=1 iter=50 em=RootMeanSquaredError v=+ lr=0.2 mil=10 nl=20} dout=%Output% loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+} data=%Data% out=%Output% seed=1 /iter:50;/lr:0.2;/nl:20;/mil:10;/v:+;/nt:1 diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index ce141dbafd..70346af906 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -768,7 +768,7 @@ public void TestMulticlassEnsembleCombiner() var predictors = new PredictorModel[] { - LightGbm.TrainMulticlass(Env, new Options + LightGbm.TrainMulticlass(Env, new LightGbmMulticlassTrainer.Options { FeatureColumnName = "Features", NumberOfIterations = 5, diff --git a/test/Microsoft.ML.TestFramework/Learners.cs b/test/Microsoft.ML.TestFramework/Learners.cs index f369264db3..876710da6d 100644 --- a/test/Microsoft.ML.TestFramework/Learners.cs +++ b/test/Microsoft.ML.TestFramework/Learners.cs @@ -333,14 +333,14 @@ static TestLearnersBase() public static PredictorAndArgs LightGBMRegMae = new PredictorAndArgs { - Trainer = new SubComponent("LightGBMR", "nt=1 iter=50 em=mae v=+ lr=0.2 mil=10 nl=20"), + Trainer = new SubComponent("LightGBMR", "nt=1 iter=50 em=MeanAbsoluteError v=+ lr=0.2 mil=10 nl=20"), Tag = "LightGBMRegMae", BaselineProgress = true, }; public static PredictorAndArgs LightGBMRegRmse = new PredictorAndArgs { - Trainer = new SubComponent("LightGBMR", "nt=1 iter=50 em=rmse v=+ lr=0.2 mil=10 nl=20"), + Trainer = new SubComponent("LightGBMR", "nt=1 iter=50 em=RootMeanSquaredError v=+ lr=0.2 mil=10 nl=20"), Tag = "LightGBMRegRmse", BaselineProgress = true, }; diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index ff5f07379b..bc22299fb4 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -677,7 +677,7 @@ private void ExecuteTFTransformMNISTConvTrainingTest(bool shuffle, int? shuffleS batchSize: 20)) .Append(mlContext.Transforms.Concatenate("Features", "Prediction")) .AppendCacheCheckpoint(mlContext) - .Append(mlContext.MulticlassClassification.Trainers.LightGbm(new Trainers.LightGbm.Options() + .Append(mlContext.MulticlassClassification.Trainers.LightGbm(new Trainers.LightGbm.LightGbmMulticlassTrainer.Options() { LabelColumnName = "Label", FeatureColumnName = "Features", diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index d8ec78e585..beaec1a3a3 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -49,7 +49,7 @@ public void LightGBMBinaryEstimator() { var (pipe, dataView) = GetBinaryClassificationPipeline(); - var trainer = ML.BinaryClassification.Trainers.LightGbm(new Options + var trainer = ML.BinaryClassification.Trainers.LightGbm(new LightGbmBinaryTrainer.Options { NumberOfLeaves = 10, NumberOfThreads = 1, @@ -136,7 +136,7 @@ public void LightGBMRankerEstimator() { var (pipe, dataView) = GetRankingPipeline(); - var trainer = ML.Ranking.Trainers.LightGbm(new Options() { LabelColumnName = "Label0", FeatureColumnName = "NumericFeatures", RowGroupColumnName = "Group", LearningRate = 0.4 }); + var trainer = ML.Ranking.Trainers.LightGbm(new LightGbmRankingTrainer.Options() { LabelColumnName = "Label0", FeatureColumnName = "NumericFeatures", RowGroupColumnName = "Group", LearningRate = 0.4 }); var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView); @@ -162,13 +162,13 @@ public void FastTreeRegressorEstimator() } /// - /// LightGbmRegressorTrainer TrainerEstimator test + /// LightGbmRegressionTrainer TrainerEstimator test /// [LightGBMFact] public void LightGBMRegressorEstimator() { var dataView = GetRegressionPipeline(); - var trainer = ML.Regression.Trainers.LightGbm(new Options + var trainer = ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options { NumberOfThreads = 1, NormalizeFeatures = NormalizeOption.Warn, @@ -294,13 +294,14 @@ private void LightGbmHelper(bool useSoftmax, out string modelString, out List Date: Tue, 19 Mar 2019 12:19:49 -0700 Subject: [PATCH 04/18] Hiding of ColumnOptions (#2959) --- .../Dynamic/Normalizer.cs | 3 +- .../Dynamic/TensorFlow/TextClassification.cs | 5 +- .../ImageAnalytics/ConvertToGrayScale.cs | 4 +- .../ImageAnalytics/DnnFeaturizeImage.cs | 2 +- .../ImageAnalytics/ExtractPixels.cs | 2 +- .../Transforms/ImageAnalytics/LoadImages.cs | 2 +- .../Transforms/ImageAnalytics/ResizeImages.cs | 2 +- ...nOptions.cs => VectorWhitenWithOptions.cs} | 5 +- .../Transforms/ReplaceMissingValues.cs | 6 +- .../Dynamic/ValueMapping.cs | 24 +-- .../Dynamic/ValueMappingFloatToString.cs | 22 +-- .../Dynamic/ValueMappingStringToArray.cs | 22 +-- .../Dynamic/ValueMappingStringToKeyType.cs | 24 +-- .../ConversionsExtensionsCatalog.cs | 185 ++++++++++++++---- .../Transforms/ExtensionsCatalog.cs | 14 +- src/Microsoft.ML.Data/Transforms/Hashing.cs | 3 +- .../Transforms/KeyToVector.cs | 3 +- .../Transforms/Normalizer.cs | 23 ++- .../Transforms/TypeConverting.cs | 5 +- .../Transforms/ValueToKeyMappingEstimator.cs | 7 +- .../ExtensionsCatalog.cs | 52 ++++- .../ImagePixelExtractor.cs | 3 +- .../ImageResizer.cs | 3 +- .../VectorToImageTransform.cs | 3 +- .../MklComponentsCatalog.cs | 7 +- .../VectorWhitening.cs | 3 +- src/Microsoft.ML.PCA/PCACatalog.cs | 3 +- src/Microsoft.ML.PCA/PcaTransformer.cs | 3 +- .../TransformsStatic.cs | 16 +- .../CategoricalCatalog.cs | 39 ++-- .../ConversionsCatalog.cs | 3 +- .../CountFeatureSelection.cs | 3 +- .../ExtensionsCatalog.cs | 20 +- .../FeatureSelectionCatalog.cs | 6 +- .../HashJoiningTransform.cs | 3 +- src/Microsoft.ML.Transforms/KernelCatalog.cs | 12 +- .../MissingValueHandlingTransformer.cs | 4 +- .../MissingValueReplacing.cs | 59 +++--- .../NormalizerCatalog.cs | 12 +- src/Microsoft.ML.Transforms/OneHotEncoding.cs | 3 +- .../OneHotHashEncoding.cs | 3 +- .../Properties/AssemblyInfo.cs | 1 + .../RandomFourierFeaturizing.cs | 3 +- .../Text/LdaTransform.cs | 3 +- .../Text/NgramHashingTransformer.cs | 3 +- .../Text/NgramTransform.cs | 3 +- .../Text/StopWordsRemovingTransformer.cs | 3 +- .../Text/TextCatalog.cs | 94 ++++++--- .../Text/WordEmbeddingsExtractor.cs | 3 +- .../Text/WordTokenizing.cs | 3 +- .../Debugging.cs | 2 - .../TensorflowTests.cs | 3 +- .../Transformers/CategoricalTests.cs | 8 +- .../Transformers/NAReplaceTests.cs | 24 +-- .../Transformers/ValueMappingTests.cs | 32 +-- 55 files changed, 525 insertions(+), 283 deletions(-) rename docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/{VectorWhitenWithColumnOptions.cs => VectorWhitenWithOptions.cs} (90%) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs index c5421ac305..55f3c89845 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs @@ -58,7 +58,8 @@ public static void Example() // Composing a different pipeline if we wanted to normalize more than one column at a time. // Using log scale as the normalization mode. - var multiColPipeline = ml.Transforms.Normalize(NormalizingEstimator.NormalizationMode.LogMeanVariance, new ColumnOptions[] { ("LogInduced", "Induced"), ("LogSpontaneous", "Spontaneous") }); + var multiColPipeline = ml.Transforms.Normalize("LogInduced", "Induced", NormalizingEstimator.NormalizationMode.LogMeanVariance) + .Append(ml.Transforms.Normalize("LogSpontaneous", "Spontaneous", NormalizingEstimator.NormalizationMode.LogMeanVariance)); // The transformed data. var multiColtransformer = multiColPipeline.Fit(trainData); var multiColtransformedData = multiColtransformer.Transform(trainData); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs index 30310b755b..3fe3f169ed 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs @@ -69,10 +69,11 @@ public static void Example() }; var model = mlContext.Transforms.Text.TokenizeIntoWords("TokenizedWords", "Sentiment_Text") - .Append(mlContext.Transforms.Conversion.MapValue(lookupMap, "Words", "Ids", new ColumnOptions[] { ("VariableLenghtFeatures", "TokenizedWords") })) + .Append(mlContext.Transforms.Conversion.MapValue("VariableLenghtFeatures", lookupMap, + lookupMap.Schema["Words"], lookupMap.Schema["Ids"], "TokenizedWords")) .Append(mlContext.Transforms.CustomMapping(ResizeFeaturesAction, "Resize")) .Append(tensorFlowModel.ScoreTensorFlowModel(new[] { "Prediction/Softmax" }, new[] { "Features" })) - .Append(mlContext.Transforms.CopyColumns(("Prediction", "Prediction/Softmax"))) + .Append(mlContext.Transforms.CopyColumns("Prediction", "Prediction/Softmax")) .Fit(dataView); var engine = mlContext.Model.CreatePredictionEngine(model); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ConvertToGrayScale.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ConvertToGrayScale.cs index 1a1dcb7a54..d5cba7120b 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ConvertToGrayScale.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ConvertToGrayScale.cs @@ -36,8 +36,8 @@ public static void Example() var imagesFolder = Path.GetDirectoryName(imagesDataFile); // Image loading pipeline. - var pipeline = mlContext.Transforms.LoadImages(imagesFolder, ("ImageObject", "ImagePath")) - .Append(mlContext.Transforms.ConvertToGrayscale(("Grayscale", "ImageObject"))); + var pipeline = mlContext.Transforms.LoadImages(imagesFolder, "ImageObject", "ImagePath") + .Append(mlContext.Transforms.ConvertToGrayscale("Grayscale", "ImageObject")); var transformedData = pipeline.Fit(data).Transform(data); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/DnnFeaturizeImage.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/DnnFeaturizeImage.cs index 7f3e5d3c62..af69a3578c 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/DnnFeaturizeImage.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/DnnFeaturizeImage.cs @@ -40,7 +40,7 @@ public static void Example() // Installing the Microsoft.ML.DNNImageFeaturizer packages copies the models in the // `DnnImageModels` folder. // Image loading pipeline. - var pipeline = mlContext.Transforms.LoadImages(imagesFolder, ("ImageObject", "ImagePath")) + var pipeline = mlContext.Transforms.LoadImages(imagesFolder, "ImageObject", "ImagePath") .Append(mlContext.Transforms.ResizeImages("ImageObject", imageWidth: 224, imageHeight: 224)) .Append(mlContext.Transforms.ExtractPixels("Pixels", "ImageObject")) .Append(mlContext.Transforms.DnnFeaturizeImage("FeaturizedImage", m => m.ModelSelector.ResNet18(mlContext, m.OutputColumn, m.InputColumn), "Pixels")); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ExtractPixels.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ExtractPixels.cs index da6c583e13..188e36ca15 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ExtractPixels.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ExtractPixels.cs @@ -37,7 +37,7 @@ public static void Example() var imagesFolder = Path.GetDirectoryName(imagesDataFile); // Image loading pipeline. - var pipeline = mlContext.Transforms.LoadImages(imagesFolder, ("ImageObject", "ImagePath")) + var pipeline = mlContext.Transforms.LoadImages(imagesFolder, "ImageObject", "ImagePath") .Append(mlContext.Transforms.ResizeImages("ImageObject", imageWidth: 100, imageHeight: 100 )) .Append(mlContext.Transforms.ExtractPixels("Pixels", "ImageObject")); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/LoadImages.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/LoadImages.cs index 80404e3ae7..f6fb4cae29 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/LoadImages.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/LoadImages.cs @@ -36,7 +36,7 @@ public static void Example() var imagesFolder = Path.GetDirectoryName(imagesDataFile); // Image loading pipeline. - var pipeline = mlContext.Transforms.LoadImages(imagesFolder, ("ImageReal", "ImagePath")); + var pipeline = mlContext.Transforms.LoadImages(imagesFolder, "ImageReal", "ImagePath"); var transformedData = pipeline.Fit(data).Transform(data); // The transformedData IDataView contains the loaded images now diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ResizeImages.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ResizeImages.cs index b792aa9a8e..ca0d642e14 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ResizeImages.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ResizeImages.cs @@ -36,7 +36,7 @@ public static void Example() var imagesFolder = Path.GetDirectoryName(imagesDataFile); // Image loading pipeline. - var pipeline = mlContext.Transforms.LoadImages(imagesFolder, ("ImageReal", "ImagePath")) + var pipeline = mlContext.Transforms.LoadImages(imagesFolder, "ImageReal", "ImagePath") .Append(mlContext.Transforms.ResizeImages("ImageReal", imageWidth: 100, imageHeight: 100)); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithColumnOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithOptions.cs similarity index 90% rename from docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithColumnOptions.cs rename to docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithOptions.cs index eaec189bc4..bf314064e1 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithColumnOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Projection/VectorWhitenWithOptions.cs @@ -5,7 +5,7 @@ namespace Microsoft.ML.Samples.Dynamic { - public sealed class VectorWhitenWithColumnOptions + public sealed class VectorWhitenWithOptions { /// This example requires installation of additional nuget package Microsoft.ML.Mkl.Components. public static void Example() @@ -39,8 +39,7 @@ public static void Example() // A pipeline to project Features column into white noise vector. - var whiteningPipeline = ml.Transforms.VectorWhiten(new Transforms.VectorWhiteningEstimator.ColumnOptions( - nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features), kind: Transforms.WhiteningKind.PrincipalComponentAnalysis, rank: 4)); + var whiteningPipeline = ml.Transforms.VectorWhiten(nameof(SamplesUtils.DatasetUtils.SampleVectorOfNumbersData.Features), kind: Transforms.WhiteningKind.PrincipalComponentAnalysis, rank: 4); // The transformed (projected) data. var transformedData = whiteningPipeline.Fit(trainData).Transform(trainData); // Getting the data of the newly created column, so we can preview it. diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ReplaceMissingValues.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ReplaceMissingValues.cs index 1bcc4ef5f5..de241dceda 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ReplaceMissingValues.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ReplaceMissingValues.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Linq; using Microsoft.ML.Data; -using static Microsoft.ML.Transforms.MissingValueReplacingEstimator.ColumnOptions; +using Microsoft.ML.Transforms; namespace Microsoft.ML.Samples.Dynamic { @@ -25,7 +25,7 @@ public static void Example() var data = mlContext.Data.LoadFromEnumerable(samples); // ReplaceMissingValues is used to create a column where missing values are replaced according to the ReplacementMode. - var meanPipeline = mlContext.Transforms.ReplaceMissingValues("MissingReplaced", "Features", ReplacementMode.Mean); + var meanPipeline = mlContext.Transforms.ReplaceMissingValues("MissingReplaced", "Features", MissingValueReplacingEstimator.ReplacementMode.Mean); // Now we can transform the data and look at the output to confirm the behavior of the estimator. // This operation doesn't actually evaluate data until we read the data below. @@ -36,7 +36,7 @@ public static void Example() var meanRowEnumerable = mlContext.Data.CreateEnumerable(meanTransformedData, reuseRowObject: false); // ReplaceMissingValues is used to create a column where missing values are replaced according to the ReplacementMode. - var defaultPipeline = mlContext.Transforms.ReplaceMissingValues("MissingReplaced", "Features", ReplacementMode.DefaultValue); + var defaultPipeline = mlContext.Transforms.ReplaceMissingValues("MissingReplaced", "Features", MissingValueReplacingEstimator.ReplacementMode.DefaultValue); // Now we can transform the data and look at the output to confirm the behavior of the estimator. // This operation doesn't actually evaluate data until we read the data below. diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMapping.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMapping.cs index d84d69db6f..2df356760a 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMapping.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMapping.cs @@ -37,24 +37,14 @@ public static void Example() // 35.0 1.0 6-11yrs 1.0 3.0 32.0 5.0 ... // If the list of keys and values are known, they can be passed to the API. The ValueMappingEstimator can also get the mapping through an IDataView - // Creating a list of keys based on the Education values from the dataset. - var educationKeys = new List() - { - "0-5yrs", - "6-11yrs", - "12+yrs" - }; - - // Creating a list of associated values that will map respectively to each educationKey - var educationValues = new List() - { - "Undergraduate", - "Postgraduate", - "Postgraduate" - }; - + // Creating a list of key-value pairs based on the Education values from the dataset. + var educationMap = new Dictionary (); + educationMap["0-5yrs"] = "Undergraduate"; + educationMap["6-11yrs"] = "Postgraduate"; + educationMap["12+yrs"] = "Postgraduate"; + // Constructs the ValueMappingEstimator making the ML.net pipeline - var pipeline = mlContext.Transforms.Conversion.MapValue(educationKeys, educationValues, ("EducationCategory", "Education")); + var pipeline = mlContext.Transforms.Conversion.MapValue("EducationCategory", educationMap, "Education"); // Fits the ValueMappingEstimator and transforms the data converting the Education to EducationCategory. IDataView transformedData = pipeline.Fit(trainData).Transform(trainData); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingFloatToString.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingFloatToString.cs index c32d7efdd5..5cf34572ba 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingFloatToString.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingFloatToString.cs @@ -28,24 +28,14 @@ public static void Example() IDataView trainData = mlContext.Data.LoadFromEnumerable(data); // If the list of keys and values are known, they can be passed to the API. The ValueMappingEstimator can also get the mapping through an IDataView - // Creating a list of keys based on the induced value from the dataset - var temperatureKeys = new List() - { - 36.0f, - 35.0f, - 34.0f - }; - - // Creating a list of values, these strings will map accordingly to each key. - var classificationValues = new List() - { - "T1", - "T2", - "T3" - }; + // Creating a list of key-value pairs based on the induced value from the dataset + var temperatureMap = new Dictionary(); + temperatureMap[36.0f] = "T1"; + temperatureMap[35.0f] = "T2"; + temperatureMap[34.0f] = "T3"; // Constructs the ValueMappingEstimator making the ML.net pipeline - var pipeline = mlContext.Transforms.Conversion.MapValue(temperatureKeys, classificationValues, ("TemperatureCategory", "Temperature")); + var pipeline = mlContext.Transforms.Conversion.MapValue("TemperatureCategory", temperatureMap, "Temperature"); // Fits the ValueMappingEstimator and transforms the data adding the TemperatureCategory column. IDataView transformedData = pipeline.Fit(trainData).Transform(trainData); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingStringToArray.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingStringToArray.cs index cfafe4c336..f008d559d8 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingStringToArray.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingStringToArray.cs @@ -31,24 +31,14 @@ public static void Example() IDataView trainData = mlContext.Data.LoadFromEnumerable(data); // If the list of keys and values are known, they can be passed to the API. The ValueMappingEstimator can also get the mapping through an IDataView - // Creating a list of keys based on the Education values from the dataset - var educationKeys = new List() - { - "0-5yrs", - "6-11yrs", - "12+yrs" - }; - - // Sample list of associated array values - var educationValues = new List() - { - new int[] { 1,2,3 }, - new int[] { 5,6,7 }, - new int[] { 42,32,64 } - }; + // Creating a list of key-value pairs based on the Education values from the dataset + var educationMap = new Dictionary(); + educationMap["0-5yrs"] = new int[] { 1, 2, 3 }; + educationMap["6-11yrs"] = new int[] { 5, 6, 7 }; + educationMap["12+yrs"] = new int[] { 42, 32, 64 }; // Constructs the ValueMappingEstimator making the ML.net pipeline - var pipeline = mlContext.Transforms.Conversion.MapValue(educationKeys, educationValues, ("EducationFeature", "Education")); + var pipeline = mlContext.Transforms.Conversion.MapValue("EducationFeature", educationMap, "Education"); // Fits the ValueMappingEstimator and transforms the data adding the EducationFeature column. IDataView transformedData = pipeline.Fit(trainData).Transform(trainData); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingStringToKeyType.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingStringToKeyType.cs index 11cffba54c..8c01d35e78 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingStringToKeyType.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingStringToKeyType.cs @@ -34,28 +34,18 @@ public static void Example() IEnumerable data = SamplesUtils.DatasetUtils.GetInfertData(); IDataView trainData = mlContext.Data.LoadFromEnumerable(data); - // Creating a list of keys based on the Education values from the dataset + // Creating a list of key-value pairs based on the Education values from the dataset // These lists are created by hand for the demonstration, but the ValueMappingEstimator does take an IEnumerable. - var educationKeys = new List() - { - "0-5yrs", - "6-11yrs", - "12+yrs" - }; - - // Creating a list of values that are sample strings. These will be converted to KeyTypes - var educationValues = new List() - { - "Undergraduate", - "Postgraduate", - "Postgraduate" - }; + var educationMap = new Dictionary(); + educationMap["0-5yrs"] = "Undergraduate"; + educationMap["6-11yrs"] = "Postgraduate"; + educationMap["12+yrs"] = "Postgraduate"; // Generate the ValueMappingEstimator that will output KeyTypes even though our values are strings. // The KeyToValueMappingEstimator is added to provide a reverse lookup of the KeyType, converting the KeyType value back // to the original value. - var pipeline = mlContext.Transforms.Conversion.MapValue(educationKeys, educationValues, true, ("EducationKeyType", "Education")) - .Append(mlContext.Transforms.Conversion.MapKeyToValue(("EducationCategory", "EducationKeyType"))); + var pipeline = mlContext.Transforms.Conversion.MapValue("EducationKeyType", educationMap, "Education", true) + .Append(mlContext.Transforms.Conversion.MapKeyToValue("EducationCategory", "EducationKeyType")); // Fits the ValueMappingEstimator and transforms the data adding the EducationKeyType column. IDataView transformedData = pipeline.Fit(trainData).Transform(trainData); diff --git a/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs b/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs index 2aa61a1b14..80f24389fb 100644 --- a/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; +using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.Transforms; @@ -36,7 +37,8 @@ public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms /// /// The conversion transform's catalog. /// Description of dataset columns and how to process them. - public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms catalog, params HashingEstimator.ColumnOptions[] columns) + [BestFriend] + internal static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms catalog, params HashingEstimator.ColumnOptions[] columns) => new HashingEstimator(CatalogUtils.GetEnvironment(catalog), columns); /// @@ -54,14 +56,15 @@ public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms /// public static TypeConvertingEstimator ConvertType(this TransformsCatalog.ConversionTransforms catalog, string outputColumnName, string inputColumnName = null, DataKind outputKind = ConvertDefaults.DefaultOutputKind) - => new TypeConvertingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, outputKind); + => new TypeConvertingEstimator(CatalogUtils.GetEnvironment(catalog), new[] { new TypeConvertingEstimator.ColumnOptions(outputColumnName, outputKind, inputColumnName) }); /// /// Changes column type of the input column. /// /// The conversion transform's catalog. /// Description of dataset columns and how to process them. - public static TypeConvertingEstimator ConvertType(this TransformsCatalog.ConversionTransforms catalog, params TypeConvertingEstimator.ColumnOptions[] columns) + [BestFriend] + internal static TypeConvertingEstimator ConvertType(this TransformsCatalog.ConversionTransforms catalog, params TypeConvertingEstimator.ColumnOptions[] columns) => new TypeConvertingEstimator(CatalogUtils.GetEnvironment(catalog), columns); /// @@ -91,7 +94,8 @@ public static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.Co /// [!code-csharp[KeyToValueMappingEstimator](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingStringToKeyType.cs)] /// ]]> /// - public static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.ConversionTransforms catalog, params ColumnOptions[] columns) + [BestFriend] + internal static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.ConversionTransforms catalog, params ColumnOptions[] columns) => new KeyToValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), ColumnOptions.ConvertToValueTuples(columns)); /// @@ -99,7 +103,8 @@ public static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.Co /// /// The conversion transform's catalog. /// The input column to map back to vectors. - public static KeyToVectorMappingEstimator MapKeyToVector(this TransformsCatalog.ConversionTransforms catalog, + [BestFriend] + internal static KeyToVectorMappingEstimator MapKeyToVector(this TransformsCatalog.ConversionTransforms catalog, params KeyToVectorMappingEstimator.ColumnOptions[] columns) => new KeyToVectorMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns); @@ -124,6 +129,10 @@ public static KeyToVectorMappingEstimator MapKeyToVector(this TransformsCatalog. /// Maximum number of keys to keep per column when auto-training. /// How items should be ordered when vectorized. If choosen they will be in the order encountered. /// If , items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). + /// Whether key value annotations should be text, regardless of the actual input type. + /// The data view containing the terms. If specified, this should be a single column data + /// view, and the key-values will be taken from that column. If unspecified, the key-values will be determined + /// from the input data upon fitting. /// /// /// new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, maximumNumberOfKeys, keyOrdinality); + ValueToKeyMappingEstimator.KeyOrdinality keyOrdinality = ValueToKeyMappingEstimator.Defaults.Ordinality, + bool addKeyValueAnnotationsAsText = ValueToKeyMappingEstimator.Defaults.AddKeyValueAnnotationsAsText, + IDataView keyData = null) + => new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), + new[] { new ValueToKeyMappingEstimator.ColumnOptions(outputColumnName, inputColumnName, maximumNumberOfKeys, keyOrdinality, addKeyValueAnnotationsAsText) }, keyData); /// /// Converts value types into , optionally loading the keys to use from . @@ -153,7 +165,8 @@ public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.Co /// ]]> /// /// - public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.ConversionTransforms catalog, + [BestFriend] + internal static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.ConversionTransforms catalog, ValueToKeyMappingEstimator.ColumnOptions[] columns, IDataView keyData = null) => new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns, keyData); @@ -163,10 +176,10 @@ public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.Co /// The key type. /// The value type. /// The conversion transform's catalog - /// The list of keys to use for the mapping. The mapping is 1-1 with . The length of this list must be the same length as and - /// cannot contain duplicate keys. - /// The list of values to pair with the keys for the mapping. The length of this list must be equal to the same length as . - /// The columns to apply this transform on. + /// Name of the column resulting from the transformation of . + /// Specifies the mapping that will be perfomed. The keys will be mapped to the values as specified in the . + /// Name of the column to transform. If set to , the value of the will be used as source. + /// Whether to treat the values as a . /// An instance of the /// /// @@ -179,10 +192,45 @@ public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.Co /// public static ValueMappingEstimator MapValue( this TransformsCatalog.ConversionTransforms catalog, - IEnumerable keys, - IEnumerable values, + string outputColumnName, + IEnumerable> keyValuePairs, + string inputColumnName = null, + bool treatValuesAsKeyType = false) + { + var keys = keyValuePairs.Select(pair => pair.Key); + var values = keyValuePairs.Select(pair => pair.Value); + return new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), keys, values, treatValuesAsKeyType, + new[] { (outputColumnName, inputColumnName ?? outputColumnName) }); + } + + /// + /// + /// + /// The key type. + /// The value type. + /// The conversion transform's catalog + /// Specifies the mapping that will be perfomed. The keys will be mapped to the values as specified in the . + /// The columns to apply this transform on. + /// An instance of the + /// + /// + /// + /// + [BestFriend] + internal static ValueMappingEstimator MapValue( + this TransformsCatalog.ConversionTransforms catalog, + IEnumerable> keyValuePairs, params ColumnOptions[] columns) - => new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), keys, values, ColumnOptions.ConvertToValueTuples(columns)); + { + var keys = keyValuePairs.Select(pair => pair.Key); + var values = keyValuePairs.Select(pair => pair.Value); + return new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), keys, values, ColumnOptions.ConvertToValueTuples(columns)); + } /// /// @@ -190,9 +238,7 @@ public static ValueMappingEstimator MapValueThe key type. /// The value type. /// The conversion transform's catalog - /// The list of keys to use for the mapping. The mapping is 1-1 with . The length of this list must be the same length as and - /// cannot contain duplicate keys. - /// The list of values to pair with the keys for the mapping. The length of this list must be equal to the same length as . + /// Specifies the mapping that will be perfomed. The keys will be mapped to the values as specified in the . /// Whether to treat the values as a . /// The columns to apply this transform on. /// An instance of the @@ -202,14 +248,18 @@ public static ValueMappingEstimator MapValue /// - public static ValueMappingEstimator MapValue( + [BestFriend] + internal static ValueMappingEstimator MapValue( this TransformsCatalog.ConversionTransforms catalog, - IEnumerable keys, - IEnumerable values, + IEnumerable> keyValuePairs, bool treatValuesAsKeyType, params ColumnOptions[] columns) - => new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), keys, values, treatValuesAsKeyType, - ColumnOptions.ConvertToValueTuples(columns)); + { + var keys = keyValuePairs.Select(pair => pair.Key); + var values = keyValuePairs.Select(pair => pair.Value); + return new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), keys, values, treatValuesAsKeyType, + ColumnOptions.ConvertToValueTuples(columns)); + } /// /// @@ -217,10 +267,9 @@ public static ValueMappingEstimator MapValueThe key type. /// The value type. /// The conversion transform's catalog - /// The list of keys to use for the mapping. The mapping is 1-1 with . The length of this list must be the same length as and - /// cannot contain duplicate keys. - /// The list of values to pair with the keys for the mapping of TOutputType[]. The length of this list must be equal to the same length as . - /// The columns to apply this transform on. + /// Name of the column resulting from the transformation of . + /// Specifies the mapping that will be perfomed. The keys will be mapped to the values as specified in the . + /// Name of the column to transform. If set to , the value of the will be used as source. /// An instance of the /// /// @@ -233,20 +282,55 @@ public static ValueMappingEstimator MapValue public static ValueMappingEstimator MapValue( this TransformsCatalog.ConversionTransforms catalog, - IEnumerable keys, - IEnumerable values, - params ColumnOptions[] columns) - => new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), keys, values, - ColumnOptions.ConvertToValueTuples(columns)); + string outputColumnName, + IEnumerable> keyValuePairs, + string inputColumnName = null) + { + var keys = keyValuePairs.Select(pair => pair.Key); + var values = keyValuePairs.Select(pair => pair.Value); + return new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), keys, values, + new[] { (outputColumnName, inputColumnName ?? outputColumnName) }); + } /// /// /// + /// The key type. + /// The value type. /// The conversion transform's catalog - /// An instance of that contains the key and value columns. - /// Name of the key column in . - /// Name of the value column in . + /// Specifies the mapping that will be perfomed. The keys will be mapped to the values as specified in the . /// The columns to apply this transform on. + /// An instance of the + /// + /// + /// + /// + [BestFriend] + internal static ValueMappingEstimator MapValue( + this TransformsCatalog.ConversionTransforms catalog, + IEnumerable> keyValuePairs, + params ColumnOptions[] columns) + { + var keys = keyValuePairs.Select(pair => pair.Key); + var values = keyValuePairs.Select(pair => pair.Value); + return new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), keys, values, + ColumnOptions.ConvertToValueTuples(columns)); + } + + /// + /// + /// + /// The conversion transform's catalog + /// Name of the column resulting from the transformation of . + /// An instance of that contains the and columns. + /// The key column in . + /// The value column in . + /// Name of the column to transform. If set to , the value of the will be used as source. /// A instance of the ValueMappingEstimator /// /// @@ -259,8 +343,35 @@ public static ValueMappingEstimator MapValue public static ValueMappingEstimator MapValue( this TransformsCatalog.ConversionTransforms catalog, - IDataView lookupMap, string keyColumnName, string valueColumnName, params ColumnOptions[] columns) - => new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), lookupMap, keyColumnName, valueColumnName, + string outputColumnName, IDataView lookupMap, DataViewSchema.Column keyColumn, DataViewSchema.Column valueColumn, string inputColumnName = null) + { + return new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), lookupMap, keyColumn.Name, valueColumn.Name, + new[] { (outputColumnName, inputColumnName ?? outputColumnName) }); + } + + /// + /// + /// + /// The conversion transform's catalog + /// An instance of that contains the and columns. + /// The key column in . + /// The value column in . + /// The columns to apply this transform on. + /// A instance of the ValueMappingEstimator + /// + /// + /// + /// + [BestFriend] + internal static ValueMappingEstimator MapValue( + this TransformsCatalog.ConversionTransforms catalog, + IDataView lookupMap, DataViewSchema.Column keyColumn, DataViewSchema.Column valueColumn, params ColumnOptions[] columns) + => new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), lookupMap, keyColumn.Name, valueColumn.Name, ColumnOptions.ConvertToValueTuples(columns)); } } diff --git a/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs b/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs index 1972eb9bb2..c4b97d0dea 100644 --- a/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs @@ -11,7 +11,8 @@ namespace Microsoft.ML /// /// Specifies input and output column names for a transformation. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { private readonly string _outputColumnName; private readonly string _inputColumnName; @@ -19,12 +20,12 @@ public sealed class ColumnOptions /// /// Specifies input and output column names for a transformation. /// - /// Name of output column resulting from the transformation of . - /// Name of input column. - public ColumnOptions(string outputColumnName, string inputColumnName) + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. + public ColumnOptions(string outputColumnName, string inputColumnName = null) { _outputColumnName = outputColumnName; - _inputColumnName = inputColumnName; + _inputColumnName = inputColumnName ?? outputColumnName; } /// @@ -76,7 +77,8 @@ public static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog, /// ]]> /// /// - public static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog, params ColumnOptions[] columns) + [BestFriend] + internal static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog, params ColumnOptions[] columns) => new ColumnCopyingEstimator(CatalogUtils.GetEnvironment(catalog), ColumnOptions.ConvertToValueTuples(columns)); /// diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 0a53c8d6fd..93a6207609 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -1123,7 +1123,8 @@ internal static class Defaults /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// /// Name of the column resulting from the transformation of . diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs index dfa8607160..18ff259b4b 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs @@ -735,7 +735,8 @@ internal static class Defaults /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// Name of the column resulting from the transformation of . public readonly string Name; diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index 27d6700341..100a65609d 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -63,7 +63,8 @@ public enum NormalizationMode SupervisedBinning = 4 } - public abstract class ColumnOptionsBase + [BestFriend] + internal abstract class ColumnOptionsBase { public readonly string Name; public readonly string InputColumnName; @@ -102,7 +103,7 @@ internal static ColumnOptionsBase Create(string outputColumnName, string inputCo } } - public abstract class ControlZeroColumnOptionsBase : ColumnOptionsBase + internal abstract class ControlZeroColumnOptionsBase : ColumnOptionsBase { public readonly bool EnsureZeroUntouched; @@ -113,7 +114,8 @@ private protected ControlZeroColumnOptionsBase(string outputColumnName, string i } } - public sealed class MinMaxColumnOptions : ControlZeroColumnOptionsBase + [BestFriend] + internal sealed class MinMaxColumnOptions : ControlZeroColumnOptionsBase { public MinMaxColumnOptions(string outputColumnName, string inputColumnName = null, long maximumExampleCount = Defaults.MaximumExampleCount, bool ensureZeroUntouched = Defaults.EnsureZeroUntouched) : base(outputColumnName, inputColumnName ?? outputColumnName, maximumExampleCount, ensureZeroUntouched) @@ -124,7 +126,8 @@ internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, D => NormalizeTransform.MinMaxUtils.CreateBuilder(this, host, srcIndex, srcType, cursor); } - public sealed class MeanVarianceColumnOptions : ControlZeroColumnOptionsBase + [BestFriend] + internal sealed class MeanVarianceColumnOptions : ControlZeroColumnOptionsBase { public readonly bool UseCdf; @@ -139,7 +142,8 @@ internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, D => NormalizeTransform.MeanVarUtils.CreateBuilder(this, host, srcIndex, srcType, cursor); } - public sealed class LogMeanVarianceColumnOptions : ColumnOptionsBase + [BestFriend] + internal sealed class LogMeanVarianceColumnOptions : ColumnOptionsBase { public readonly bool UseCdf; @@ -154,7 +158,8 @@ internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, D => NormalizeTransform.LogMeanVarUtils.CreateBuilder(this, host, srcIndex, srcType, cursor); } - public sealed class BinningColumnOptions : ControlZeroColumnOptionsBase + [BestFriend] + internal sealed class BinningColumnOptions : ControlZeroColumnOptionsBase { public readonly int MaximumBinCount; @@ -169,7 +174,8 @@ internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, D => NormalizeTransform.BinUtils.CreateBuilder(this, host, srcIndex, srcType, cursor); } - public sealed class SupervisedBinningColumOptions : ControlZeroColumnOptionsBase + [BestFriend] + internal sealed class SupervisedBinningColumOptions : ControlZeroColumnOptionsBase { public readonly int MaximumBinCount; public readonly string LabelColumnName; @@ -308,7 +314,8 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(NormalizingTransformer).Assembly.FullName); } - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { public readonly string Name; public readonly string InputColumnName; diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index 75df8a6475..95e08aa23d 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -172,7 +172,7 @@ private static VersionInfo GetVersionInfo() /// /// A collection of describing the settings of the transformation. /// - public IReadOnlyCollection Columns => _columns.AsReadOnly(); + internal IReadOnlyCollection Columns => _columns.AsReadOnly(); private readonly TypeConvertingEstimator.ColumnOptions[] _columns; @@ -526,7 +526,8 @@ internal sealed class Defaults /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// /// Name of the column resulting from the transformation of . diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs index cf2946e1e5..4775f2fe1f 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs @@ -17,6 +17,7 @@ internal static class Defaults { public const int MaximumNumberOfKeys = 1000000; public const KeyOrdinality Ordinality = KeyOrdinality.ByOccurrence; + public const bool AddKeyValueAnnotationsAsText = false; } /// @@ -40,7 +41,8 @@ public enum KeyOrdinality : byte /// /// Describes how the transformer handles one column pair. /// - public abstract class ColumnOptionsBase + [BestFriend] + internal abstract class ColumnOptionsBase { public readonly string OutputColumnName; public readonly string InputColumnName; @@ -70,7 +72,8 @@ private protected ColumnOptionsBase(string outputColumnName, string inputColumnN /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions : ColumnOptionsBase + [BestFriend] + internal sealed class ColumnOptions : ColumnOptionsBase { /// /// Describes how the transformer handles one column pair. diff --git a/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs b/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs index 6868e82c63..dfcac547fb 100644 --- a/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs @@ -9,6 +9,19 @@ namespace Microsoft.ML { public static class ImageEstimatorsCatalog { + /// + /// The transform's catalog. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. + /// + /// + /// + /// + public static ImageGrayscalingEstimator ConvertToGrayscale(this TransformsCatalog catalog, string outputColumnName, string inputColumnName = null) + => new ImageGrayscalingEstimator(CatalogUtils.GetEnvironment(catalog), new[] { (outputColumnName, inputColumnName ?? outputColumnName) }); + /// /// The transform's catalog. /// Specifies the names of the input columns for the transformation, and their respective output column names. @@ -18,9 +31,34 @@ public static class ImageEstimatorsCatalog /// [!code-csharp[ConvertToGrayscale](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ConvertToGrayscale.cs)] /// ]]> /// - public static ImageGrayscalingEstimator ConvertToGrayscale(this TransformsCatalog catalog, params ColumnOptions[] columns) + [BestFriend] + internal static ImageGrayscalingEstimator ConvertToGrayscale(this TransformsCatalog catalog, params ColumnOptions[] columns) => new ImageGrayscalingEstimator(CatalogUtils.GetEnvironment(catalog), ColumnOptions.ConvertToValueTuples(columns)); + /// + /// Loads the images from the into memory. + /// + /// + /// The image get loaded in memory as a type. + /// Loading is the first step of almost every pipeline that does image processing, and further analysis on images. + /// The images to load need to be in the formats supported by . + /// For end-to-end image processing pipelines, and scenarios in your applications, see the + /// examples in the machinelearning-samples github repository. + /// + /// + /// The transform's catalog. + /// Name of the column resulting from the transformation of . + /// The images folder. + /// Name of the column to transform. If set to , the value of the will be used as source. + /// + /// + /// + /// + public static ImageLoadingEstimator LoadImages(this TransformsCatalog catalog, string outputColumnName, string imageFolder, string inputColumnName = null) + => new ImageLoadingEstimator(CatalogUtils.GetEnvironment(catalog), imageFolder, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }); + /// /// Loads the images from the into memory. /// @@ -41,7 +79,8 @@ public static ImageGrayscalingEstimator ConvertToGrayscale(this TransformsCatalo /// [!code-csharp[LoadImages](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/LoadImages.cs)] /// ]]> /// - public static ImageLoadingEstimator LoadImages(this TransformsCatalog catalog, string imageFolder, params ColumnOptions[] columns) + [BestFriend] + internal static ImageLoadingEstimator LoadImages(this TransformsCatalog catalog, string imageFolder, params ColumnOptions[] columns) => new ImageLoadingEstimator(CatalogUtils.GetEnvironment(catalog), imageFolder, ColumnOptions.ConvertToValueTuples(columns)); /// @@ -75,7 +114,8 @@ public static ImagePixelExtractingEstimator ExtractPixels(this TransformsCatalog /// /// The transform's catalog. /// The describing how the transform handles each image pixel extraction output input column pair. - public static ImagePixelExtractingEstimator ExtractPixels(this TransformsCatalog catalog, params ImagePixelExtractingEstimator.ColumnOptions[] columnOptions) + [BestFriend] + internal static ImagePixelExtractingEstimator ExtractPixels(this TransformsCatalog catalog, params ImagePixelExtractingEstimator.ColumnOptions[] columnOptions) => new ImagePixelExtractingEstimator(CatalogUtils.GetEnvironment(catalog), columnOptions); /// @@ -133,7 +173,8 @@ public static ImageResizingEstimator ResizeImages(this TransformsCatalog catalog /// [!code-csharp[ResizeImages](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ResizeImages.cs)] /// ]]> /// - public static ImageResizingEstimator ResizeImages(this TransformsCatalog catalog, params ImageResizingEstimator.ColumnOptions[] columnOptions) + [BestFriend] + internal static ImageResizingEstimator ResizeImages(this TransformsCatalog catalog, params ImageResizingEstimator.ColumnOptions[] columnOptions) => new ImageResizingEstimator(CatalogUtils.GetEnvironment(catalog), columnOptions); /// @@ -141,7 +182,8 @@ public static ImageResizingEstimator ResizeImages(this TransformsCatalog catalog /// /// The transform's catalog. /// The describing how the transform handles each vector to image conversion column pair. - public static VectorToImageConvertingEstimator ConvertToImage(this TransformsCatalog catalog, params VectorToImageConvertingEstimator.ColumnOptions[] columnOptions) + [BestFriend] + internal static VectorToImageConvertingEstimator ConvertToImage(this TransformsCatalog catalog, params VectorToImageConvertingEstimator.ColumnOptions[] columnOptions) => new VectorToImageConvertingEstimator(CatalogUtils.GetEnvironment(catalog), columnOptions); /// diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs index 2fa255bc1e..533918b336 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs @@ -562,7 +562,8 @@ internal static void GetOrder(ColorsOrder order, ColorBits colors, out int a, ou /// /// Describes how the transformer handles one image pixel extraction column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// Name of the column resulting from the transformation of . public readonly string Name; diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs index f73f541b50..01d897093a 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs @@ -460,7 +460,8 @@ public enum Anchor : byte /// /// Describes how the transformer handles one image resize column. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// Name of the column resulting from the transformation of public readonly string Name; diff --git a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs index 8ee8d6bc17..5e33c0176c 100644 --- a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs @@ -449,7 +449,8 @@ internal static class Defaults /// /// Describes how the transformer handles one vector to image conversion column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// Name of the column resulting from the transformation of . public readonly string Name; diff --git a/src/Microsoft.ML.Mkl.Components/MklComponentsCatalog.cs b/src/Microsoft.ML.Mkl.Components/MklComponentsCatalog.cs index 7ba5c81031..10e144907b 100644 --- a/src/Microsoft.ML.Mkl.Components/MklComponentsCatalog.cs +++ b/src/Microsoft.ML.Mkl.Components/MklComponentsCatalog.cs @@ -142,6 +142,7 @@ public static SymbolicSgdTrainer SymbolicSgd( /// /// /// /// @@ -161,12 +162,12 @@ public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog catal /// /// /// /// /// - public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog catalog, params VectorWhiteningEstimator.ColumnOptions[] columns) + [BestFriend] + internal static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog catalog, params VectorWhiteningEstimator.ColumnOptions[] columns) => new VectorWhiteningEstimator(CatalogUtils.GetEnvironment(catalog), columns); - } } diff --git a/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs b/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs index 862031bdbc..abafdfba4e 100644 --- a/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs +++ b/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs @@ -682,7 +682,8 @@ internal static class Defaults /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// /// Name of the column resulting from the transformation of . diff --git a/src/Microsoft.ML.PCA/PCACatalog.cs b/src/Microsoft.ML.PCA/PCACatalog.cs index 2cf6f5e1bb..3c1aafe6f7 100644 --- a/src/Microsoft.ML.PCA/PCACatalog.cs +++ b/src/Microsoft.ML.PCA/PCACatalog.cs @@ -35,7 +35,8 @@ public static PrincipalComponentAnalyzer ProjectToPrincipalComponents(this Trans /// Initializes a new instance of . /// The transform's catalog. /// Input columns to apply PrincipalComponentAnalysis on. - public static PrincipalComponentAnalyzer ProjectToPrincipalComponents(this TransformsCatalog catalog, params PrincipalComponentAnalyzer.ColumnOptions[] columns) + [BestFriend] + internal static PrincipalComponentAnalyzer ProjectToPrincipalComponents(this TransformsCatalog catalog, params PrincipalComponentAnalyzer.ColumnOptions[] columns) => new PrincipalComponentAnalyzer(CatalogUtils.GetEnvironment(catalog), columns); /// diff --git a/src/Microsoft.ML.PCA/PcaTransformer.cs b/src/Microsoft.ML.PCA/PcaTransformer.cs index 57ebd98892..1860ebaa17 100644 --- a/src/Microsoft.ML.PCA/PcaTransformer.cs +++ b/src/Microsoft.ML.PCA/PcaTransformer.cs @@ -630,7 +630,8 @@ internal static class Defaults /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// /// Name of the column resulting from the transformation of . diff --git a/src/Microsoft.ML.StaticPipe/TransformsStatic.cs b/src/Microsoft.ML.StaticPipe/TransformsStatic.cs index 607a3439fa..b2f9596cd0 100644 --- a/src/Microsoft.ML.StaticPipe/TransformsStatic.cs +++ b/src/Microsoft.ML.StaticPipe/TransformsStatic.cs @@ -731,9 +731,9 @@ public static class NAReplacerStaticExtensions private readonly struct Config { public readonly bool ImputeBySlot; - public readonly MissingValueReplacingEstimator.ColumnOptions.ReplacementMode ReplacementMode; + public readonly MissingValueReplacingEstimator.ReplacementMode ReplacementMode; - public Config(MissingValueReplacingEstimator.ColumnOptions.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode, + public Config(MissingValueReplacingEstimator.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.Mode, bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) { ImputeBySlot = imputeBySlot; @@ -814,7 +814,7 @@ public override IEstimator Reconcile(IHostEnvironment env, /// /// Incoming data. /// How NaN should be replaced - public static Scalar ReplaceNaNValues(this Scalar input, MissingValueReplacingEstimator.ColumnOptions.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode) + public static Scalar ReplaceNaNValues(this Scalar input, MissingValueReplacingEstimator.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.Mode) { Contracts.CheckValue(input, nameof(input)); return new OutScalar(input, new Config(replacementMode, false)); @@ -825,7 +825,7 @@ public static Scalar ReplaceNaNValues(this Scalar input, MissingVa /// /// Incoming data. /// How NaN should be replaced - public static Scalar ReplaceNaNValues(this Scalar input, MissingValueReplacingEstimator.ColumnOptions.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode) + public static Scalar ReplaceNaNValues(this Scalar input, MissingValueReplacingEstimator.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.Mode) { Contracts.CheckValue(input, nameof(input)); return new OutScalar(input, new Config(replacementMode, false)); @@ -838,7 +838,7 @@ public static Scalar ReplaceNaNValues(this Scalar input, Missing /// If true, per-slot imputation of replacement is performed. /// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors, /// where imputation is always for the entire column. - public static Vector ReplaceNaNValues(this Vector input, MissingValueReplacingEstimator.ColumnOptions.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode, bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) + public static Vector ReplaceNaNValues(this Vector input, MissingValueReplacingEstimator.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.Mode, bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); @@ -852,7 +852,7 @@ public static Vector ReplaceNaNValues(this Vector input, MissingVa /// If true, per-slot imputation of replacement is performed. /// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors, /// where imputation is always for the entire column. - public static Vector ReplaceNaNValues(this Vector input, MissingValueReplacingEstimator.ColumnOptions.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode, bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) + public static Vector ReplaceNaNValues(this Vector input, MissingValueReplacingEstimator.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.Mode, bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input, new Config(replacementMode, imputeBySlot)); @@ -863,7 +863,7 @@ public static Vector ReplaceNaNValues(this Vector input, Missing /// /// Incoming data. /// How NaN should be replaced - public static VarVector ReplaceNaNValues(this VarVector input, MissingValueReplacingEstimator.ColumnOptions.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode) + public static VarVector ReplaceNaNValues(this VarVector input, MissingValueReplacingEstimator.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.Mode) { Contracts.CheckValue(input, nameof(input)); return new OutVarVectorColumn(input, new Config(replacementMode, false)); @@ -873,7 +873,7 @@ public static VarVector ReplaceNaNValues(this VarVector input, Mis /// /// Incoming data. /// How NaN should be replaced - public static VarVector ReplaceNaNValues(this VarVector input, MissingValueReplacingEstimator.ColumnOptions.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode) + public static VarVector ReplaceNaNValues(this VarVector input, MissingValueReplacingEstimator.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.Mode) { Contracts.CheckValue(input, nameof(input)); return new OutVarVectorColumn(input, new Config(replacementMode, false)); diff --git a/src/Microsoft.ML.Transforms/CategoricalCatalog.cs b/src/Microsoft.ML.Transforms/CategoricalCatalog.cs index 85175e5268..22ebfe7890 100644 --- a/src/Microsoft.ML.Transforms/CategoricalCatalog.cs +++ b/src/Microsoft.ML.Transforms/CategoricalCatalog.cs @@ -13,12 +13,17 @@ namespace Microsoft.ML public static class CategoricalCatalog { /// - /// Convert a text column into one-hot encoded vector. + /// Convert text columns into one-hot encoded vectors. /// /// The transform catalog /// Name of the column resulting from the transformation of . /// Name of column to transform. If set to , the value of the will be used as source. - /// The conversion mode. + /// Output kind: Bag (multi-set vector), Ind (indicator vector), Key (index), or Binary encoded indicator vector. + /// Maximum number of terms to keep per column when auto-training. + /// How items should be ordered when vectorized. If choosen they will be in the order encountered. + /// If , items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). + /// Specifies an ordering for the encoding. If specified, this should be a single column data view, + /// and the key-values will be taken from that column. If unspecified, the ordering will be determined from the input data upon fitting. /// /// /// new OneHotEncodingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, outputKind); + OneHotEncodingEstimator.OutputKind outputKind = OneHotEncodingEstimator.Defaults.OutKind, + int maximumNumberOfKeys = ValueToKeyMappingEstimator.Defaults.MaximumNumberOfKeys, + ValueToKeyMappingEstimator.KeyOrdinality keyOrdinality = ValueToKeyMappingEstimator.Defaults.Ordinality, + IDataView keyData = null) + => new OneHotEncodingEstimator(CatalogUtils.GetEnvironment(catalog), + new[] { new OneHotEncodingEstimator.ColumnOptions(outputColumnName, inputColumnName, outputKind, maximumNumberOfKeys, keyOrdinality) }, keyData); /// /// Convert several text column into one-hot encoded vectors. /// /// The transform catalog /// The column settings. - public static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.CategoricalTransforms catalog, + [BestFriend] + internal static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.CategoricalTransforms catalog, params OneHotEncodingEstimator.ColumnOptions[] columns) => new OneHotEncodingEstimator(CatalogUtils.GetEnvironment(catalog), columns); @@ -47,7 +57,8 @@ public static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.Cate /// The column settings. /// Specifies an ordering for the encoding. If specified, this should be a single column data view, /// and the key-values will be taken from that column. If unspecified, the ordering will be determined from the input data upon fitting. - public static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.CategoricalTransforms catalog, + [BestFriend] + internal static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.CategoricalTransforms catalog, OneHotEncodingEstimator.ColumnOptions[] columns, IDataView keyData = null) => new OneHotEncodingEstimator(CatalogUtils.GetEnvironment(catalog), columns, keyData); @@ -58,26 +69,32 @@ public static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.Cate /// The transform catalog /// Name of the column resulting from the transformation of . /// Name of column to transform. If set to , the value of the will be used as source. + /// The conversion mode. /// Number of bits to hash into. Must be between 1 and 30, inclusive. + /// Hashing seed. + /// Whether the position of each term should be included in the hash. /// During hashing we constuct mappings between original values and the produced hash values. /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. - /// The conversion mode. public static OneHotHashEncodingEstimator OneHotHashEncoding(this TransformsCatalog.CategoricalTransforms catalog, string outputColumnName, string inputColumnName = null, + OneHotEncodingEstimator.OutputKind outputKind = OneHotEncodingEstimator.OutputKind.Indicator, int numberOfBits = OneHotHashEncodingEstimator.Defaults.NumberOfBits, - int maximumNumberOfInverts = OneHotHashEncodingEstimator.Defaults.MaximumNumberOfInverts, - OneHotEncodingEstimator.OutputKind outputKind = OneHotEncodingEstimator.OutputKind.Indicator) - => new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName ?? outputColumnName, numberOfBits, maximumNumberOfInverts, outputKind); + uint seed = OneHotHashEncodingEstimator.Defaults.Seed, + bool useOrderedHashing = OneHotHashEncodingEstimator.Defaults.UseOrderedHashing, + int maximumNumberOfInverts = OneHotHashEncodingEstimator.Defaults.MaximumNumberOfInverts) + => new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), + new[] { new OneHotHashEncodingEstimator.ColumnOptions(outputColumnName, inputColumnName, outputKind, numberOfBits, seed, useOrderedHashing, maximumNumberOfInverts) }); /// /// Convert several text column into hash-based one-hot encoded vectors. /// /// The transform catalog /// The column settings. - public static OneHotHashEncodingEstimator OneHotHashEncoding(this TransformsCatalog.CategoricalTransforms catalog, + [BestFriend] + internal static OneHotHashEncodingEstimator OneHotHashEncoding(this TransformsCatalog.CategoricalTransforms catalog, params OneHotHashEncodingEstimator.ColumnOptions[] columns) => new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), columns); } diff --git a/src/Microsoft.ML.Transforms/ConversionsCatalog.cs b/src/Microsoft.ML.Transforms/ConversionsCatalog.cs index 6b4de2fbba..406cef8d2d 100644 --- a/src/Microsoft.ML.Transforms/ConversionsCatalog.cs +++ b/src/Microsoft.ML.Transforms/ConversionsCatalog.cs @@ -18,7 +18,8 @@ public static class ConversionsCatalog /// /// The categorical transform's catalog. /// Specifies the output and input columns on which the transformation should be applied. - public static KeyToBinaryVectorMappingEstimator MapKeyToBinaryVector(this TransformsCatalog.ConversionTransforms catalog, + [BestFriend] + internal static KeyToBinaryVectorMappingEstimator MapKeyToBinaryVector(this TransformsCatalog.ConversionTransforms catalog, params ColumnOptions[] columns) => new KeyToBinaryVectorMappingEstimator(CatalogUtils.GetEnvironment(catalog), ColumnOptions.ConvertToValueTuples(columns)); diff --git a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs index 197f9b9568..4623b04406 100644 --- a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs @@ -47,7 +47,8 @@ internal sealed class Options : TransformInputBase /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// Name of the column resulting from the transformation of . public readonly string Name; diff --git a/src/Microsoft.ML.Transforms/ExtensionsCatalog.cs b/src/Microsoft.ML.Transforms/ExtensionsCatalog.cs index 2e283ad89f..55659fbcb9 100644 --- a/src/Microsoft.ML.Transforms/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.Transforms/ExtensionsCatalog.cs @@ -15,7 +15,8 @@ public static class ExtensionsCatalog /// /// The transform extensions' catalog. /// The names of the input columns of the transformation and the corresponding names for the output columns. - public static MissingValueIndicatorEstimator IndicateMissingValues(this TransformsCatalog catalog, + [BestFriend] + internal static MissingValueIndicatorEstimator IndicateMissingValues(this TransformsCatalog catalog, params ColumnOptions[] columns) => new MissingValueIndicatorEstimator(CatalogUtils.GetEnvironment(catalog), ColumnOptions.ConvertToValueTuples(columns)); @@ -45,13 +46,16 @@ public static MissingValueIndicatorEstimator IndicateMissingValues(this Transfor /// (depending on whether the is given a value, or left to null) /// identical to the input column for everything but the missing values. The missing values of the input column, in this new column are replaced with /// one of the values specifid in the . The default for the is - /// . + /// . /// /// The transform extensions' catalog. /// Name of the column resulting from the transformation of . /// Name of column to transform. If set to , the value of the will be used as source. /// If not provided, the will be replaced with the results of the transforms. - /// The type of replacement to use as specified in + /// The type of replacement to use as specified in + /// If true, per-slot imputation of replacement is performed. + /// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors, + /// where imputation is always for the entire column. /// /// /// new MissingValueReplacingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, replacementMode); + MissingValueReplacingEstimator.ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.Mode, + bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) + => new MissingValueReplacingEstimator(CatalogUtils.GetEnvironment(catalog), new[] { new MissingValueReplacingEstimator.ColumnOptions(outputColumnName, inputColumnName, replacementMode, imputeBySlot) }); /// /// Creates a new output column, identical to the input column for everything but the missing values. - /// The missing values of the input column, in this new column are replaced with . + /// The missing values of the input column, in this new column are replaced with . /// /// The transform extensions' catalog. /// The name of the columns to use, and per-column transformation configuraiton. - public static MissingValueReplacingEstimator ReplaceMissingValues(this TransformsCatalog catalog, params MissingValueReplacingEstimator.ColumnOptions[] columns) + [BestFriend] + internal static MissingValueReplacingEstimator ReplaceMissingValues(this TransformsCatalog catalog, params MissingValueReplacingEstimator.ColumnOptions[] columns) => new MissingValueReplacingEstimator(CatalogUtils.GetEnvironment(catalog), columns); } } diff --git a/src/Microsoft.ML.Transforms/FeatureSelectionCatalog.cs b/src/Microsoft.ML.Transforms/FeatureSelectionCatalog.cs index 6e15b9fa0b..3c59b738a4 100644 --- a/src/Microsoft.ML.Transforms/FeatureSelectionCatalog.cs +++ b/src/Microsoft.ML.Transforms/FeatureSelectionCatalog.cs @@ -25,7 +25,8 @@ public static class FeatureSelectionCatalog /// ]]> /// /// - public static MutualInformationFeatureSelectingEstimator SelectFeaturesBasedOnMutualInformation(this TransformsCatalog.FeatureSelectionTransforms catalog, + [BestFriend] + internal static MutualInformationFeatureSelectingEstimator SelectFeaturesBasedOnMutualInformation(this TransformsCatalog.FeatureSelectionTransforms catalog, string labelColumnName = MutualInfoSelectDefaults.LabelColumn, int slotsInOutput = MutualInfoSelectDefaults.SlotsInOutput, int numberOfBins = MutualInfoSelectDefaults.NumBins, @@ -64,7 +65,8 @@ public static MutualInformationFeatureSelectingEstimator SelectFeaturesBasedOnMu /// ]]> /// /// - public static CountFeatureSelectingEstimator SelectFeaturesBasedOnCount(this TransformsCatalog.FeatureSelectionTransforms catalog, + [BestFriend] + internal static CountFeatureSelectingEstimator SelectFeaturesBasedOnCount(this TransformsCatalog.FeatureSelectionTransforms catalog, params CountFeatureSelectingEstimator.ColumnOptions[] columns) => new CountFeatureSelectingEstimator(CatalogUtils.GetEnvironment(catalog), columns); diff --git a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs index f5b63d99e4..651a758c2c 100644 --- a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs +++ b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs @@ -105,7 +105,8 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { // Either VBuffer> or a single Key. // Note that if CustomSlotMap contains only one array, the output type of the transform will a single Key. diff --git a/src/Microsoft.ML.Transforms/KernelCatalog.cs b/src/Microsoft.ML.Transforms/KernelCatalog.cs index 52c2d2d072..e399038833 100644 --- a/src/Microsoft.ML.Transforms/KernelCatalog.cs +++ b/src/Microsoft.ML.Transforms/KernelCatalog.cs @@ -21,6 +21,8 @@ public static class KernelExpansionCatalog /// The number of random Fourier features to create. /// If , use both of cos and sin basis functions to create two features for every random Fourier frequency. /// Otherwise, only cos bases would be used. + /// Which fourier generator to use. + /// The seed of the random number generator for generating the new features (if unspecified, the global random is used). /// /// /// new ApproximatedKernelMappingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, rank, useCosAndSinBases); + bool useCosAndSinBases = ApproximatedKernelMappingEstimator.Defaults.UseCosAndSinBases, + KernelBase generator = null, + int? seed = null) + => new ApproximatedKernelMappingEstimator(CatalogUtils.GetEnvironment(catalog), + new[] { new ApproximatedKernelMappingEstimator.ColumnOptions(outputColumnName, rank, useCosAndSinBases, inputColumnName, generator, seed) }); /// /// Takes columns filled with a vector of floats and maps its to a random low-dimensional feature space. /// /// The transform's catalog. /// The input columns to use for the transformation. - public static ApproximatedKernelMappingEstimator ApproximatedKernelMap(this TransformsCatalog catalog, params ApproximatedKernelMappingEstimator.ColumnOptions[] columns) + [BestFriend] + internal static ApproximatedKernelMappingEstimator ApproximatedKernelMap(this TransformsCatalog catalog, params ApproximatedKernelMappingEstimator.ColumnOptions[] columns) => new ApproximatedKernelMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns); } } diff --git a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs index 8dd9cc1733..820ebf0c2f 100644 --- a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs @@ -153,7 +153,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa if (!addInd) { replaceCols.Add(new MissingValueReplacingEstimator.ColumnOptions(column.Name, column.Source, - (MissingValueReplacingEstimator.ColumnOptions.ReplacementMode)(column.Kind ?? options.ReplaceWith), column.ImputeBySlot ?? options.ImputeBySlot)); + (MissingValueReplacingEstimator.ReplacementMode)(column.Kind ?? options.ReplaceWith), column.ImputeBySlot ?? options.ImputeBySlot)); continue; } @@ -188,7 +188,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa // Add the NAReplaceTransform column. replaceCols.Add(new MissingValueReplacingEstimator.ColumnOptions(tmpReplacementColName, column.Source, - (MissingValueReplacingEstimator.ColumnOptions.ReplacementMode)(column.Kind ?? options.ReplaceWith), column.ImputeBySlot ?? options.ImputeBySlot)); + (MissingValueReplacingEstimator.ReplacementMode)(column.Kind ?? options.ReplaceWith), column.ImputeBySlot ?? options.ImputeBySlot)); // Add the ConcatTransform column. if (replaceType is VectorType) diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index ea8341430e..6c10792562 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -117,7 +117,7 @@ internal sealed class Options : TransformInputBase public Column[] Columns; [Argument(ArgumentType.AtMostOnce, HelpText = "The replacement method to utilize", ShortName = "kind")] - public ReplacementKind ReplacementKind = (ReplacementKind)MissingValueReplacingEstimator.Defaults.ReplacementMode; + public ReplacementKind ReplacementKind = (ReplacementKind)MissingValueReplacingEstimator.Defaults.Mode; // Specifying by-slot imputation for vectors of unknown size will cause a warning, and the imputation will be global. [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to impute values by slot", ShortName = "slot")] @@ -441,7 +441,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa cols[i] = new MissingValueReplacingEstimator.ColumnOptions( item.Name, item.Source, - (MissingValueReplacingEstimator.ColumnOptions.ReplacementMode)(item.Kind ?? options.ReplacementKind), + (MissingValueReplacingEstimator.ReplacementMode)(item.Kind ?? options.ReplacementKind), item.Slot ?? options.ImputeBySlot, item.ReplacementString); }; @@ -890,41 +890,42 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src public sealed class MissingValueReplacingEstimator : IEstimator { + /// + /// The possible ways to replace missing values. + /// + public enum ReplacementMode : byte + { + /// + /// Replace with the default value of the column based on its type. For example, 'zero' for numeric and 'empty' for string/text columns. + /// + DefaultValue = 0, + /// + /// Replace with the mean value of the column. Supports only numeric/time span/ DateTime columns. + /// + Mean = 1, + /// + /// Replace with the minimum value of the column. Supports only numeric/time span/ DateTime columns. + /// + Minimum = 2, + /// + /// Replace with the maximum value of the column. Supports only numeric/time span/ DateTime columns. + /// + Maximum = 3, + } + [BestFriend] internal static class Defaults { - public const ColumnOptions.ReplacementMode ReplacementMode = ColumnOptions.ReplacementMode.DefaultValue; + public const ReplacementMode Mode = ReplacementMode.DefaultValue; public const bool ImputeBySlot = true; } /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { - /// - /// The possible ways to replace missing values. - /// - public enum ReplacementMode : byte - { - /// - /// Replace with the default value of the column based on its type. For example, 'zero' for numeric and 'empty' for string/text columns. - /// - DefaultValue = 0, - /// - /// Replace with the mean value of the column. Supports only numeric/time span/ DateTime columns. - /// - Mean = 1, - /// - /// Replace with the minimum value of the column. Supports only numeric/time span/ DateTime columns. - /// - Minimum = 2, - /// - /// Replace with the maximum value of the column. Supports only numeric/time span/ DateTime columns. - /// - Maximum = 3, - } - /// Name of the column resulting from the transformation of . public readonly string Name; /// Name of column to transform. @@ -949,7 +950,7 @@ public enum ReplacementMode : byte /// If true, per-slot imputation of replacement is performed. /// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors, /// where imputation is always for the entire column. - public ColumnOptions(string name, string inputColumnName = null, ReplacementMode replacementMode = Defaults.ReplacementMode, + public ColumnOptions(string name, string inputColumnName = null, ReplacementMode replacementMode = Defaults.Mode, bool imputeBySlot = Defaults.ImputeBySlot) { Contracts.CheckNonWhiteSpace(name, nameof(name)); @@ -973,7 +974,7 @@ internal ColumnOptions(string name, string inputColumnName, ReplacementMode repl private readonly IHost _host; private readonly ColumnOptions[] _columns; - internal MissingValueReplacingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, ColumnOptions.ReplacementMode replacementKind = Defaults.ReplacementMode) + internal MissingValueReplacingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, ReplacementMode replacementKind = Defaults.Mode) : this(env, new ColumnOptions(outputColumnName, inputColumnName ?? outputColumnName, replacementKind)) { diff --git a/src/Microsoft.ML.Transforms/NormalizerCatalog.cs b/src/Microsoft.ML.Transforms/NormalizerCatalog.cs index 04020355be..657e58bf32 100644 --- a/src/Microsoft.ML.Transforms/NormalizerCatalog.cs +++ b/src/Microsoft.ML.Transforms/NormalizerCatalog.cs @@ -40,7 +40,8 @@ public static NormalizingEstimator Normalize(this TransformsCatalog catalog, /// ]]> /// /// - public static NormalizingEstimator Normalize(this TransformsCatalog catalog, + [BestFriend] + internal static NormalizingEstimator Normalize(this TransformsCatalog catalog, NormalizingEstimator.NormalizationMode mode, params ColumnOptions[] columns) => new NormalizingEstimator(CatalogUtils.GetEnvironment(catalog), mode, ColumnOptions.ConvertToValueTuples(columns)); @@ -50,7 +51,8 @@ public static NormalizingEstimator Normalize(this TransformsCatalog catalog, /// /// The transform catalog /// The normalization settings for all the columns - public static NormalizingEstimator Normalize(this TransformsCatalog catalog, + [BestFriend] + internal static NormalizingEstimator Normalize(this TransformsCatalog catalog, params NormalizingEstimator.ColumnOptionsBase[] columns) => new NormalizingEstimator(CatalogUtils.GetEnvironment(catalog), columns); @@ -79,7 +81,8 @@ public static LpNormNormalizingEstimator NormalizeLpNorm(this TransformsCatalog /// /// The transform's catalog. /// Describes the parameters of the lp-normalization process for each column pair. - public static LpNormNormalizingEstimator NormalizeLpNorm(this TransformsCatalog catalog, params LpNormNormalizingEstimator.ColumnOptions[] columns) + [BestFriend] + internal static LpNormNormalizingEstimator NormalizeLpNorm(this TransformsCatalog catalog, params LpNormNormalizingEstimator.ColumnOptions[] columns) => new LpNormNormalizingEstimator(CatalogUtils.GetEnvironment(catalog), columns); /// @@ -110,7 +113,8 @@ public static GlobalContrastNormalizingEstimator NormalizeGlobalContrast(this Tr /// /// The transform's catalog. /// Describes the parameters of the gcn-normaliztion process for each column pair. - public static GlobalContrastNormalizingEstimator NormalizeGlobalContrast(this TransformsCatalog catalog, params GlobalContrastNormalizingEstimator.ColumnOptions[] columns) + [BestFriend] + internal static GlobalContrastNormalizingEstimator NormalizeGlobalContrast(this TransformsCatalog catalog, params GlobalContrastNormalizingEstimator.ColumnOptions[] columns) => new GlobalContrastNormalizingEstimator(CatalogUtils.GetEnvironment(catalog), columns); } } diff --git a/src/Microsoft.ML.Transforms/OneHotEncoding.cs b/src/Microsoft.ML.Transforms/OneHotEncoding.cs index a79294c613..a7b6373c51 100644 --- a/src/Microsoft.ML.Transforms/OneHotEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotEncoding.cs @@ -181,7 +181,8 @@ public enum OutputKind : byte /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions : ValueToKeyMappingEstimator.ColumnOptionsBase + [BestFriend] + internal sealed class ColumnOptions : ValueToKeyMappingEstimator.ColumnOptionsBase { public readonly OutputKind OutputKind; /// diff --git a/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs b/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs index 8a543619dc..957433a8d9 100644 --- a/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs @@ -216,7 +216,8 @@ internal static class Defaults /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { public readonly HashingEstimator.ColumnOptions HashingOptions; public readonly OneHotEncodingEstimator.OutputKind OutputKind; diff --git a/src/Microsoft.ML.Transforms/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Transforms/Properties/AssemblyInfo.cs index b9317da64c..3ef4273ed9 100644 --- a/src/Microsoft.ML.Transforms/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Transforms/Properties/AssemblyInfo.cs @@ -17,5 +17,6 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.RServerScoring.TextAnalytics" + InternalPublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)] [assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs index 0d83665562..d4f80b26e4 100644 --- a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs +++ b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs @@ -617,7 +617,8 @@ internal static class Defaults /// /// Describes how the transformer handles one Gcn column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// /// Name of the column resulting from the transformation of . diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 66ce9278e2..81b581b12e 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -1004,7 +1004,8 @@ internal LatentDirichletAllocationEstimator(IHostEnvironment env, params ColumnO /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// /// Name of the column resulting from the transformation of . diff --git a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs index a3c7c3194d..d8945fdf3d 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs @@ -872,7 +872,8 @@ public sealed class NgramHashingEstimator : IEstimator /// /// Describes how the transformer handles one pair of mulitple inputs - singular output columns. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// Name of the column resulting from the transformation of . public readonly string Name; diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index 0f22e279bd..44b3c98bb3 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -793,7 +793,8 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col) /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// Name of the column resulting from the transformation of . public readonly string Name; diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index 22c0ab8b14..415b18077b 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -508,7 +508,8 @@ public Options() /// /// Describes how the transformer handles one column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// Name of the column resulting from the transformation of . public readonly string Name; diff --git a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs index ef56dbf065..4aa28da763 100644 --- a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs @@ -72,7 +72,8 @@ public static TokenizingByCharactersEstimator TokenizeIntoCharactersAsKeys(this /// and append another marker character, , to the end of the output vector of characters. /// Pairs of columns to run the tokenization on. - public static TokenizingByCharactersEstimator TokenizeIntoCharactersAsKeys(this TransformsCatalog.TextTransforms catalog, + [BestFriend] + internal static TokenizingByCharactersEstimator TokenizeIntoCharactersAsKeys(this TransformsCatalog.TextTransforms catalog, bool useMarkerCharacters = CharTokenizingDefaults.UseMarkerCharacters, params ColumnOptions[] columns) => new TokenizingByCharactersEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), useMarkerCharacters, ColumnOptions.ConvertToValueTuples(columns)); @@ -118,8 +119,8 @@ public static WordEmbeddingEstimator ApplyWordEmbedding(this TransformsCatalog.T /// /// The text-related transform's catalog. - /// Name of the column resulting from the transformation of . /// The path of the pre-trained embeedings model to use. + /// Name of the column resulting from the transformation of . /// Name of the column to transform. /// /// @@ -146,7 +147,8 @@ public static WordEmbeddingEstimator ApplyWordEmbedding(this TransformsCatalog.T /// ]]> /// /// - public static WordEmbeddingEstimator ApplyWordEmbedding(this TransformsCatalog.TextTransforms catalog, + [BestFriend] + internal static WordEmbeddingEstimator ApplyWordEmbedding(this TransformsCatalog.TextTransforms catalog, WordEmbeddingEstimator.PretrainedModelKind modelKind = WordEmbeddingEstimator.PretrainedModelKind.SentimentSpecificWordEmbedding, params WordEmbeddingEstimator.ColumnOptions[] columns) => new WordEmbeddingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), modelKind, columns); @@ -170,7 +172,8 @@ public static WordTokenizingEstimator TokenizeIntoWords(this TransformsCatalog.T /// /// The text-related transform's catalog. /// Pairs of columns to run the tokenization on. - public static WordTokenizingEstimator TokenizeIntoWords(this TransformsCatalog.TextTransforms catalog, + [BestFriend] + internal static WordTokenizingEstimator TokenizeIntoWords(this TransformsCatalog.TextTransforms catalog, params WordTokenizingEstimator.ColumnOptions[] columns) => new WordTokenizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns); @@ -210,7 +213,8 @@ public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.Text /// /// The text-related transform's catalog. /// Pairs of columns to run the ngram process on. - public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.TextTransforms catalog, + [BestFriend] + internal static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.TextTransforms catalog, params NgramExtractingEstimator.ColumnOptions[] columns) => new NgramExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns); @@ -384,6 +388,7 @@ public static WordHashBagEstimator ProduceHashedWordBags(this TransformsCatalog. /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. + /// Whether to rehash unigrams. public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.TextTransforms catalog, string outputColumnName, string inputColumnName = null, @@ -393,10 +398,47 @@ public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.T bool useAllLengths = NgramHashingEstimator.Defaults.UseAllLengths, uint seed = NgramHashingEstimator.Defaults.Seed, bool useOrderedHashing = NgramHashingEstimator.Defaults.UseOrderedHashing, - int maximumNumberOfInverts = NgramHashingEstimator.Defaults.MaximumNumberOfInverts) + int maximumNumberOfInverts = NgramHashingEstimator.Defaults.MaximumNumberOfInverts, + bool rehashUnigrams = NgramHashingEstimator.Defaults.RehashUnigrams) + => new NgramHashingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), + new[] {new NgramHashingEstimator.ColumnOptions(outputColumnName, new[] { inputColumnName }, ngramLength: ngramLength, skipLength: skipLength, + useAllLengths: useAllLengths, numberOfBits: numberOfBits, seed: seed, useOrderedHashing: useOrderedHashing, maximumNumberOfInverts: maximumNumberOfInverts, rehashUnigrams) }); + + /// + /// Produces a bag of counts of hashed ngrams in + /// and outputs ngram vector as + /// + /// is different from in a way that + /// takes tokenized text as input while tokenizes text internally. + /// + /// The text-related transform's catalog. + /// Name of the column resulting from the transformation of . + /// Names of the columns to transform. If set to , the value of the will be used as source. + /// Number of bits to hash into. Must be between 1 and 30, inclusive. + /// Ngram length. + /// Maximum number of tokens to skip when constructing an ngram. + /// Whether to include all ngram lengths up to or only . + /// Hashing seed. + /// Whether the position of each source column should be included in the hash (when there are multiple source columns). + /// During hashing we constuct mappings between original values and the produced hash values. + /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. + /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. + /// 0 does not retain any input values. -1 retains all input values mapping to each hash. + /// Whether to rehash unigrams. + public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.TextTransforms catalog, + string outputColumnName, + string[] inputColumnNames = null, + int numberOfBits = NgramHashingEstimator.Defaults.NumberOfBits, + int ngramLength = NgramHashingEstimator.Defaults.NgramLength, + int skipLength = NgramHashingEstimator.Defaults.SkipLength, + bool useAllLengths = NgramHashingEstimator.Defaults.UseAllLengths, + uint seed = NgramHashingEstimator.Defaults.Seed, + bool useOrderedHashing = NgramHashingEstimator.Defaults.UseOrderedHashing, + int maximumNumberOfInverts = NgramHashingEstimator.Defaults.MaximumNumberOfInverts, + bool rehashUnigrams = NgramHashingEstimator.Defaults.RehashUnigrams) => new NgramHashingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), - outputColumnName, inputColumnName, numberOfBits: numberOfBits, ngramLength: ngramLength, skipLength: skipLength, - useAllLengths: useAllLengths, seed: seed, useOrderedHashing: useOrderedHashing, maximumNumberOfInverts: maximumNumberOfInverts); + new[] {new NgramHashingEstimator.ColumnOptions(outputColumnName, inputColumnNames, ngramLength: ngramLength, skipLength: skipLength, + useAllLengths: useAllLengths, numberOfBits: numberOfBits, seed: seed, useOrderedHashing: useOrderedHashing, maximumNumberOfInverts: maximumNumberOfInverts, rehashUnigrams) }); /// /// Produces a bag of counts of hashed ngrams for each . For each column, @@ -407,7 +449,8 @@ public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.T /// /// The text-related transform's catalog. /// Pairs of columns to compute n-grams. Note that gram indices are generated by hashing. - public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.TextTransforms catalog, + [BestFriend] + internal static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.TextTransforms catalog, NgramHashingEstimator.ColumnOptions[] columns) => new NgramHashingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns); @@ -419,9 +462,16 @@ public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.T /// Name of the column resulting from the transformation of . /// Name of the column to transform. If set to , the value of the will be used as source. /// The number of topics. + /// Dirichlet prior on document-topic vectors. + /// Dirichlet prior on vocab-topic vectors. + /// Number of Metropolis Hasting step. /// Number of iterations. + /// Compute log likelihood over local dataset on this iteration interval. + /// The number of training threads. Default value depends on number of logical processors. /// The threshold of maximum count of tokens per doc. /// The number of words to summarize the topic. + /// The number of burn-in iterations. + /// Reset the random number generator for each document. /// /// /// new LatentDirichletAllocationEstimator(CatalogUtils.GetEnvironment(catalog), - outputColumnName, inputColumnName, numberOfTopics, - LatentDirichletAllocationEstimator.Defaults.AlphaSum, - LatentDirichletAllocationEstimator.Defaults.Beta, - LatentDirichletAllocationEstimator.Defaults.SamplingStepCount, - maximumNumberOfIterations, - LatentDirichletAllocationEstimator.Defaults.NumberOfThreads, - maximumTokenCountPerDocument, - numberOfSummaryTermsPerTopic, - LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval, - LatentDirichletAllocationEstimator.Defaults.NumberOfBurninIterations, - LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator); + outputColumnName, inputColumnName, numberOfTopics, alphaSum, beta, samplingStepCount, + maximumNumberOfIterations, numberOfThreads, maximumTokenCountPerDocument, numberOfSummaryTermsPerTopic, + likelihoodInterval, numberOfBurninIterations, resetRandomGenerator); /// /// Uses LightLDA to transform a document (represented as a vector of floats) @@ -455,7 +504,8 @@ public static LatentDirichletAllocationEstimator LatentDirichletAllocation(this /// /// The transform's catalog. /// Describes the parameters of LDA for each column pair. - public static LatentDirichletAllocationEstimator LatentDirichletAllocation( + [BestFriend] + internal static LatentDirichletAllocationEstimator LatentDirichletAllocation( this TransformsCatalog.TextTransforms catalog, params LatentDirichletAllocationEstimator.ColumnOptions[] columns) => new LatentDirichletAllocationEstimator(CatalogUtils.GetEnvironment(catalog), columns); diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index 3094936ba9..9eff2abd7d 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -828,7 +828,8 @@ public enum PretrainedModelKind /// /// Information for each column pair. /// - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// /// Name of the column resulting from the transformation of . diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs index 8c29215b94..f54db6d352 100644 --- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs +++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs @@ -438,7 +438,8 @@ internal WordTokenizingEstimator(IHostEnvironment env, params ColumnOptions[] co : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(WordTokenizingEstimator)), new WordTokenizingTransformer(env, columns)) { } - public sealed class ColumnOptions + [BestFriend] + internal sealed class ColumnOptions { /// /// Output column name that will be used to store the tokenization result of column. diff --git a/test/Microsoft.ML.Functional.Tests/Debugging.cs b/test/Microsoft.ML.Functional.Tests/Debugging.cs index 8682586059..5e25f7f3b0 100644 --- a/test/Microsoft.ML.Functional.Tests/Debugging.cs +++ b/test/Microsoft.ML.Functional.Tests/Debugging.cs @@ -105,7 +105,6 @@ public void InspectPipelineSchema() // Define a pipeline var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features) - .Append(mlContext.Transforms.Normalize()) .AppendCacheCheckpoint(mlContext) .Append(mlContext.Regression.Trainers.Sdca( new SdcaRegressionTrainer.Options { NumberOfThreads = 1, MaximumNumberOfIterations = 20 })); @@ -173,7 +172,6 @@ public void ViewTrainingOutput() // Define a pipeline var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features) - .Append(mlContext.Transforms.Normalize()) .AppendCacheCheckpoint(mlContext) .Append(mlContext.Regression.Trainers.Sdca( new SdcaRegressionTrainer.Options { NumberOfThreads = 1, MaximumNumberOfIterations = 20 })); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index bc22299fb4..324b732ac1 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -999,7 +999,8 @@ public void TensorFlowSentimentClassificationTest() // Then this integer vector is retrieved from the pipeline and resized to fixed length. // The second pipeline 'tfEnginePipe' takes the resized integer vector and passes it to TensoFlow and gets the classification scores. var estimator = mlContext.Transforms.Text.TokenizeIntoWords("TokenizedWords", "Sentiment_Text") - .Append(mlContext.Transforms.Conversion.MapValue(lookupMap, "Words", "Ids", new ColumnOptions[] { ("Features", "TokenizedWords") })); + .Append(mlContext.Transforms.Conversion.MapValue(lookupMap, lookupMap.Schema["Words"], lookupMap.Schema["Ids"], + new ColumnOptions[] { ("Features", "TokenizedWords") })); var model = estimator.Fit(dataView); var dataPipe = mlContext.Model.CreatePredictionEngine(model); diff --git a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs index 0bb02cacd6..5161071a1a 100644 --- a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs @@ -75,10 +75,10 @@ public void CategoricalOneHotHashEncoding() var mlContext = new MLContext(); var dataView = mlContext.Data.LoadFromEnumerable(data); - var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("CatA", "A", 3, 0, OneHotEncodingEstimator.OutputKind.Bag) - .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("CatB", "A", 2, 0, OneHotEncodingEstimator.OutputKind.Key)) - .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("CatC", "A", 3, 0, OneHotEncodingEstimator.OutputKind.Indicator)) - .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("CatD", "A", 2, 0, OneHotEncodingEstimator.OutputKind.Binary)); + var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("CatA", "A", OneHotEncodingEstimator.OutputKind.Bag, 3, 0) + .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("CatB", "A", OneHotEncodingEstimator.OutputKind.Key, 2, 0)) + .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("CatC", "A", OneHotEncodingEstimator.OutputKind.Indicator, 3, 0)) + .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("CatD", "A", OneHotEncodingEstimator.OutputKind.Binary, 2, 0)); TestEstimatorCore(pipe, dataView); Done(); diff --git a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs index ffe79ad181..fc55126dc6 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs @@ -43,10 +43,10 @@ public void NAReplaceWorkout() var dataView = ML.Data.LoadFromEnumerable(data); var pipe = ML.Transforms.ReplaceMissingValues( - new MissingValueReplacingEstimator.ColumnOptions("NAA", "A", MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Mean), - new MissingValueReplacingEstimator.ColumnOptions("NAB", "B", MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Mean), - new MissingValueReplacingEstimator.ColumnOptions("NAC", "C", MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Mean), - new MissingValueReplacingEstimator.ColumnOptions("NAD", "D", MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Mean)); + new MissingValueReplacingEstimator.ColumnOptions("NAA", "A", MissingValueReplacingEstimator.ReplacementMode.Mean), + new MissingValueReplacingEstimator.ColumnOptions("NAB", "B", MissingValueReplacingEstimator.ReplacementMode.Mean), + new MissingValueReplacingEstimator.ColumnOptions("NAC", "C", MissingValueReplacingEstimator.ReplacementMode.Mean), + new MissingValueReplacingEstimator.ColumnOptions("NAD", "D", MissingValueReplacingEstimator.ReplacementMode.Mean)); TestEstimatorCore(pipe, dataView); Done(); } @@ -68,10 +68,10 @@ public void NAReplaceStatic() var est = data.MakeNewEstimator(). Append(row => ( - A: row.ScalarFloat.ReplaceNaNValues(MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Maximum), - B: row.ScalarDouble.ReplaceNaNValues(MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Mean), - C: row.VectorFloat.ReplaceNaNValues(MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Mean), - D: row.VectorDoulbe.ReplaceNaNValues(MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Minimum) + A: row.ScalarFloat.ReplaceNaNValues(MissingValueReplacingEstimator.ReplacementMode.Maximum), + B: row.ScalarDouble.ReplaceNaNValues(MissingValueReplacingEstimator.ReplacementMode.Mean), + C: row.VectorFloat.ReplaceNaNValues(MissingValueReplacingEstimator.ReplacementMode.Mean), + D: row.VectorDoulbe.ReplaceNaNValues(MissingValueReplacingEstimator.ReplacementMode.Minimum) )); TestEstimatorCore(est.AsDynamic, data.AsDynamic, invalidInput: invalidData); @@ -104,10 +104,10 @@ public void TestOldSavingAndLoading() var dataView = ML.Data.LoadFromEnumerable(data); var pipe = ML.Transforms.ReplaceMissingValues( - new MissingValueReplacingEstimator.ColumnOptions("NAA", "A", MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Mean), - new MissingValueReplacingEstimator.ColumnOptions("NAB", "B", MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Mean), - new MissingValueReplacingEstimator.ColumnOptions("NAC", "C", MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Mean), - new MissingValueReplacingEstimator.ColumnOptions("NAD", "D", MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Mean)); + new MissingValueReplacingEstimator.ColumnOptions("NAA", "A", MissingValueReplacingEstimator.ReplacementMode.Mean), + new MissingValueReplacingEstimator.ColumnOptions("NAB", "B", MissingValueReplacingEstimator.ReplacementMode.Mean), + new MissingValueReplacingEstimator.ColumnOptions("NAC", "C", MissingValueReplacingEstimator.ReplacementMode.Mean), + new MissingValueReplacingEstimator.ColumnOptions("NAD", "D", MissingValueReplacingEstimator.ReplacementMode.Mean)); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs index 9638a558de..3b90d25e8b 100644 --- a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs @@ -507,11 +507,15 @@ public void ValueMappingWorkout() var badData = new[] { new TestWrong() { A = "bar", B = 1.2f } }; var badDataView = ML.Data.LoadFromEnumerable(badData); - var keys = new List() { "foo", "bar", "test", "wahoo" }; - var values = new List() { 1, 2, 3, 4 }; + var keyValuePairs = new List>() { + new KeyValuePair("foo", 1), + new KeyValuePair("bar", 2), + new KeyValuePair("test", 3), + new KeyValuePair("wahoo", 4) + }; // Workout on value mapping - var est = ML.Transforms.Conversion.MapValue(keys, values, new ColumnOptions[] { ("D", "A"), ("E", "B"), ("F", "C") }); + var est = ML.Transforms.Conversion.MapValue(keyValuePairs, new ColumnOptions[] { ("D", "A"), ("E", "B"), ("F", "C") }); TestEstimatorCore(est, validFitInput: dataView, invalidInput: badDataView); } @@ -523,14 +527,14 @@ public void ValueMappingValueTypeIsVectorWorkout() var badData = new[] { new TestWrong() { A = "bar", B = 1.2f } }; var badDataView = ML.Data.LoadFromEnumerable(badData); - var keys = new List() { "foo", "bar", "test" }; - var values = new List() { - new int[] {2, 3, 4 }, - new int[] {100, 200 }, - new int[] {400, 500, 600, 700 }}; + var keyValuePairs = new List>() { + new KeyValuePair("foo", new int[] {2, 3, 4 }), + new KeyValuePair("bar", new int[] {100, 200 }), + new KeyValuePair("test", new int[] {400, 500, 600, 700 }), + }; // Workout on value mapping - var est = ML.Transforms.Conversion.MapValue(keys, values, new ColumnOptions[] { ("D", "A"), ("E", "B"), ("F", "C") }); + var est = ML.Transforms.Conversion.MapValue(keyValuePairs, new ColumnOptions[] { ("D", "A"), ("E", "B"), ("F", "C") }); TestEstimatorCore(est, validFitInput: dataView, invalidInput: badDataView); } @@ -543,11 +547,15 @@ public void ValueMappingInputIsVectorWorkout() var badData = new[] { new TestWrong() { B = 1.2f } }; var badDataView = ML.Data.LoadFromEnumerable(badData); - var keys = new List>() { "foo".AsMemory(), "bar".AsMemory(), "test".AsMemory(), "wahoo".AsMemory() }; - var values = new List() { 1, 2, 3, 4 }; + var keyValuePairs = new List,int>>() { + new KeyValuePair,int>("foo".AsMemory(), 1), + new KeyValuePair,int>("bar".AsMemory(), 2), + new KeyValuePair,int>("test".AsMemory(), 3), + new KeyValuePair,int>("wahoo".AsMemory(), 4) + }; var est = ML.Transforms.Text.TokenizeIntoWords("TokenizeB", "B") - .Append(ML.Transforms.Conversion.MapValue(keys, values, new ColumnOptions[] { ("VecB", "TokenizeB") })); + .Append(ML.Transforms.Conversion.MapValue(keyValuePairs, new ColumnOptions[] { ("VecB", "TokenizeB") })); TestEstimatorCore(est, validFitInput: dataView, invalidInput: badDataView); } From f03c49d8045b0cc85bff104488610a21936a5195 Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Tue, 19 Mar 2019 15:58:56 -0500 Subject: [PATCH 05/18] Updating the FunctionalTests to clearly explain why they are not strong named signed. (#3010) --- .../Microsoft.ML.Functional.Tests.csproj | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/Microsoft.ML.Functional.Tests/Microsoft.ML.Functional.Tests.csproj b/test/Microsoft.ML.Functional.Tests/Microsoft.ML.Functional.Tests.csproj index 88c51526af..c1c7af9263 100644 --- a/test/Microsoft.ML.Functional.Tests/Microsoft.ML.Functional.Tests.csproj +++ b/test/Microsoft.ML.Functional.Tests/Microsoft.ML.Functional.Tests.csproj @@ -1,8 +1,9 @@  - - true + + false + false From 00a5b3533bd7033a0875f4c6b456fd54281c0e54 Mon Sep 17 00:00:00 2001 From: Shahab Moradi Date: Tue, 19 Mar 2019 17:14:52 -0400 Subject: [PATCH 06/18] Added samples for tree regression trainers. (#2999) * Added samples for tree regression trainers. * PR comments * Added tab. --- .../Dynamic/Trainers/Regression/FastForest.cs | 95 ++++++++++++++++ .../Regression/FastForestWithOptions.cs | 107 ++++++++++++++++++ .../Trainers/Regression/FastTreeTweedie.cs | 95 ++++++++++++++++ .../Regression/FastTreeTweedieWithOptions.cs | 107 ++++++++++++++++++ .../Regression/FastTreeWithOptions.cs | 107 ++++++++++++++++++ .../Regression/Gam.cs} | 64 ++++++----- .../Trainers/Regression/GamWithOptions.cs | 105 +++++++++++++++++ .../FastTreeArguments.cs | 1 + .../TreeTrainersCatalog.cs | 51 ++++++++- src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs | 2 +- 10 files changed, 701 insertions(+), 33 deletions(-) create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastForest.cs create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastForestWithOptions.cs create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTreeTweedie.cs create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTreeTweedieWithOptions.cs create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTreeWithOptions.cs rename docs/samples/Microsoft.ML.Samples/Dynamic/{GeneralizedAdditiveModels.cs => Trainers/Regression/Gam.cs} (72%) create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/GamWithOptions.cs diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastForest.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastForest.cs new file mode 100644 index 0000000000..7263aef771 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastForest.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.Regression +{ + public static class FastForest + { + // This example requires installation of additional NuGet package + // Microsoft.ML.FastTree. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Create a list of training examples. + var examples = GenerateRandomDataPoints(1000); + + // Convert the examples list to an IDataView object, which is consumable by ML.NET API. + var trainingData = mlContext.Data.LoadFromEnumerable(examples); + + // Define the trainer. + var pipeline = mlContext.Regression.Trainers.FastForest(); + + // Train the model. + var model = pipeline.Fit(trainingData); + + // Create testing examples. Use different random seed to make it different from training data. + var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); + + // Run the model on test data set. + var transformedTestData = model.Transform(testData); + + // Convert IDataView object to a list. + var predictions = mlContext.Data.CreateEnumerable(transformedTestData, reuseRowObject: false).ToList(); + + // Look at 5 predictions + foreach (var p in predictions.Take(5)) + Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}"); + + // Expected output: + // Label: 0.985, Prediction: 0.864 + // Label: 0.155, Prediction: 0.164 + // Label: 0.515, Prediction: 0.470 + // Label: 0.566, Prediction: 0.501 + // Label: 0.096, Prediction: 0.138 + + // Evaluate the overall metrics + var metrics = mlContext.Regression.Evaluate(transformedTestData); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Mean Absolute Error: 0.06 + // Mean Squared Error: 0.01 + // Root Mean Squared Error: 0.07 + // RSquared: 0.93 + } + + private static IEnumerable GenerateRandomDataPoints(int count, int seed=0) + { + var random = new Random(seed); + float randomFloat() => (float)random.NextDouble(); + for (int i = 0; i < count; i++) + { + var label = randomFloat(); + yield return new DataPoint + { + Label = label, + // Create random features that are correlated with label. + Features = Enumerable.Repeat(label, 50).Select(x => x + randomFloat()).ToArray() + }; + } + } + + // Example with label and 50 feature values. A data set is a collection of such examples. + private class DataPoint + { + public float Label { get; set; } + [VectorType(50)] + public float[] Features { get; set; } + } + + // Class used to capture predictions. + private class Prediction + { + // Original label. + public float Label { get; set; } + // Predicted score from the trainer. + public float Score { get; set; } + } + } +} \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastForestWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastForestWithOptions.cs new file mode 100644 index 0000000000..4629d882e1 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastForestWithOptions.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers.FastTree; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.Regression +{ + public static class FastForestWithOptions + { + // This example requires installation of additional NuGet package + // Microsoft.ML.FastTree. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Create a list of training examples. + var examples = GenerateRandomDataPoints(1000); + + // Convert the examples list to an IDataView object, which is consumable by ML.NET API. + var trainingData = mlContext.Data.LoadFromEnumerable(examples); + + // Define trainer options. + var options = new FastForestRegressionTrainer.Options + { + // Only use 80% of features to reduce over-fitting. + FeatureFraction = 0.8, + // Create a simpler model by penalizing usage of new features. + FeatureFirstUsePenalty = 0.1, + // Reduce the number of trees to 50. + NumberOfTrees = 50 + }; + + // Define the trainer. + var pipeline = mlContext.Regression.Trainers.FastForest(options); + + // Train the model. + var model = pipeline.Fit(trainingData); + + // Create testing examples. Use different random seed to make it different from training data. + var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); + + // Run the model on test data set. + var transformedTestData = model.Transform(testData); + + // Convert IDataView object to a list. + var predictions = mlContext.Data.CreateEnumerable(transformedTestData, reuseRowObject: false).ToList(); + + // Look at 5 predictions + foreach (var p in predictions.Take(5)) + Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}"); + + // Expected output: + // Label: 0.985, Prediction: 0.866 + // Label: 0.155, Prediction: 0.171 + // Label: 0.515, Prediction: 0.470 + // Label: 0.566, Prediction: 0.476 + // Label: 0.096, Prediction: 0.140 + + // Evaluate the overall metrics + var metrics = mlContext.Regression.Evaluate(transformedTestData); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Mean Absolute Error: 0.06 + // Mean Squared Error: 0.01 + // Root Mean Squared Error: 0.08 + // RSquared: 0.93 + } + + private static IEnumerable GenerateRandomDataPoints(int count, int seed=0) + { + var random = new Random(seed); + float randomFloat() => (float)random.NextDouble(); + for (int i = 0; i < count; i++) + { + var label = randomFloat(); + yield return new DataPoint + { + Label = label, + // Create random features that are correlated with label. + Features = Enumerable.Repeat(label, 50).Select(x => x + randomFloat()).ToArray() + }; + } + } + + // Example with label and 50 feature values. A data set is a collection of such examples. + private class DataPoint + { + public float Label { get; set; } + [VectorType(50)] + public float[] Features { get; set; } + } + + // Class used to capture predictions. + private class Prediction + { + // Original label. + public float Label { get; set; } + // Predicted score from the trainer. + public float Score { get; set; } + } + } +} \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTreeTweedie.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTreeTweedie.cs new file mode 100644 index 0000000000..ead29a678f --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTreeTweedie.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.Regression +{ + public static class FastTreeTweedie + { + // This example requires installation of additional NuGet package + // Microsoft.ML.FastTree. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Create a list of training examples. + var examples = GenerateRandomDataPoints(1000); + + // Convert the examples list to an IDataView object, which is consumable by ML.NET API. + var trainingData = mlContext.Data.LoadFromEnumerable(examples); + + // Define the trainer. + var pipeline = mlContext.Regression.Trainers.FastTreeTweedie(); + + // Train the model. + var model = pipeline.Fit(trainingData); + + // Create testing examples. Use different random seed to make it different from training data. + var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); + + // Run the model on test data set. + var transformedTestData = model.Transform(testData); + + // Convert IDataView object to a list. + var predictions = mlContext.Data.CreateEnumerable(transformedTestData, reuseRowObject: false).ToList(); + + // Look at 5 predictions + foreach (var p in predictions.Take(5)) + Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}"); + + // Expected output: + // Label: 0.985, Prediction: 0.945 + // Label: 0.155, Prediction: 0.104 + // Label: 0.515, Prediction: 0.515 + // Label: 0.566, Prediction: 0.448 + // Label: 0.096, Prediction: 0.082 + + // Evaluate the overall metrics + var metrics = mlContext.Regression.Evaluate(transformedTestData); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Mean Absolute Error: 0.05 + // Mean Squared Error: 0.00 + // Root Mean Squared Error: 0.06 + // RSquared: 0.95 + } + + private static IEnumerable GenerateRandomDataPoints(int count, int seed=0) + { + var random = new Random(seed); + float randomFloat() => (float)random.NextDouble(); + for (int i = 0; i < count; i++) + { + var label = randomFloat(); + yield return new DataPoint + { + Label = label, + // Create random features that are correlated with label. + Features = Enumerable.Repeat(label, 50).Select(x => x + randomFloat()).ToArray() + }; + } + } + + // Example with label and 50 feature values. A data set is a collection of such examples. + private class DataPoint + { + public float Label { get; set; } + [VectorType(50)] + public float[] Features { get; set; } + } + + // Class used to capture predictions. + private class Prediction + { + // Original label. + public float Label { get; set; } + // Predicted score from the trainer. + public float Score { get; set; } + } + } +} \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTreeTweedieWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTreeTweedieWithOptions.cs new file mode 100644 index 0000000000..dd75a9c4f4 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTreeTweedieWithOptions.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers.FastTree; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.Regression +{ + public static class FastTreeTweedieWithOptions + { + // This example requires installation of additional NuGet package + // Microsoft.ML.FastTree. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Create a list of training examples. + var examples = GenerateRandomDataPoints(1000); + + // Convert the examples list to an IDataView object, which is consumable by ML.NET API. + var trainingData = mlContext.Data.LoadFromEnumerable(examples); + + // Define trainer options. + var options = new FastTreeTweedieTrainer.Options + { + // Use L2Norm for early stopping. + EarlyStoppingMetric = EarlyStoppingMetric.L2Norm, + // Create a simpler model by penalizing usage of new features. + FeatureFirstUsePenalty = 0.1, + // Reduce the number of trees to 50. + NumberOfTrees = 50 + }; + + // Define the trainer. + var pipeline = mlContext.Regression.Trainers.FastTreeTweedie(options); + + // Train the model. + var model = pipeline.Fit(trainingData); + + // Create testing examples. Use different random seed to make it different from training data. + var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); + + // Run the model on test data set. + var transformedTestData = model.Transform(testData); + + // Convert IDataView object to a list. + var predictions = mlContext.Data.CreateEnumerable(transformedTestData, reuseRowObject: false).ToList(); + + // Look at 5 predictions + foreach (var p in predictions.Take(5)) + Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}"); + + // Expected output: + // Label: 0.985, Prediction: 0.954 + // Label: 0.155, Prediction: 0.103 + // Label: 0.515, Prediction: 0.450 + // Label: 0.566, Prediction: 0.515 + // Label: 0.096, Prediction: 0.078 + + // Evaluate the overall metrics + var metrics = mlContext.Regression.Evaluate(transformedTestData); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Mean Absolute Error: 0.05 + // Mean Squared Error: 0.00 + // Root Mean Squared Error: 0.07 + // RSquared: 0.95 + } + + private static IEnumerable GenerateRandomDataPoints(int count, int seed=0) + { + var random = new Random(seed); + float randomFloat() => (float)random.NextDouble(); + for (int i = 0; i < count; i++) + { + var label = randomFloat(); + yield return new DataPoint + { + Label = label, + // Create random features that are correlated with label. + Features = Enumerable.Repeat(label, 50).Select(x => x + randomFloat()).ToArray() + }; + } + } + + // Example with label and 50 feature values. A data set is a collection of such examples. + private class DataPoint + { + public float Label { get; set; } + [VectorType(50)] + public float[] Features { get; set; } + } + + // Class used to capture predictions. + private class Prediction + { + // Original label. + public float Label { get; set; } + // Predicted score from the trainer. + public float Score { get; set; } + } + } +} \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTreeWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTreeWithOptions.cs new file mode 100644 index 0000000000..594c80868c --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/FastTreeWithOptions.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers.FastTree; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.Regression +{ + public static class FastTreeWithOptions + { + // This example requires installation of additional NuGet package + // Microsoft.ML.FastTree. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Create a list of training examples. + var examples = GenerateRandomDataPoints(1000); + + // Convert the examples list to an IDataView object, which is consumable by ML.NET API. + var trainingData = mlContext.Data.LoadFromEnumerable(examples); + + // Define trainer options. + var options = new FastTreeRegressionTrainer.Options + { + // Use L2Norm for early stopping. + EarlyStoppingMetric = EarlyStoppingMetric.L2Norm, + // Create a simpler model by penalizing usage of new features. + FeatureFirstUsePenalty = 0.1, + // Reduce the number of trees to 50. + NumberOfTrees = 50 + }; + + // Define the trainer. + var pipeline = mlContext.Regression.Trainers.FastTree(options); + + // Train the model. + var model = pipeline.Fit(trainingData); + + // Create testing examples. Use different random seed to make it different from training data. + var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); + + // Run the model on test data set. + var transformedTestData = model.Transform(testData); + + // Convert IDataView object to a list. + var predictions = mlContext.Data.CreateEnumerable(transformedTestData, reuseRowObject: false).ToList(); + + // Look at 5 predictions + foreach (var p in predictions.Take(5)) + Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}"); + + // Expected output: + // Label: 0.985, Prediction: 0.950 + // Label: 0.155, Prediction: 0.111 + // Label: 0.515, Prediction: 0.475 + // Label: 0.566, Prediction: 0.575 + // Label: 0.096, Prediction: 0.093 + + // Evaluate the overall metrics + var metrics = mlContext.Regression.Evaluate(transformedTestData); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Mean Absolute Error: 0.05 + // Mean Squared Error: 0.00 + // Root Mean Squared Error: 0.06 + // RSquared: 0.95 + } + + private static IEnumerable GenerateRandomDataPoints(int count, int seed=0) + { + var random = new Random(seed); + float randomFloat() => (float)random.NextDouble(); + for (int i = 0; i < count; i++) + { + var label = randomFloat(); + yield return new DataPoint + { + Label = label, + // Create random features that are correlated with label. + Features = Enumerable.Repeat(label, 50).Select(x => x + randomFloat()).ToArray() + }; + } + } + + // Example with label and 50 feature values. A data set is a collection of such examples. + private class DataPoint + { + public float Label { get; set; } + [VectorType(50)] + public float[] Features { get; set; } + } + + // Class used to capture predictions. + private class Prediction + { + // Original label. + public float Label { get; set; } + // Predicted score from the trainer. + public float Score { get; set; } + } + } +} \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/Gam.cs similarity index 72% rename from docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs rename to docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/Gam.cs index 5ab8fd3ad6..67e57697f2 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/Gam.cs @@ -2,10 +2,12 @@ using System.Linq; using Microsoft.ML.SamplesUtils; -namespace Microsoft.ML.Samples.Dynamic +namespace Microsoft.ML.Samples.Dynamic.Trainers.Regression { - public static class GeneralizedAdditiveModelsRegression + public static class Gam { + // This example requires installation of additional NuGet package + // Microsoft.ML.FastTree. public static void Example() { // Create a new context for ML.NET operations. It can be used for exception tracking and logging, @@ -15,31 +17,34 @@ public static void Example() // Read the Housing regression dataset var data = DatasetUtils.LoadHousingRegressionDataset(mlContext); - // Create a pipeline - // Concatenate the features to create a Feature vector. - // Then append a gam regressor, setting the "MedianHomeValue" column as the label of the dataset, - // the "Features" column produced by concatenation as the features column, - // and use a small number of bins to make it easy to visualize in the console window. - // For real appplications, it is recommended to start with the default number of bins. var labelName = "MedianHomeValue"; var featureNames = data.Schema .Select(column => column.Name) // Get the column names .Where(name => name != labelName) // Drop the Label .ToArray(); - var pipeline = mlContext.Transforms.Concatenate("Features", featureNames) - .Append(mlContext.Regression.Trainers.Gam( - labelColumnName: labelName, featureColumnName: "Features", maximumBinCountPerFeature: 16)); - var fitPipeline = pipeline.Fit(data); - // Extract the model from the pipeline - var gamModel = fitPipeline.LastTransformer.Model; + // Create a pipeline. + var pipeline = + // Concatenate the features to create a Feature vector. + mlContext.Transforms.Concatenate("Features", featureNames) + // Append a GAM regression trainer, setting the "MedianHomeValue" column as the label of the dataset, + // the "Features" column produced by concatenation as the features column, + // and use a small number of bins to make it easy to visualize in the console window. + // For real applications, it is recommended to start with the default number of bins. + .Append(mlContext.Regression.Trainers.Gam(labelColumnName: labelName, featureColumnName: "Features", maximumBinCountPerFeature: 16)); - // Now investigate the properties of the Generalized Additive Model: The intercept and shape functions. + // Train the pipeline. + var trainedPipeline = pipeline.Fit(data); - // The intercept for the GAM models represent the average prediction for the training data - var intercept = gamModel.Bias; - // Expected output: Average predicted cost: 22.53 - Console.WriteLine($"Average predicted cost: {intercept:0.00}"); + // Extract the model from the pipeline. + var gamModel = trainedPipeline.LastTransformer.Model; + + // Now investigate the bias and shape functions of the GAM model. + // The bias represents the average prediction for the training data. + Console.WriteLine($"Average predicted cost: {gamModel.Bias:0.00}"); + + // Expected output: + // Average predicted cost: 22.53 // Let's take a look at the features that the model built. Similar to a linear model, we have // one response per feature. Unlike a linear model, this response is a function instead of a line. @@ -48,19 +53,23 @@ public static void Example() // Let's investigate the TeacherRatio variable. This is the ratio of students to teachers, // so the higher it is, the more students a teacher has in their classroom. - // First, let's get the index of the variable we want to look at + // First, let's get the index of the variable we want to look at. var studentTeacherRatioIndex = featureNames.ToList().FindIndex(str => str.Equals("TeacherRatio")); - // Next, let's get the array of histogram bin upper bounds from the model for this feature + // Next, let's get the array of histogram bin upper bounds from the model for this feature. // For each feature, the shape function is calculated at `MaxBins` locations along the range of // values that the feature takes, and the resulting shape function can be seen as a histogram of // effects. var teacherRatioBinUpperBounds = gamModel.GetBinUpperBounds(studentTeacherRatioIndex); - // And the array of bin effects; these are the effect size for each bin + // And the array of bin effects; these are the effect size for each bin. var teacherRatioBinEffects = gamModel.GetBinEffects(studentTeacherRatioIndex); // Now, write the function to the console. The function is a set of bins, and the corresponding // function values. You can think of GAMs as building a bar-chart lookup table. + Console.WriteLine("Student-Teacher Ratio"); + for (int i = 0; i < teacherRatioBinUpperBounds.Count; i++) + Console.WriteLine($"x < {teacherRatioBinUpperBounds[i]:0.00} => {teacherRatioBinEffects[i]:0.000}"); + // Expected output: // Student-Teacher Ratio // x < 14.55 => 2.105 @@ -77,7 +86,7 @@ public static void Example() // x < 20.55 => -0.649 // x < 21.05 => -1.579 // x < ∞ => 0.318 - // + // Let's consider this output. To score a given example, we look up the first bin where the inequality // is satisfied for the feature value. We can look at the whole function to get a sense for how the // model responds to the variable on a global level. For the student-teacher-ratio variable, we can see @@ -85,19 +94,12 @@ public static void Example() // than about 18 lead to lower predictions in house value. This makes intuitive sense, as smaller class // sizes are desirable and also indicative of better-funded schools, which could make buyers likely to // pay more for the house. - // + // Another thing to notice is that these feature functions can be noisy. See student-teacher ratios > 21.05. // Common practice is to use resampling methods to estimate a confidence interval at each bin. This will // help to determine if the effect is real or just sampling noise. See for example // Tan, Caruana, Hooker, and Lou. "Distill-and-Compare: Auditing Black-Box Models Using Transparent Model // Distillation." arXiv:1710.06169." - Console.WriteLine(); - Console.WriteLine("Student-Teacher Ratio"); - for (int i = 0; i < teacherRatioBinUpperBounds.Count; i++) - { - Console.WriteLine($"x < {teacherRatioBinUpperBounds[i]:0.00} => {teacherRatioBinEffects[i]:0.000}"); - } - Console.WriteLine(); } } } diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/GamWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/GamWithOptions.cs new file mode 100644 index 0000000000..6545e27022 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/GamWithOptions.cs @@ -0,0 +1,105 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers.FastTree; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.Regression +{ + public static class GamWithOptions + { + // This example requires installation of additional NuGet package + // Microsoft.ML.FastTree. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Create a list of training examples. + var examples = GenerateRandomDataPoints(1000); + + // Convert the examples list to an IDataView object, which is consumable by ML.NET API. + var trainingData = mlContext.Data.LoadFromEnumerable(examples); + + // Define trainer options. + var options = new GamRegressionTrainer.Options + { + // The entropy (regularization) coefficient. + EntropyCoefficient = 0.3, + // Reduce the number of iterations to 50. + NumberOfIterations = 50 + }; + + // Define the trainer. + var pipeline = mlContext.Regression.Trainers.Gam(options); + + // Train the model. + var model = pipeline.Fit(trainingData); + + // Create testing examples. Use different random seed to make it different from training data. + var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); + + // Run the model on test data set. + var transformedTestData = model.Transform(testData); + + // Convert IDataView object to a list. + var predictions = mlContext.Data.CreateEnumerable(transformedTestData, reuseRowObject: false).ToList(); + + // Look at 5 predictions + foreach (var p in predictions.Take(5)) + Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}"); + + // Expected output: + // Label: 0.985, Prediction: 0.841 + // Label: 0.155, Prediction: 0.187 + // Label: 0.515, Prediction: 0.496 + // Label: 0.566, Prediction: 0.467 + // Label: 0.096, Prediction: 0.144 + + // Evaluate the overall metrics + var metrics = mlContext.Regression.Evaluate(transformedTestData); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Mean Absolute Error: 0.06 + // Mean Squared Error: 0.01 + // Root Mean Squared Error: 0.08 + // RSquared: 0.93 + } + + private static IEnumerable GenerateRandomDataPoints(int count, int seed=0) + { + var random = new Random(seed); + float randomFloat() => (float)random.NextDouble(); + for (int i = 0; i < count; i++) + { + var label = randomFloat(); + yield return new DataPoint + { + Label = label, + // Create random features that are correlated with label. + Features = Enumerable.Repeat(label, 50).Select(x => x + randomFloat()).ToArray() + }; + } + } + + // Example with label and 50 feature values. A data set is a collection of such examples. + private class DataPoint + { + public float Label { get; set; } + [VectorType(50)] + public float[] Features { get; set; } + } + + // Class used to capture predictions. + private class Prediction + { + // Original label. + public float Label { get; set; } + // Predicted score from the trainer. + public float Score { get; set; } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index 1b31bcd631..ddea3f21f8 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -524,6 +524,7 @@ public abstract class TreeOptions : TrainerInputBaseWithGroupId /// /// The fraction of features (chosen randomly) to use on each iteration. Use 0.9 if only 90% of features is needed. + /// Lower numbers help reduce over-fitting. /// [Argument(ArgumentType.AtMostOnce, HelpText = "The fraction of features (chosen randomly) to use on each iteration", ShortName = "ff")] public Double FeatureFraction = 1; diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs index 18a2737d02..5d14420148 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -50,6 +50,13 @@ public static FastTreeRegressionTrainer FastTree(this RegressionCatalog.Regressi /// /// The . /// Trainer options. + /// + /// + /// + /// + /// public static FastTreeRegressionTrainer FastTree(this RegressionCatalog.RegressionTrainers catalog, FastTreeRegressionTrainer.Options options) { @@ -188,8 +195,15 @@ public static GamBinaryTrainer Gam(this BinaryClassificationCatalog.BinaryClassi /// The number of iterations to use in learning the features. /// The maximum number of bins to use to approximate features. /// The learning rate. GAMs work best with a small learning rate. + /// + /// + /// + /// + /// public static GamRegressionTrainer Gam(this RegressionCatalog.RegressionTrainers catalog, - string labelColumnName = DefaultColumnNames.Label, + string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, int numberOfIterations = GamDefaults.NumberOfIterations, @@ -206,6 +220,13 @@ public static GamRegressionTrainer Gam(this RegressionCatalog.RegressionTrainers /// /// The . /// Trainer options. + /// + /// + /// + /// + /// public static GamRegressionTrainer Gam(this RegressionCatalog.RegressionTrainers catalog, GamRegressionTrainer.Options options) { @@ -225,6 +246,13 @@ public static GamRegressionTrainer Gam(this RegressionCatalog.RegressionTrainers /// The maximum number of leaves per decision tree. /// The minimal number of data points required to form a new tree leaf. /// The learning rate. + /// + /// + /// + /// + /// public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionCatalog.RegressionTrainers catalog, string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, @@ -244,6 +272,13 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionCatalog.Regr /// /// The . /// Trainer options. + /// + /// + /// + /// + /// public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionCatalog.RegressionTrainers catalog, FastTreeTweedieTrainer.Options options) { @@ -264,6 +299,13 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionCatalog.Regr /// The maximum number of leaves per decision tree. /// Total number of decision trees to create in the ensemble. /// The minimal number of data points required to form a new tree leaf. + /// + /// + /// + /// + /// public static FastForestRegressionTrainer FastForest(this RegressionCatalog.RegressionTrainers catalog, string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, @@ -282,6 +324,13 @@ public static FastForestRegressionTrainer FastForest(this RegressionCatalog.Regr /// /// The . /// Trainer options. + /// + /// + /// + /// + /// public static FastForestRegressionTrainer FastForest(this RegressionCatalog.RegressionTrainers catalog, FastForestRegressionTrainer.Options options) { diff --git a/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs b/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs index 4fbb489241..30aa8b59dd 100644 --- a/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs +++ b/src/Microsoft.ML.SamplesUtils/ConsoleUtils.cs @@ -55,7 +55,7 @@ public static void PrintMetrics(MulticlassClassificationMetrics metrics) public static void PrintMetrics(RegressionMetrics metrics) { Console.WriteLine($"Mean Absolute Error: {metrics.MeanAbsoluteError:F2}"); - Console.WriteLine($"Mean Square dError: {metrics.MeanSquaredError:F2}"); + Console.WriteLine($"Mean Squared Error: {metrics.MeanSquaredError:F2}"); Console.WriteLine($"Root Mean Squared Error: {metrics.RootMeanSquaredError:F2}"); Console.WriteLine($"RSquared: {metrics.RSquared:F2}"); } From fd1c700b33c3e0dcc97176faafa5dd22dc16b98e Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Tue, 19 Mar 2019 14:37:12 -0700 Subject: [PATCH 07/18] Cleanup the statistics usage API (#2048) * Refactoring the Old ModelStatistics to a class just for the modelStats, like MLR needs all the time, and lr needs when the stats are not calculated. * Separting the Bias from the weights coefficients APIs. --- src/Microsoft.ML.Core/Data/ModelLoading.cs | 1 + .../Standard/LinearModelParameters.cs | 44 +- .../LogisticRegression/LogisticRegression.cs | 19 +- .../MulticlassLogisticRegression.cs | 42 +- .../Standard/ModelStatistics.cs | 630 +++++++++++------- .../CommandTrainingLrWithStats-out.txt | 1 + .../Common/EntryPoints/ensemble-summary.txt | 4 +- ...etCore30OrNotNetCoreAndX64FactAttribute.cs | 16 + .../TrainerEstimators/LbfgsTests.cs | 176 ++++- test/data/backcompat/LrWithStats.zip | Bin 0 -> 4713 bytes test/data/backcompat/MlrWithStats.zip | Bin 0 -> 5203 bytes 11 files changed, 610 insertions(+), 323 deletions(-) create mode 100644 test/data/backcompat/LrWithStats.zip create mode 100644 test/data/backcompat/MlrWithStats.zip diff --git a/src/Microsoft.ML.Core/Data/ModelLoading.cs b/src/Microsoft.ML.Core/Data/ModelLoading.cs index 2fb7bd99ee..dfced6e610 100644 --- a/src/Microsoft.ML.Core/Data/ModelLoading.cs +++ b/src/Microsoft.ML.Core/Data/ModelLoading.cs @@ -167,6 +167,7 @@ private static bool TryLoadModel(IHostEnvironment env, out TRes resu // TryLoadModelCore should rewind on failure. Contracts.Assert(fp == ent.Stream.Position); + return false; } diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs b/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs index b447207c97..6ddc44d2d1 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs @@ -141,7 +141,7 @@ private protected LinearModelParameters(IHostEnvironment env, string name, Model // int: number of weights // Float[]: weights // bool: has model stats - // (Conditional) LinearModelStatistics: stats + // (Conditional) LinearModelParameterStatistics: stats Bias = ctx.Reader.ReadFloat(); Host.CheckDecode(FloatUtils.IsFinite(Bias)); @@ -199,7 +199,7 @@ private protected override void SaveCore(ModelSaveContext ctx) // int: number of weights // Float[]: weights // bool: has model stats - // (Conditional) LinearModelStatistics: stats + // (Conditional) LinearModelParameterStatistics: stats ctx.Writer.Write(Bias); ctx.Writer.Write(Weight.Length); @@ -413,9 +413,7 @@ public sealed partial class LinearBinaryModelParameters : LinearModelParameters, internal const string RegistrationName = "LinearBinaryPredictor"; private const string ModelStatsSubModelFilename = "ModelStats"; - private readonly LinearModelStatistics _stats; - - public LinearModelStatistics Statistics { get { return _stats; } } + public readonly ModelStatisticsBase Statistics; private static VersionInfo GetVersionInfo() { @@ -438,11 +436,11 @@ private static VersionInfo GetVersionInfo() /// of the i-th feature. Note that this will take ownership of the . /// The bias added to every output score. /// - internal LinearBinaryModelParameters(IHostEnvironment env, in VBuffer weights, float bias, LinearModelStatistics stats = null) + internal LinearBinaryModelParameters(IHostEnvironment env, in VBuffer weights, float bias, ModelStatisticsBase stats = null) : base(env, RegistrationName, in weights, bias) { Contracts.AssertValueOrNull(stats); - _stats = stats; + Statistics = stats; } private LinearBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx) @@ -454,9 +452,19 @@ private LinearBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx) // *** Binary format *** // (Base class) - // LinearModelStatistics: model statistics (optional, in a separate stream) - - ctx.LoadModelOrNull(Host, out _stats, ModelStatsSubModelFilename); + // LinearModelParameterStatistics: model statistics (optional, in a separate stream) + try + { + LinearModelParameterStatistics stats; + ctx.LoadModelOrNull(Host, out stats, ModelStatsSubModelFilename); + Statistics = stats; + } + catch (Exception) + { + ModelStatisticsBase stats; + ctx.LoadModelOrNull(Host, out stats, ModelStatsSubModelFilename); + Statistics = stats; + } } private static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) @@ -478,14 +486,14 @@ private protected override void SaveCore(ModelSaveContext ctx) { // *** Binary format *** // (Base class) - // LinearModelStatistics: model statistics (optional, in a separate stream) + // LinearModelParameterStatistics: model statistics (optional, in a separate stream) base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); - Contracts.AssertValueOrNull(_stats); - if (_stats != null) - ctx.SaveModel(_stats, ModelStatsSubModelFilename); + Contracts.AssertValueOrNull(Statistics); + if (Statistics != null) + ctx.SaveModel(Statistics, ModelStatsSubModelFilename); } private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification; @@ -510,7 +518,7 @@ private protected override void SaveSummary(TextWriter writer, RoleMappedSchema writer.WriteLine(LinearPredictorUtils.LinearModelAsText("Linear Binary Classification Predictor", null, null, in weights, Bias, schema)); - _stats?.SaveText(writer, this, schema.Feature.Value, 20); + Statistics?.SaveText(writer, schema.Feature.Value, 20); } /// @@ -521,17 +529,17 @@ IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKe var weights = Weight; List> results = new List>(); LinearPredictorUtils.SaveLinearModelWeightsInKeyValuePairs(in weights, Bias, schema, results); - _stats?.SaveSummaryInKeyValuePairs(this, schema.Feature.Value, int.MaxValue, results); + Statistics?.SaveSummaryInKeyValuePairs(schema.Feature.Value, int.MaxValue, results); return results; } private protected override DataViewRow GetStatsIRowOrNull(RoleMappedSchema schema) { - if (_stats == null) + if (Statistics == null) return null; var names = default(VBuffer>); AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names); - var meta = _stats.MakeStatisticsMetadata(this, schema, in names); + var meta = Statistics.MakeStatisticsMetadata(schema, in names); return AnnotationUtils.AnnotationsAsRow(meta); } diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LogisticRegression.cs index 2d5b6634c4..3795fbafb7 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LogisticRegression.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LogisticRegression.cs @@ -61,7 +61,7 @@ public sealed class Options : OptionsBase } private double _posWeight; - private LinearModelStatistics _stats; + private ModelStatisticsBase _stats; /// /// Initializes a new instance of @@ -250,12 +250,12 @@ private protected override void ComputeTrainingStatistics(IChannel ch, FloatLabe // Compute the standard error of coefficients. long hessianDimension = (long)numParams * (numParams + 1) / 2; - if (hessianDimension > int.MaxValue) + if (hessianDimension > int.MaxValue || LbfgsTrainerOptions.ComputeStandardDeviation == null) { - ch.Warning("The number of parameter is too large. Cannot hold the variance-covariance matrix in memory. " + + ch.Warning("The number of parameters is too large. Cannot hold the variance-covariance matrix in memory. " + "Skipping computation of standard errors and z-statistics of coefficients. Consider choosing a larger L1 regularizer" + "to reduce the number of parameters."); - _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance); + _stats = new ModelStatisticsBase(Host, NumGoodRows, numParams, deviance, nullDeviance); return; } @@ -354,13 +354,10 @@ private protected override void ComputeTrainingStatistics(IChannel ch, FloatLabe } } - if (LbfgsTrainerOptions.ComputeStandardDeviation == null) - _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance); - else - { - var std = LbfgsTrainerOptions.ComputeStandardDeviation.ComputeStandardDeviation(hessian, weightIndices, numParams, CurrentWeights.Length, ch, L2Weight); - _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance, std); - } + VBuffer weightsOnly = default(VBuffer); + CurrentWeights.CopyTo(ref weightsOnly, 1, CurrentWeights.Length - 1); + var std = LbfgsTrainerOptions.ComputeStandardDeviation.ComputeStandardDeviation(hessian, weightIndices, numParams, CurrentWeights.Length, ch, L2Weight); + _stats = new LinearModelParameterStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance, std, weightsOnly, bias); } private protected override void ProcessPriorDistribution(float label, float weight) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 8487ab6b79..f56bc329d2 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -66,7 +66,7 @@ public sealed class Options : OptionsBase // After training, it stores the total weights of training examples in each class. private Double[] _prior; - private LinearModelStatistics _stats; + private ModelStatisticsBase _stats; private protected override int ClassCount => _numClasses; @@ -300,7 +300,7 @@ private protected override void ComputeTrainingStatistics(IChannel ch, FloatLabe ch.Info("AIC: \t{0}", 2 * numParams + deviance); // REVIEW: Figure out how to compute the statistics for the coefficients. - _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance); + _stats = new ModelStatisticsBase(Host, NumGoodRows, numParams, deviance, nullDeviance); } private protected override void ProcessPriorDistribution(float label, float weight) @@ -376,7 +376,7 @@ private static VersionInfo GetVersionInfo() private readonly float[] _biases; private readonly VBuffer[] _weights; - private readonly LinearModelStatistics _stats; + public readonly ModelStatisticsBase Statistics; // This stores the _weights matrix in dense format for performance. // It is used to make efficient predictions when the instance is sparse, so we get @@ -395,7 +395,7 @@ private static VersionInfo GetVersionInfo() bool ICanSavePfa.CanSavePfa => true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; - internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, in VBuffer weights, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null) + internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, in VBuffer weights, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null) : base(env, RegistrationName) { Contracts.Assert(weights.Length == numClasses + numClasses * numFeatures); @@ -424,7 +424,7 @@ internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, in VB _labelNames = labelNames; Contracts.AssertValueOrNull(stats); - _stats = stats; + Statistics = stats; } /// @@ -438,7 +438,7 @@ internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, in VB /// The length of the feature vector. /// The optional label names. If specified not null, it should have the same length as . /// The model statistics. - internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, VBuffer[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null) + internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, VBuffer[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null) : base(env, RegistrationName) { Contracts.CheckValue(weights, nameof(weights)); @@ -468,7 +468,7 @@ internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, VBuff _labelNames = labelNames; Contracts.AssertValueOrNull(stats); - _stats = stats; + Statistics = stats; } private MulticlassLogisticRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx) @@ -487,7 +487,7 @@ private MulticlassLogisticRegressionModelParameters(IHostEnvironment env, ModelL // int: total number of non-zero weights (same as number of column indices if sparse, num of classes * num of features if dense) // float[]: non-zero weights // int[]: Id of label names (optional, in a separate stream) - // LinearModelStatistics: model statistics (optional, in a separate stream) + // ModelStatisticsBase: model statistics (optional, in a separate stream) _numFeatures = ctx.Reader.ReadInt32(); Host.CheckDecode(_numFeatures >= 1); @@ -552,7 +552,12 @@ private MulticlassLogisticRegressionModelParameters(IHostEnvironment env, ModelL if (ctx.TryLoadBinaryStream(LabelNamesSubModelFilename, r => labelNames = LoadLabelNames(ctx, r))) _labelNames = labelNames; - ctx.LoadModelOrNull(Host, out _stats, ModelStatsSubModelFilename); + // backwards compatibility:MLR used to serialize a LinearModelSStatistics object, before there existed two separate classes + // for ModelStatisticsBase and LinearModelParameterStatistics. + // It always only populated only the fields now found on ModelStatisticsBase. + ModelStatisticsBase stats; + ctx.LoadModelOrNull(Host, out stats, ModelStatsSubModelFilename); + Statistics = stats; } private static MulticlassLogisticRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) @@ -588,7 +593,7 @@ private protected override void SaveCore(ModelSaveContext ctx) // float[]: non-zero weights // bool: whether label names are present // int[]: Id of label names (optional, in a separate stream) - // LinearModelStatistics: model statistics (optional, in a separate stream) + // LinearModelParameterStatistics: model statistics (optional, in a separate stream) ctx.Writer.Write(_numFeatures); ctx.Writer.Write(_numClasses); @@ -688,9 +693,9 @@ private protected override void SaveCore(ModelSaveContext ctx) if (_labelNames != null) ctx.SaveBinaryStream(LabelNamesSubModelFilename, w => SaveLabelNames(ctx, w)); - Contracts.AssertValueOrNull(_stats); - if (_stats != null) - ctx.SaveModel(_stats, ModelStatsSubModelFilename); + Contracts.AssertValueOrNull(Statistics); + if (Statistics != null) + ctx.SaveModel(Statistics, ModelStatsSubModelFilename); } // REVIEW: Destroy. @@ -787,8 +792,8 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) writer.WriteLine("\t{0}\t{1}", namedValues.Key, (float)namedValues.Value); } - if (_stats != null) - _stats.SaveText(writer, null, schema.Feature.Value, 20); + if (Statistics != null) + Statistics.SaveText(writer, schema.Feature.Value, 20); } /// @@ -992,11 +997,12 @@ DataViewRow ICanGetSummaryAsIRow.GetSummaryIRowOrNull(RoleMappedSchema schema) DataViewRow ICanGetSummaryAsIRow.GetStatsIRowOrNull(RoleMappedSchema schema) { - if (_stats == null) + if (Statistics == null) return null; - VBuffer> names = default; - var meta = _stats.MakeStatisticsMetadata(null, schema, in names); + var names = default(VBuffer>); + AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, _weights.Length, ref names); + var meta = Statistics.MakeStatisticsMetadata(schema, in names); return AnnotationUtils.AnnotationsAsRow(meta); } } diff --git a/src/Microsoft.ML.StandardTrainers/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardTrainers/Standard/ModelStatistics.cs index d070d91ce6..9f83d5b6b8 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/ModelStatistics.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/ModelStatistics.cs @@ -15,27 +15,51 @@ using Microsoft.ML.Trainers; // This is for deserialization from a model repository. -[assembly: LoadableClass(typeof(LinearModelStatistics), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(ModelStatisticsBase), typeof(LinearModelParameterStatistics), null, typeof(SignatureLoadModel), "Linear Model Statistics", - LinearModelStatistics.LoaderSignature)] + LinearModelParameterStatistics.LoaderSignature)] + +// This is for deserialization from a model repository. +[assembly: LoadableClass(typeof(ModelStatisticsBase), typeof(ModelStatisticsBase), null, typeof(SignatureLoadModel), + "Model Statistics", + ModelStatisticsBase.LoaderSignature)] namespace Microsoft.ML.Trainers { /// - /// Represents a coefficient statistics object. + /// Represents a coefficient statistics object containing statistics about the calculated model parameters. /// - public readonly struct CoefficientStatistics + public sealed class CoefficientStatistics { - public readonly string Name; + /// + /// The model parameter (bias of weight) for which the statistics are generated. + /// public readonly float Estimate; + + /// + /// The standard deviation of the estimate of this model parameter (bias of weight). + /// public readonly float StandardError; + + /// + /// The standard score of the estimate of this model parameter (bias of weight). + /// Quantifies by how much the estimate is above or below the mean. + /// public readonly float ZScore; + + /// + /// The probability value of the estimate of this model parameter (bias of weight). + /// public readonly float PValue; - internal CoefficientStatistics(string name, float estimate, float stdError, float zScore, float pValue) + /// + /// The index of the feature, in the Features vector, to which this model parameter (bias of weight) corresponds to. + /// + public readonly int Index; + + internal CoefficientStatistics(int featureIndex, float estimate, float stdError, float zScore, float pValue) { - Contracts.AssertNonEmpty(name); - Name = name; + Index = featureIndex; Estimate = estimate; StandardError = stdError; ZScore = zScore; @@ -43,221 +67,315 @@ internal CoefficientStatistics(string name, float estimate, float stdError, floa } } - // REVIEW: Make this class a loadable class and implement ICanSaveModel. // REVIEW: Reconcile with the stats in OLS learner. /// /// The statistics for linear predictor. /// - public sealed class LinearModelStatistics : ICanSaveModel + public class ModelStatisticsBase : ICanSaveModel { - internal const string LoaderSignature = "LinearModelStats"; - - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "LMODSTAT", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, - verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(LinearModelStatistics).Assembly.FullName); - } - - private readonly IHostEnvironment _env; + private protected IHostEnvironment Env; // Total count of training examples used to train the model. - private readonly long _trainingExampleCount; + public readonly long TrainingExampleCount; // The deviance of this model. - private readonly float _deviance; + public readonly float Deviance; // The deviance of the null hypothesis. - private readonly float _nullDeviance; + public readonly float NullDeviance; // Total count of parameters. - private readonly int _paramCount; + public readonly int ParametersCount; - // The standard errors of coefficients, including the bias. - // The standard error of bias is placed at index zero. - // It could be null when there are too many non-zero weights so that - // the memory is insufficient to hold the Hessian matrix necessary for the computation - // of the variance-covariance matrix. - private readonly VBuffer? _coeffStdError; - - public long TrainingExampleCount => _trainingExampleCount; + internal const string LoaderSignature = "ModelStats"; - public float Deviance => _deviance; + internal ModelStatisticsBase(IHostEnvironment env, long trainingExampleCount, int paramCount, float deviance, float nullDeviance) + { + Contracts.CheckValue(env, nameof(env)); + Env = env; - public float NullDeviance => _nullDeviance; + Env.Assert(trainingExampleCount > 0); + Env.Assert(paramCount > 0); - public int ParametersCount => _paramCount; + ParametersCount = paramCount; + TrainingExampleCount = trainingExampleCount; + Deviance = deviance; + NullDeviance = nullDeviance; + } - internal LinearModelStatistics(IHostEnvironment env, long trainingExampleCount, int paramCount, float deviance, float nullDeviance) + internal ModelStatisticsBase(IHostEnvironment env, ModelLoadContext ctx) { - Contracts.AssertValue(env); - env.Assert(trainingExampleCount > 0); - env.Assert(paramCount > 0); - _env = env; - _paramCount = paramCount; - _trainingExampleCount = trainingExampleCount; - _deviance = deviance; - _nullDeviance = nullDeviance; + Contracts.CheckValue(env, nameof(env)); + Env = env; + Env.AssertValue(ctx); + + // *** Binary Format *** + // int: count of parameters + // long: count of training examples + // float: deviance + // float: null deviance + + ParametersCount = ctx.Reader.ReadInt32(); + Env.CheckDecode(ParametersCount > 0); + + TrainingExampleCount = ctx.Reader.ReadInt64(); + Env.CheckDecode(TrainingExampleCount > 0); + + Deviance = ctx.Reader.ReadFloat(); + NullDeviance = ctx.Reader.ReadFloat(); } - internal LinearModelStatistics(IHostEnvironment env, long trainingExampleCount, int paramCount, float deviance, float nullDeviance, in VBuffer coeffStdError) - : this(env, trainingExampleCount, paramCount, deviance, nullDeviance) + void ICanSaveModel.Save(ModelSaveContext ctx) { - _env.Assert(coeffStdError.GetValues().Length == _paramCount); - _coeffStdError = coeffStdError; + Contracts.AssertValue(Env); + Env.CheckValue(ctx, nameof(ctx)); + SaveCore(ctx); + ctx.SetVersionInfo(GetVersionInfo()); } - internal LinearModelStatistics(IHostEnvironment env, ModelLoadContext ctx) + private protected virtual void SaveCore(ModelSaveContext ctx) { - Contracts.CheckValue(env, nameof(env)); - _env = env; - _env.AssertValue(ctx); - // *** Binary Format *** // int: count of parameters // long: count of training examples // float: deviance // float: null deviance - // bool: whether standard error is included - // (Conditional) float[_paramCount]: values of std errors of coefficients - // (Conditional) int: length of std errors of coefficients - // (Conditional) int[_paramCount]: indices of std errors of coefficients - _paramCount = ctx.Reader.ReadInt32(); - _env.CheckDecode(_paramCount > 0); + Env.Assert(ParametersCount > 0); + ctx.Writer.Write(ParametersCount); + + Env.Assert(TrainingExampleCount > 0); + ctx.Writer.Write(TrainingExampleCount); + + ctx.Writer.Write(Deviance); + ctx.Writer.Write(NullDeviance); + } + + internal virtual void SaveText(TextWriter writer, DataViewSchema.Column featureColumn, int paramCountCap) + { + Contracts.AssertValue(Env); + Env.CheckValue(writer, nameof(writer)); + + writer.WriteLine(); + writer.WriteLine("*** MODEL STATISTICS SUMMARY *** "); + writer.WriteLine("Count of training examples:\t{0}", TrainingExampleCount); + writer.WriteLine("Residual Deviance: \t{0}", Deviance); + writer.WriteLine("Null Deviance: \t{0}", NullDeviance); + writer.WriteLine("AIC: \t{0}", 2 * ParametersCount + Deviance); + } + + /// + /// Support method for linear models and . + /// + internal virtual void SaveSummaryInKeyValuePairs(DataViewSchema.Column featureColumn, int paramCountCap, List> resultCollection) + { + Contracts.AssertValue(Env); + Env.AssertValue(resultCollection); + + resultCollection.Add(new KeyValuePair("Count of training examples", TrainingExampleCount)); + resultCollection.Add(new KeyValuePair("Residual Deviance", Deviance)); + resultCollection.Add(new KeyValuePair("Null Deviance", NullDeviance)); + resultCollection.Add(new KeyValuePair("AIC", 2 * ParametersCount + Deviance)); + } + + internal virtual DataViewSchema.Annotations MakeStatisticsMetadata(RoleMappedSchema schema, in VBuffer> names) + { + var builder = new DataViewSchema.Annotations.Builder(); + + builder.AddPrimitiveValue("Count of training examples", NumberDataViewType.Int64, TrainingExampleCount); + builder.AddPrimitiveValue("Residual Deviance", NumberDataViewType.Single, Deviance); + builder.AddPrimitiveValue("Null Deviance", NumberDataViewType.Single, NullDeviance); + builder.AddPrimitiveValue("AIC", NumberDataViewType.Single, 2 * ParametersCount + Deviance); + + return builder.ToAnnotations(); + } - _trainingExampleCount = ctx.Reader.ReadInt64(); - _env.CheckDecode(_trainingExampleCount > 0); + private protected virtual VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "MOD STAT", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ModelStatisticsBase).Assembly.FullName); + } + } - _deviance = ctx.Reader.ReadFloat(); - _nullDeviance = ctx.Reader.ReadFloat(); + // REVIEW: Reconcile with the stats in OLS learner. + /// + /// The statistics for linear predictor. + /// + public sealed class LinearModelParameterStatistics : ModelStatisticsBase + { + internal new const string LoaderSignature = "LinearModelStats"; + + private const int CoeffStatsRefactorVersion = 0x00010002; + + private protected override VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "LMODSTAT", + // verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Refactored the stats for the parameters in the base class. + verReadableCur: 0x00010002, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LinearModelParameterStatistics).Assembly.FullName); + } - var hasStdErrors = ctx.Reader.ReadBoolean(); - if (!hasStdErrors) + // The standard errors of coefficients, including the bias. + // The standard error of bias is placed at index zero. + // It could be null when there are too many non-zero weights so that + // the memory is insufficient to hold the Hessian matrix necessary for the computation + // of the variance-covariance matrix. + private readonly VBuffer _coeffStdError; + + /// + /// The weights of the LinearModelParams trained. + /// + private readonly VBuffer _weights; + + /// + /// The bias of the LinearModelParams trained. + /// + private readonly float _bias; + + internal LinearModelParameterStatistics(IHostEnvironment env, long trainingExampleCount, int paramCount, float deviance, float nullDeviance, + in VBuffer coeffStdError, VBuffer weights, float bias) + : base(env, trainingExampleCount, paramCount, deviance, nullDeviance) + { + Env.Assert(trainingExampleCount > 0); + Env.Assert(paramCount > 0); + Env.Assert(coeffStdError.Length > 0, nameof(coeffStdError)); + Env.Assert(weights.Length > 0, nameof(weights)); + + _coeffStdError = coeffStdError; + _weights = weights; + _bias = bias; + } + + private LinearModelParameterStatistics(IHostEnvironment env, ModelLoadContext ctx) + : base(env, ctx) + { + // *** Binary format *** + // + // bool: whether standard error is included + // float[_paramCount]: values of std errors of coefficients + // int: length of std errors of coefficients + // (Conditional)int[_paramCount]: indices of std errors of coefficients + + //backwards compatibility + if (ctx.Header.ModelVerWritten < CoeffStatsRefactorVersion) { - _env.Assert(_coeffStdError == null); - return; + if (!ctx.Reader.ReadBoolean()) // this was used in the old model to denote whether there were stdErrorValues or not. + return; } - float[] stdErrorValues = ctx.Reader.ReadFloatArray(_paramCount); + float[] stdErrorValues = ctx.Reader.ReadFloatArray(ParametersCount); int length = ctx.Reader.ReadInt32(); - _env.CheckDecode(length >= _paramCount); - if (length == _paramCount) + env.CheckDecode(length >= ParametersCount); + + if (length == ParametersCount) { _coeffStdError = new VBuffer(length, stdErrorValues); - return; + } + else + { + env.Assert(length > ParametersCount); + int[] stdErrorIndices = ctx.Reader.ReadIntArray(ParametersCount); + _coeffStdError = new VBuffer(length, ParametersCount, stdErrorValues, stdErrorIndices); } - _env.Assert(length > _paramCount); - int[] stdErrorIndices = ctx.Reader.ReadIntArray(_paramCount); - _coeffStdError = new VBuffer(length, _paramCount, stdErrorValues, stdErrorIndices); - } + //read the bias + _bias = ctx.Reader.ReadFloat(); - internal static LinearModelStatistics Create(IHostEnvironment env, ModelLoadContext ctx) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); - return new LinearModelStatistics(env, ctx); - } + //read the weights + bool isWeightsDense = ctx.Reader.ReadBoolByte(); + var weightsLength = ctx.Reader.ReadInt32(); + var weightsValues = ctx.Reader.ReadFloatArray(weightsLength); - void ICanSaveModel.Save(ModelSaveContext ctx) - { - Contracts.AssertValue(_env); - _env.CheckValue(ctx, nameof(ctx)); - SaveCore(ctx); - ctx.SetVersionInfo(GetVersionInfo()); + if (isWeightsDense) + { + _weights = new VBuffer(weightsLength, weightsValues); + } + else + { + int[] weightsIndices = ctx.Reader.ReadIntArray(weightsLength); + _weights = new VBuffer(weightsLength, weightsLength, stdErrorValues, weightsIndices); + } } - private void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { // *** Binary Format *** - // int: count of parameters - // long: count of training examples - // float: deviance - // float: null deviance - // bool: whether standard error is included + // // (Conditional) float[_paramCount]: values of std errors of coefficients // (Conditional) int: length of std errors of coefficients // (Conditional) int[_paramCount]: indices of std errors of coefficients - _env.Assert(_paramCount > 0); - ctx.Writer.Write(_paramCount); - - _env.Assert(_trainingExampleCount > 0); - ctx.Writer.Write(_trainingExampleCount); - - ctx.Writer.Write(_deviance); - ctx.Writer.Write(_nullDeviance); - - bool hasStdErrors = _coeffStdError.HasValue; - ctx.Writer.Write(hasStdErrors); - if (!hasStdErrors) - return; + base.SaveCore(ctx); - var coeffStdErrorValues = _coeffStdError.Value.GetValues(); - _env.Assert(coeffStdErrorValues.Length == _paramCount); + var coeffStdErrorValues = _coeffStdError.GetValues(); + Env.Assert(coeffStdErrorValues.Length == ParametersCount); ctx.Writer.WriteSinglesNoCount(coeffStdErrorValues); - ctx.Writer.Write(_coeffStdError.Value.Length); - if (_coeffStdError.Value.IsDense) - return; - - ctx.Writer.WriteIntsNoCount(_coeffStdError.Value.GetIndices()); + ctx.Writer.Write(_coeffStdError.Length); + if (!_coeffStdError.IsDense) + ctx.Writer.WriteIntsNoCount(_coeffStdError.GetIndices()); + + //save the bias + ctx.Writer.Write(_bias); + + //save the weights + ctx.Writer.WriteBoolByte(_weights.IsDense); + ctx.Writer.Write(_weights.Length); + ctx.Writer.WriteSinglesNoCount(_weights.GetValues()); + if (!_weights.IsDense) + ctx.Writer.WriteIntsNoCount(_coeffStdError.GetIndices()); } /// - /// Computes the standart deviation, Z-Score and p-Value. + /// Computes the standart deviation, Z-Score and p-Value for the value being passed as the bias. /// - public static bool TryGetBiasStatistics(LinearModelStatistics stats, float bias, out float stdError, out float zScore, out float pValue) + public CoefficientStatistics GetBiasStatisticsForValue(float bias) { - if (!stats._coeffStdError.HasValue) - { - stdError = 0; - zScore = 0; - pValue = 0; - return false; - } - - const Double sqrt2 = 1.41421356237; // Math.Sqrt(2); - stdError = stats._coeffStdError.Value.GetValues()[0]; - Contracts.Assert(stdError == stats._coeffStdError.Value.GetItemOrDefault(0)); - zScore = bias / stdError; - pValue = 1.0f - (float)ProbabilityFunctions.Erf(Math.Abs(zScore / sqrt2)); - return true; + const double sqrt2 = 1.41421356237; // Math.Sqrt(2); + var stdError = _coeffStdError.GetValues()[0]; + Contracts.Assert(stdError == _coeffStdError.GetItemOrDefault(0)); + var zScore = bias / stdError; + var pValue = 1.0f - (float)ProbabilityFunctions.Erf(Math.Abs(zScore / sqrt2)); + + //int feature index, float estimate, float stdError, float zScore, float pValue + return new CoefficientStatistics(0, bias, stdError, zScore, pValue); } - private static void GetUnorderedCoefficientStatistics(LinearModelStatistics stats, in VBuffer weights, in VBuffer> names, + /// + /// Computes the standart deviation, Z-Score and p-Value for the calculated bias. + /// + public CoefficientStatistics GetBiasStatistics() => GetBiasStatisticsForValue(_bias); + + private void GetUnorderedCoefficientStatistics(in VBuffer> names, ref VBuffer estimate, ref VBuffer stdErr, ref VBuffer zScore, ref VBuffer pValue, out ValueGetter>> getSlotNames) { - if (!stats._coeffStdError.HasValue) - { - getSlotNames = null; - return; - } - - Contracts.Assert(stats._coeffStdError.Value.Length == weights.Length + 1); + Contracts.Assert(_coeffStdError.Length == _weights.Length + 1); - var statisticsCount = stats.ParametersCount - 1; + var statisticsCount = ParametersCount - 1; var estimateEditor = VBufferEditor.Create(ref estimate, statisticsCount); var stdErrorEditor = VBufferEditor.Create(ref stdErr, statisticsCount); var zScoreEditor = VBufferEditor.Create(ref zScore, statisticsCount); var pValueEditor = VBufferEditor.Create(ref pValue, statisticsCount); - const Double sqrt2 = 1.41421356237; // Math.Sqrt(2); + const double sqrt2 = 1.41421356237; // Math.Sqrt(2); - bool denseStdError = stats._coeffStdError.Value.IsDense; - ReadOnlySpan stdErrorIndices = stats._coeffStdError.Value.GetIndices(); - ReadOnlySpan coeffStdErrorValues = stats._coeffStdError.Value.GetValues(); - for (int i = 1; i < stats.ParametersCount; i++) + bool denseStdError = _coeffStdError.IsDense; + ReadOnlySpan stdErrorIndices = _coeffStdError.GetIndices(); + ReadOnlySpan coeffStdErrorValues = _coeffStdError.GetValues(); + for (int i = 1; i < ParametersCount; i++) { int wi = denseStdError ? i - 1 : stdErrorIndices[i] - 1; - Contracts.Assert(0 <= wi && wi < weights.Length); - var weight = estimateEditor.Values[i - 1] = weights.GetItemOrDefault(wi); + Contracts.Assert(0 <= wi && wi < _weights.Length); + var weight = estimateEditor.Values[i - 1] = _weights.GetItemOrDefault(wi); var stdError = stdErrorEditor.Values[wi] = coeffStdErrorValues[i]; zScoreEditor.Values[i - 1] = weight / stdError; pValueEditor.Values[i - 1] = 1 - (float)ProbabilityFunctions.Erf(Math.Abs(zScoreEditor.Values[i - 1] / sqrt2)); @@ -273,7 +391,7 @@ private static void GetUnorderedCoefficientStatistics(LinearModelStatistics stat (ref VBuffer> dst) => { var editor = VBufferEditor.Create(ref dst, statisticsCount); - ReadOnlySpan stdErrorIndices2 = stats._coeffStdError.Value.GetIndices(); + ReadOnlySpan stdErrorIndices2 = _coeffStdError.GetIndices(); for (int i = 1; i <= statisticsCount; i++) { int wi = denseStdError ? i - 1 : stdErrorIndices2[i] - 1; @@ -283,100 +401,124 @@ private static void GetUnorderedCoefficientStatistics(LinearModelStatistics stat }; } - private List GetUnorderedCoefficientStatistics(LinearBinaryModelParameters parent, DataViewSchema.Column featureColumn) + private List GetUnorderedCoefficientStatistics() { - Contracts.AssertValue(_env); - _env.CheckValue(parent, nameof(parent)); + Contracts.AssertValue(Env); - if (!_coeffStdError.HasValue) - return new List(); + Env.Assert(_coeffStdError.Length == _weights.Length + 1); - var weights = parent.Weights as IReadOnlyList; - _env.Assert(_paramCount == 1 || weights != null); - _env.Assert(_coeffStdError.Value.Length == weights.Count + 1); + ReadOnlySpan stdErrorValues = _coeffStdError.GetValues(); + const Double sqrt2 = 1.41421356237; // Math.Sqrt(2); + List result = new List(ParametersCount - 1); + bool denseStdError = _coeffStdError.IsDense; + ReadOnlySpan stdErrorIndices = _coeffStdError.GetIndices(); + float[] zScores = new float[ParametersCount - 1]; + for (int i = 1; i < ParametersCount; i++) //skip the bias term + { + int wi = denseStdError ? i - 1 : stdErrorIndices[i] - 1; + Env.Assert(0 <= wi && wi < _weights.Length); + var weight = _weights.GetItemOrDefault(wi); + var stdError = stdErrorValues[i]; + var zScore = zScores[i - 1] = weight / stdError; + var pValue = 1 - (float)ProbabilityFunctions.Erf(Math.Abs(zScore / sqrt2)); + result.Add(new CoefficientStatistics(wi, weight, stdError, zScore, pValue)); + } + return result; + } + + private string[] GetFeatureNames(DataViewSchema.Column featureColumn) + { var names = default(VBuffer>); featureColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref names); - _env.Assert(names.Length > 0, "FeatureColumnName has no metadata."); + Env.Assert(names.Length > 0, "FeatureColumnName has no metadata."); - ReadOnlySpan stdErrorValues = _coeffStdError.Value.GetValues(); - const Double sqrt2 = 1.41421356237; // Math.Sqrt(2); + bool denseStdError = _coeffStdError.IsDense; + ReadOnlySpan stdErrorIndices = _coeffStdError.GetIndices(); + + var featureNames = new List(); - List result = new List(_paramCount - 1); - bool denseStdError = _coeffStdError.Value.IsDense; - ReadOnlySpan stdErrorIndices = _coeffStdError.Value.GetIndices(); - float[] zScores = new float[_paramCount - 1]; - for (int i = 1; i < _paramCount; i++) + for (int i = 1; i < ParametersCount; i++) //skip the bias term { int wi = denseStdError ? i - 1 : stdErrorIndices[i] - 1; - _env.Assert(0 <= wi && wi < weights.Count); + Env.Assert(0 <= wi && wi < _weights.Length); var name = names.GetItemOrDefault(wi).ToString(); if (string.IsNullOrEmpty(name)) name = $"f{wi}"; - var weight = weights[wi]; - var stdError = stdErrorValues[i]; - var zScore = zScores[i - 1] = weight / stdError; - var pValue = 1 - (float)ProbabilityFunctions.Erf(Math.Abs(zScore / sqrt2)); - result.Add(new CoefficientStatistics(name, weight, stdError, zScore, pValue)); + + featureNames.Add(name); } - return result; + + return featureNames.ToArray(); + } /// /// Gets the coefficient statistics as an object. /// - public CoefficientStatistics[] GetCoefficientStatistics(LinearBinaryModelParameters parent, DataViewSchema.Column featureColumn, int paramCountCap) + public CoefficientStatistics[] GetWeightsCoefficientStatistics(int paramCountCap) { - Contracts.AssertValue(_env); - _env.CheckValue(parent, nameof(parent)); - _env.CheckParam(paramCountCap >= 0, nameof(paramCountCap)); - - if (paramCountCap > _paramCount) - paramCountCap = _paramCount; - - float stdError; - float zScore; - float pValue; - var bias = parent.Bias; - if (!TryGetBiasStatistics(parent.Statistics, bias, out stdError, out zScore, out pValue)) - return null; - - var order = GetUnorderedCoefficientStatistics(parent, featureColumn).OrderByDescending(stat => stat.ZScore).Take(paramCountCap - 1); - return order.Prepend(new[] { new CoefficientStatistics("(Bias)", bias, stdError, zScore, pValue) }).ToArray(); + Env.CheckParam(paramCountCap >= 0, nameof(paramCountCap)); + + if (paramCountCap > ParametersCount) + paramCountCap = ParametersCount; + + var order = GetUnorderedCoefficientStatistics().OrderByDescending(stat => stat.ZScore).Take(paramCountCap - 1); + return order.ToArray(); } - internal void SaveText(TextWriter writer, LinearBinaryModelParameters parent, DataViewSchema.Column featureColumn, int paramCountCap) + /// + /// Saves the statistics in Text format. + /// + internal override void SaveText(TextWriter writer, DataViewSchema.Column featureColumn, int paramCountCap) { - Contracts.AssertValue(_env); - _env.CheckValue(writer, nameof(writer)); - _env.AssertValueOrNull(parent); - writer.WriteLine(); - writer.WriteLine("*** MODEL STATISTICS SUMMARY *** "); - writer.WriteLine("Count of training examples:\t{0}", _trainingExampleCount); - writer.WriteLine("Residual Deviance: \t{0}", _deviance); - writer.WriteLine("Null Deviance: \t{0}", _nullDeviance); - writer.WriteLine("AIC: \t{0}", 2 * _paramCount + _deviance); - - if (parent == null) - return; + base.SaveText(writer, featureColumn, paramCountCap); - var coeffStats = GetCoefficientStatistics(parent, featureColumn, paramCountCap); + var biasStats = GetBiasStatistics(); + var coeffStats = GetWeightsCoefficientStatistics(paramCountCap); if (coeffStats == null) return; + var featureNames = GetFeatureNames(featureColumn); + Env.Assert(featureNames.Length >= 1); + writer.WriteLine(); writer.WriteLine("Coefficients statistics:"); writer.WriteLine("Coefficient \tEstimate\tStd. Error\tz value \tPr(>|z|)"); + Func decorateProbabilityString = (float probZ) => + { + Contracts.AssertValue(Env); + Env.Assert(0 <= probZ && probZ <= 1); + if (probZ < 0.001) + return string.Format("{0} ***", probZ); + if (probZ < 0.01) + return string.Format("{0} **", probZ); + if (probZ < 0.05) + return string.Format("{0} *", probZ); + if (probZ < 0.1) + return string.Format("{0} .", probZ); + + return probZ.ToString(); + }; + + writer.WriteLine("(Bias)\t{0,-10:G7}\t{1,-10:G7}\t{2,-10:G7}\t{3}", + biasStats.Estimate, + biasStats.StandardError, + biasStats.ZScore, + decorateProbabilityString(biasStats.PValue)); + foreach (var coeffStat in coeffStats) { + Env.Assert(coeffStat.Index < featureNames.Length); + writer.WriteLine("{0,-15}\t{1,-10:G7}\t{2,-10:G7}\t{3,-10:G7}\t{4}", - coeffStat.Name, + featureNames[coeffStat.Index], coeffStat.Estimate, coeffStat.StandardError, coeffStat.ZScore, - DecorateProbabilityString(coeffStat.PValue)); + decorateProbabilityString(coeffStat.PValue)); } writer.WriteLine("---"); @@ -386,64 +528,52 @@ internal void SaveText(TextWriter writer, LinearBinaryModelParameters parent, Da /// /// Support method for linear models and . /// - internal void SaveSummaryInKeyValuePairs(LinearBinaryModelParameters parent, - DataViewSchema.Column featureColumn, int paramCountCap, List> resultCollection) + internal override void SaveSummaryInKeyValuePairs(DataViewSchema.Column featureColumn, int paramCountCap, List> resultCollection) { - Contracts.AssertValue(_env); - _env.AssertValue(resultCollection); - - resultCollection.Add(new KeyValuePair("Count of training examples", _trainingExampleCount)); - resultCollection.Add(new KeyValuePair("Residual Deviance", _deviance)); - resultCollection.Add(new KeyValuePair("Null Deviance", _nullDeviance)); - resultCollection.Add(new KeyValuePair("AIC", 2 * _paramCount + _deviance)); + Env.AssertValue(resultCollection); - if (parent == null) - return; + base.SaveSummaryInKeyValuePairs(featureColumn, paramCountCap, resultCollection); - var coeffStats = GetCoefficientStatistics(parent, featureColumn, paramCountCap); + var biasStats = GetBiasStatistics(); + var coeffStats = GetWeightsCoefficientStatistics(paramCountCap); if (coeffStats == null) return; + var featureNames = GetFeatureNames(featureColumn); + resultCollection.Add(new KeyValuePair( + "(Bias)", + new float[] { biasStats.Estimate, biasStats.StandardError, biasStats.ZScore, biasStats.PValue })); + foreach (var coeffStat in coeffStats) { + Env.Assert(coeffStat.Index < featureNames.Length); + resultCollection.Add(new KeyValuePair( - coeffStat.Name, + featureNames[coeffStat.Index], new float[] { coeffStat.Estimate, coeffStat.StandardError, coeffStat.ZScore, coeffStat.PValue })); } } - internal DataViewSchema.Annotations MakeStatisticsMetadata(LinearBinaryModelParameters parent, RoleMappedSchema schema, in VBuffer> names) + internal override DataViewSchema.Annotations MakeStatisticsMetadata(RoleMappedSchema schema, in VBuffer> names) { - _env.AssertValueOrNull(parent); - _env.AssertValue(schema); + Env.AssertValue(schema); var builder = new DataViewSchema.Annotations.Builder(); + builder.Add(base.MakeStatisticsMetadata(schema, names), c => true); - builder.AddPrimitiveValue("Count of training examples", NumberDataViewType.Int64, _trainingExampleCount); - builder.AddPrimitiveValue("Residual Deviance", NumberDataViewType.Single, _deviance); - builder.AddPrimitiveValue("Null Deviance", NumberDataViewType.Single, _nullDeviance); - builder.AddPrimitiveValue("AIC", NumberDataViewType.Single, 2 * _paramCount + _deviance); + //bias statistics + var biasStats = GetBiasStatistics(); + builder.AddPrimitiveValue("BiasEstimate", NumberDataViewType.Single, biasStats.Estimate); + builder.AddPrimitiveValue("BiasStandardError", NumberDataViewType.Single, biasStats.StandardError); + builder.AddPrimitiveValue("BiasZScore", NumberDataViewType.Single, biasStats.ZScore); + builder.AddPrimitiveValue("BiasPValue", NumberDataViewType.Single, biasStats.PValue); - if (parent == null) - return builder.ToAnnotations(); - - if (!TryGetBiasStatistics(parent.Statistics, parent.Bias, out float biasStdErr, out float biasZScore, out float biasPValue)) - return builder.ToAnnotations(); - - var biasEstimate = parent.Bias; - builder.AddPrimitiveValue("BiasEstimate", NumberDataViewType.Single, biasEstimate); - builder.AddPrimitiveValue("BiasStandardError", NumberDataViewType.Single, biasStdErr); - builder.AddPrimitiveValue("BiasZScore", NumberDataViewType.Single, biasZScore); - builder.AddPrimitiveValue("BiasPValue", NumberDataViewType.Single, biasPValue); - - var weights = default(VBuffer); - ((IHaveFeatureWeights)parent).GetFeatureWeights(ref weights); var estimate = default(VBuffer); var stdErr = default(VBuffer); var zScore = default(VBuffer); var pValue = default(VBuffer); ValueGetter>> getSlotNames; - GetUnorderedCoefficientStatistics(parent.Statistics, in weights, in names, ref estimate, ref stdErr, ref zScore, ref pValue, out getSlotNames); + GetUnorderedCoefficientStatistics(in names, ref estimate, ref stdErr, ref zScore, ref pValue, out getSlotNames); var subMetaBuilder = new DataViewSchema.Annotations.Builder(); subMetaBuilder.AddSlotNames(stdErr.Length, getSlotNames); @@ -457,21 +587,5 @@ internal DataViewSchema.Annotations MakeStatisticsMetadata(LinearBinaryModelPara return builder.ToAnnotations(); } - - private string DecorateProbabilityString(float probZ) - { - Contracts.AssertValue(_env); - _env.Assert(0 <= probZ && probZ <= 1); - if (probZ < 0.001) - return string.Format("{0} ***", probZ); - if (probZ < 0.01) - return string.Format("{0} **", probZ); - if (probZ < 0.05) - return string.Format("{0} *", probZ); - if (probZ < 0.1) - return string.Format("{0} .", probZ); - - return probZ.ToString(); - } } } diff --git a/test/BaselineOutput/Common/Command/CommandTrainingLrWithStats-out.txt b/test/BaselineOutput/Common/Command/CommandTrainingLrWithStats-out.txt index 9fc4371313..9dff492b77 100644 --- a/test/BaselineOutput/Common/Command/CommandTrainingLrWithStats-out.txt +++ b/test/BaselineOutput/Common/Command/CommandTrainingLrWithStats-out.txt @@ -8,6 +8,7 @@ Model trained with 500 training examples. Residual Deviance: 458.9709 (on 493 degrees of freedom) Null Deviance: 539.2764 (on 499 degrees of freedom) AIC: 472.9709 +Warning: The number of parameters is too large. Cannot hold the variance-covariance matrix in memory. Skipping computation of standard errors and z-statistics of coefficients. Consider choosing a larger L1 regularizerto reduce the number of parameters. Not training a calibrator because it is not needed. Physical memory usage(MB): %Number% Virtual memory usage(MB): %Number% diff --git a/test/BaselineOutput/Common/EntryPoints/ensemble-summary.txt b/test/BaselineOutput/Common/EntryPoints/ensemble-summary.txt index fadb2e27c8..12c0c05472 100644 --- a/test/BaselineOutput/Common/EntryPoints/ensemble-summary.txt +++ b/test/BaselineOutput/Common/EntryPoints/ensemble-summary.txt @@ -20,7 +20,7 @@ AIC: 118.2943 Coefficients statistics: Coefficient Estimate Std. Error z value Pr(>|z|) -(Bias) -5.120674 0.6998186 -7.317145 0 *** +(Bias) -5.120674 0.6998186 -7.317145 0 *** Features.thickness 2.353567 0.4267568 5.515007 5.960464E-08 *** Features.bare_nuclei 2.435889 0.451504 5.395056 5.960464E-08 *** Features.uniform_shape 1.944249 0.4137097 4.699549 2.622604E-06 *** @@ -68,7 +68,7 @@ AIC: 114.1969 Coefficients statistics: Coefficient Estimate Std. Error z value Pr(>|z|) -(Bias) -4.860323 0.7128119 -6.818521 0 *** +(Bias) -4.860323 0.7128119 -6.818521 0 *** Features.bare_nuclei 3.16846 0.4579377 6.918975 0 *** Features.thickness 2.143086 0.4306555 4.976335 6.556511E-07 *** Features.uniform_shape 1.711214 0.4222687 4.05243 5.072355E-05 *** diff --git a/test/Microsoft.ML.TestFramework/Attributes/LessThanNetCore30OrNotNetCoreAndX64FactAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/LessThanNetCore30OrNotNetCoreAndX64FactAttribute.cs index 3685d6de0c..c6bddec8ab 100644 --- a/test/Microsoft.ML.TestFramework/Attributes/LessThanNetCore30OrNotNetCoreAndX64FactAttribute.cs +++ b/test/Microsoft.ML.TestFramework/Attributes/LessThanNetCore30OrNotNetCoreAndX64FactAttribute.cs @@ -20,4 +20,20 @@ protected override bool IsEnvironmentSupported() return Environment.Is64BitProcess && AppDomain.CurrentDomain.GetData("FX_PRODUCT_VERSION") == null; } } + + /// + /// A fact for tests requiring x64 environment and either .NET Core version lower than 3.0 or framework other than .NET Core. + /// + public sealed class LessThanNetCore30OrNotNetCore : EnvironmentSpecificFactAttribute + { + public LessThanNetCore30OrNotNetCore() : base("Skipping test on .net core version > 3.0 ") + { + } + + /// + protected override bool IsEnvironmentSupported() + { + return AppDomain.CurrentDomain.GetData("FX_PRODUCT_VERSION") == null; + } + } } \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs index bccdc90ee3..e57d923d2c 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs @@ -2,8 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.IO; using Microsoft.ML.Calibrators; using Microsoft.ML.Data; +using Microsoft.ML.Model; +using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Trainers; using Xunit; @@ -52,23 +56,27 @@ public void TestEstimatorPoissonRegression() } [Fact] - public void TestLogisticRegressionNoStats() + public void TestLRNoStats() { (IEstimator pipe, IDataView dataView) = GetBinaryClassificationPipeline(); pipe = pipe.Append(ML.BinaryClassification.Trainers.LogisticRegression(new LogisticRegressionBinaryTrainer.Options { ShowTrainingStatistics = true })); var transformerChain = pipe.Fit(dataView) as TransformerChain>>; - var linearModel = transformerChain.LastTransformer.Model.SubModel as LinearBinaryModelParameters; - var stats = linearModel.Statistics; - LinearModelStatistics.TryGetBiasStatistics(stats, 2, out float stdError, out float zScore, out float pValue); + var stats = linearModel.Statistics as ModelStatisticsBase; + + Assert.NotNull(stats); + + var stats2 = linearModel.Statistics as LinearModelParameterStatistics; - Assert.Equal(0f, stdError); - Assert.Equal(0f, zScore); + Assert.Null(stats2); + + Done(); } + [Fact] - public void TestLogisticRegressionWithStats() + public void TestLRWithStats() { (IEstimator pipe, IDataView dataView) = GetBinaryClassificationPipeline(); @@ -82,21 +90,157 @@ public void TestLogisticRegressionWithStats() var transformer = pipe.Fit(dataView) as TransformerChain>>; var linearModel = transformer.LastTransformer.Model.SubModel as LinearBinaryModelParameters; - var stats = linearModel.Statistics; - LinearModelStatistics.TryGetBiasStatistics(stats, 2, out float stdError, out float zScore, out float pValue); - CompareNumbersWithTolerance(stdError, 0.250672936); - CompareNumbersWithTolerance(zScore, 7.97852373); + Action validateStats = (modelParameters) => + { + var stats = linearModel.Statistics as LinearModelParameterStatistics; + var biasStats = stats?.GetBiasStatistics(); + Assert.NotNull(biasStats); + + biasStats = stats.GetBiasStatisticsForValue(2); + + Assert.NotNull(biasStats); + + CompareNumbersWithTolerance(biasStats.StandardError, 0.25, digitsOfPrecision: 2); + CompareNumbersWithTolerance(biasStats.ZScore, 7.97, digitsOfPrecision: 2); + + var scoredData = transformer.Transform(dataView); + + var coefficients = stats.GetWeightsCoefficientStatistics(100); + + Assert.Equal(18, coefficients.Length); + + foreach (var coefficient in coefficients) + Assert.True(coefficient.StandardError < 1.0); + }; + + validateStats(linearModel); + + var modelAndSchemaPath = GetOutputPath("TestLRWithStats.zip"); + + // Save model. + ML.Model.Save(transformer, dataView.Schema, modelAndSchemaPath); + + ITransformer transformerChain; + using (var fs = File.OpenRead(modelAndSchemaPath)) + transformerChain = ML.Model.Load(fs, out var schema); + + var lastTransformer = ((TransformerChain)transformerChain).LastTransformer as BinaryPredictionTransformer>; + var model = lastTransformer.Model as ParameterMixingCalibratedModelParameters, ICalibrator>; + + linearModel = model.SubModel as LinearBinaryModelParameters; + + validateStats(linearModel); + + Done(); + + } + + + [Fact] + public void TestLRWithStatsBackCompatibility() + { + string dropModelPath = GetDataPath("backcompat/LrWithStats.zip"); + string trainData = GetDataPath("adult.tiny.with-schema.txt"); + + using (FileStream fs = File.OpenRead(dropModelPath)) + { + var result = ModelFileUtils.LoadPredictorOrNull(Env, fs) as ParameterMixingCalibratedModelParameters, ICalibrator>; + var subPredictor = result?.SubModel as LinearBinaryModelParameters; + var stats = subPredictor?.Statistics; + + CompareNumbersWithTolerance(stats.Deviance, 458.970917); + CompareNumbersWithTolerance(stats.NullDeviance, 539.276367); + Assert.Equal(7, stats.ParametersCount); + Assert.Equal(500, stats.TrainingExampleCount); + + } + + Done(); + } + + [Fact] + public void TestMLRNoStats() + { + (IEstimator pipe, IDataView dataView) = GetMulticlassPipeline(); + var trainer = ML.MulticlassClassification.Trainers.LogisticRegression(); + var pipeWithTrainer = pipe.Append(trainer); + + TestEstimatorCore(pipeWithTrainer, dataView); - var scoredData = transformer.Transform(dataView); + var transformer = pipeWithTrainer.Fit(dataView); + var model = transformer.LastTransformer.Model as MulticlassLogisticRegressionModelParameters; + var stats = model.Statistics; - var coeffcients = stats.GetCoefficientStatistics(linearModel, scoredData.Schema["Features"], 100); + Assert.Null(stats); - Assert.Equal(19, coeffcients.Length); + Done(); + } + + [LessThanNetCore30OrNotNetCore] + public void TestMLRWithStats() + { + (IEstimator pipe, IDataView dataView) = GetMulticlassPipeline(); + + var trainer = ML.MulticlassClassification.Trainers.LogisticRegression(new LogisticRegressionMulticlassClassificationTrainer.Options + { + ShowTrainingStatistics = true + }); + var pipeWithTrainer = pipe.Append(trainer); + + TestEstimatorCore(pipeWithTrainer, dataView); - foreach (var coefficient in coeffcients) - Assert.True(coefficient.StandardError < 1.0); + var transformer = pipeWithTrainer.Fit(dataView); + var model = transformer.LastTransformer.Model as MulticlassLogisticRegressionModelParameters; + Action validateStats = (modelParams) => + { + var stats = modelParams.Statistics; + Assert.NotNull(stats); + + CompareNumbersWithTolerance(stats.Deviance, 45.35, digitsOfPrecision: 2); + CompareNumbersWithTolerance(stats.NullDeviance, 329.58, digitsOfPrecision: 2); + //Assert.Equal(14, stats.ParametersCount); + Assert.Equal(150, stats.TrainingExampleCount); + }; + + validateStats(model); + + var modelAndSchemaPath = GetOutputPath("TestMLRWithStats.zip"); + // Save model. + ML.Model.Save(transformer, dataView.Schema, modelAndSchemaPath); + + // Load model. + ITransformer transformerChain; + using (var fs = File.OpenRead(modelAndSchemaPath)) + transformerChain = ML.Model.Load(fs, out var schema); + + var lastTransformer = ((TransformerChain)transformerChain).LastTransformer as MulticlassPredictionTransformer>>; + model = lastTransformer.Model as MulticlassLogisticRegressionModelParameters; + + validateStats(model); + + Done(); + } + + [Fact] + public void TestMLRWithStatsBackCompatibility() + { + string dropModelPath = GetDataPath("backcompat/MlrWithStats.zip"); + string trainData = GetDataPath("iris.data"); + + using (FileStream fs = File.OpenRead(dropModelPath)) + { + var result = ModelFileUtils.LoadPredictorOrNull(Env, fs) as MulticlassLogisticRegressionModelParameters; + var stats = result?.Statistics; + + Assert.Equal(132.012238f, stats.Deviance); + Assert.Equal(329.583679f, stats.NullDeviance); + Assert.Equal(11, stats.ParametersCount); + Assert.Equal(150, stats.TrainingExampleCount); + } + + Done(); } } } diff --git a/test/data/backcompat/LrWithStats.zip b/test/data/backcompat/LrWithStats.zip new file mode 100644 index 0000000000000000000000000000000000000000..a570b4fdbe88383b65b05ae2d05ed7dfc1d4c1ba GIT binary patch literal 4713 zcmbVQ2{=^i8=o2bK9)oxyCE_5%5IRo5#fq5vZSmN*el@9%rR_x-&GZA?YY1%W{5A&!2cW^|`- zDboX248Y(4h9wr`=IQ3?YUt_WWoPA#J@4k_DdQX9tI)3i(LN~yL2rAylFa^i0C?&Z z6aqm2|7fhUliO)uFRUG9X6ohS>~7(U@jY+%eJXR#Ik3P?jZ>G*=vL>3@oeR+oPMOqZuV7j?csm3C-YZ z-;^y5-F701UsdJacqKAc!=GgG68&i7o*dclgbO>ZVjp{Hm!H~b>AvmMNM&+gTL36o z1gPQvf7ER8@p%a|zvS<;gmH9mf?(hLopt?&-JLXYMfK5u zuPvi#LdilIxv@ewa(QB(@fh(e@`Uh)ggYS#*&T#O1ej=C_)RvHRC+4i&A9LqKOQcg z)|?rWoKZE}6SSxt4XT5*#ojIm=MHG5C8K2;qb2mIJtKKHw7UyhN)P#z9PUm~ZNOEh zG7D)vtHwzaE_G_1VNFqSLekdBuyhz>kT6HMA`O1+e%%Vys}D?|Fgl~tIH(42bQl;S z+qgnu+}#|pn5~GSOt<2Zsl!lO(%is|)M+ARB;6@HKFQcyRHIZ@+-q5N9NDkZZ&r@z z_3kf0_gZ;BB~M~&29$u`YNfGB7-;dmSc_RAuxK-|Q35+H{v&eB&5T(iw3tH!V2NYU zc=pRaldX{bt}v$YFkORKPJ3+9in*AVq+~k5n??6lkr(;0iJVi^ZbXq*=F605nAwXGloIo zvC?@KF=)0vnK_}n(szt;c{vQr`U1giM1~K0l>bmreN9r)wAFrux#==5DDH*l32mUI zb5>Kv@0nc_l+1Q>=U!6&I-nW9e*H${grzu9hFr@lLrtRm@w=u@%0^nQ`Fob}ii}A&ncwwBJvEQAv`*Ut*;JkHfE3)(w zj9FWe%(^f1#}Ty*!Oz>%6-z$oenSTZ8zniYM(aMAVV zN6lN;r<|)&Qv)?_jBYpjZtA<-cYv{Q7z84?v(Z1csQ~@hh)Nq?76~U&EpdK=YF9@S zlIeVw$?jII!UC@(BHN^As20f|SJG@}7pF=H*URUM=Q2lcQN;-=5@+-2a3PNEG(3|p zxiv->`S}KVry5UoAHnK1y-zu8R>XM0Ci%vFQ>RG4|iIlZ&;n6K73 z%dc=lOa<0(H_d`@n%JT66-{{f^dGJv?i>LE>GDX$K_h+(mDkcf&#>pROIQrbWl`cN zc@#TJ10@G9luJQQLm!5O-(*4w6Vm9Jm1v@o(Uv)xjyc02(SMeX2I$_4mIoD5^K?Pc}%d0{)`T*jUI$>gWBoh%xBszEBQ!>>0{bzEpc(G+(YP5LJQ9{UQtjs zs)asQyO#^`e&YQ&n{S#dD2Gv3mK=Ek@YW0PC-ZOl`$1U2QB$@SEqJIOyh8o(DpXiQ z*o-yL?7|6lAMer<-4yY1ffc(8w$UYN$+^w!+`+mX z1Yy@1Z~ZLS==vA}LYK+344#~voHZWFUN__41#gc$%dkolD90-?N8_U(MShsh>1Fdc zSmA=K&9T{LBQB9O!dIEKjfV+m;sj73%@ztvuy2nwK!>A}_Jm~>z{pm1m-uI9wp7}d z6v_Upz;UoS#Gi@ZZ{6CJL5hUiC(>U^LQ;#i^W9#ICGPP(SCxHXZTLl^Rj)XHfEwCH zzMRyyH@Dmrri>T=c0K6u?EAxBVPC#HAij0|IR9nhNpn-4k9IVzz=85+mh)3wq0~pn z$Ol|>WO5BT$Z}BFDRe;zGt-+Z+9VB2BaOZ&6 z^DT62X}JTy9PzY#RtqCSknUEvXtB8=$Qd|QAe+(5y*=|zaw-$@%IFc z;)ghuXYRf=^pBrbn6fun!KX;Kr}t`1mM`iy`>*9m#RqR=fyZAh*Lrf;Y@u!6B7pmS0ww4P;3Oe_#nFSGvf7Snb0MqGx9 z0y8u)iV1HI+ubIBTRJE72Ggs~9Fe^Ij%14`r4HLgCL$f5#xvcJuv^d{oVNB0_Sai# z3Qg&LpcmhqQ~%u8M;E0b$Oluc4Hpoi+09xl7oX8xqKM?~!P_4!&Yv)JQ5C|CN9jp# zRQ8QU1Tq%ReYH@ZK(o4haK|5-`^_WxPN|$M1AU@gXx4j&(>m#Ig9p-6gZHcLu3*rG zw>Jf67Yu)SSE{^L#OGvqC(>Heli0Wl3J)V;j(rXQGdN)E1;(~+=a`qf^Hz-jx}Pm} zEH0qw4a*S=okvX!50p$z9=m`>qdN>}7HV=E#`|F-SK34ovWGMuDI+u`2Bbt=yC8<7 zLz;z(Dr}->>AA}%Ct;Ekk+Gk#wFTl=Q@9JTYpGM8u&bit1ykNI!|!zf3gd$8+=&6^ z{q3{O(!cd(XyffD;Es#}p#e-OT^VKe{YZfNGcYQ+nW3Qm2E-Tzi!%Fh(7;e&txti9 z0>lmlfFk!n${!2}*6$QR7}XAdI~oU8#1yQJZCE=3fYl}i;0^V^(4Ju5z&eluca>&G zxE+lH%Owh0Bkj*<|5ZGJeE=&Q%BtaX|83PB9R!O3%2Lh1I}Wh&|D}UFVuQ;m1$!8H zck&O|e-&3SH{fzb0qtS@8FWY6zzLrM)6WD-a(_$zU}M0^k%BVE3_<~=O0+Qz5I_(J P8}NM$#Ca^s*46(23u5D< literal 0 HcmV?d00001 diff --git a/test/data/backcompat/MlrWithStats.zip b/test/data/backcompat/MlrWithStats.zip new file mode 100644 index 0000000000000000000000000000000000000000..4ddc9e2d08293ae03d31bdb6c495080a91b826c5 GIT binary patch literal 5203 zcmds5c{r49`=0D+j42G+UpXw3Wi*r}`%)QZhD2Gj?=uWpLrJ~UAXK(wOO#0=g|f!T zuDq6{D6&=7h!B2{e#h73dEbuu-ao&Y<38q@Ip#j^>prjhTFx7fVPxWlKp<=oYJfPF z)iG~B8}MyAF!+FBMzVJ&xD#BD5L}4X=FTKPcOpT~Kgb{5jfQC3%0cj3W{1E6gmwV4 zpupI~#FLzz+#UUiBx@sk2WKy=y|=TU<4NGfm75NOeQ^!z`j$lcCm))4FN`-p#7#%+8JuaL*#{KeM*#f^fg(@C>I%ghbIPjNz zH1Vrd)y(>!Rby-1>P8}L>iE0xJKpJqk~{dfGw_W_xhWVzPaIv4D~^NyQI%~$&X96c zoD|Ob(>s=ngqA@`5Em5Q7vcXHy-5k%sTO6UaNv5PhT8tNKpLr$Qo|Z_y;Sn@#ciiL zHWczIY6MVZ{U5gnC{qlP+W0O{uhv^`p%gYL~Ezuk*3Up9o-2z%h$%T$86e{75pq|lLz@nN8MAH+MmZdx>)gs z;abDhEgt2Ict#YC<=VsVAhxqhYV(&e=Uccz`dg0cT5h*uC1sBs00@;DE9o=IN^E z=zjhN_I3kDj(oeR*W63~)-x-YOcLL+Aezz~*fY;gRUe{eA}1Ev?_cukmHnJe2x{0> z`do?Nsn?6M?x?tLx0No&&}tT!WLsp-i;khjM)$p4kZFW>kb#GgRc}UuF6D&ZF(j->cA$S(`Vjg(siOn18(Cv;Iuec*E}T zsS1JGX+C9*p2cUj()qp77u<5qIW!1PZ<-`-So(S3-1@1#G|VLRBbhxNCxpvkjyW{K z?q%>$pYvg1g_39eh@RtzoQIZ;xwbER9k{Bv9wUbl+|5>v(9<}t)qDBXccP{en)t>m z{XRIR`!p~tWI1*Z?tUruEblc>HHg7x{O8jh0E5qf^dtAXuKiBiU0emM4h5~Oi;P#@ zSI;U+(|_t&iK4P_GMsQ& zN7n{c4h6N&<-L=Ld2lglUX_#G$c6Li)5d+8FVA(IE=fxZ3@n}+{`_aj25RO?rA97% zkM;$1MfRIT%Gu&1vx(xoa@KA$JzgQN+YimSy9E^*&aABq+(;Q}bLR=fOkJhM1nVP8 zA@XIu-+q)PjZu?czKc;})~mqOK*iEedWE$96FR#U%bzpm>2?&F}A7lpe6s z|02%(xdf$$X&Y@MHi0IXeb<;_KSg3aBSdxTgDNq6X{o+?+;wJl>0L){%?)4e3oOF2 zcWXKP#<;_n4vnDdxmia>9*l&T^@%x!%_-SWKF=1f*&{ERb7O`qfR%a!#hiAKSzRp9 zh$r_!jR1h^VUc&IUMI{0BqI_7>9qcX@&npIu?k0 zdoTA8XVSO$*LHDnCpgo8+zkBrz3rb{ zem}{%SYz0elkl+cOK%~Fq`$sgJnk-LDX?ayH<^zx@ix8Yv)j|vhw;eN^;-QP-IteB zmIvDopo!sEnVkjl67^Xag2i%^SCLZ$-muWyBrom<>crjUYbeV!wHM2SnKxJO%~a)U zu9t7vHQy>fO3UQjZKEc;@TxM)?y?^5rMg`(v8X<{aOv}u5;qPbK4ZMWPRH^QuO^P2 zS?7@-6FH+gwV|0`Qw}D_oZ0Up=o`7=1m(jmCx@wIQ#4a2@M#4WCgVK9IlDcg z+|He7RrDIn72fJvC9V_$F8_uazk1Fewj5#Y{kA`8X=W8rd^hP_bkleI?Vjpd_0Qq8WS#AwwbPn!K#dR6(K&8+F2C2MG3Vi!2=cy^a3^yH9&>#10gV+}sX zi|tH`CE1xee9ab5-iGP%dK<-+EOsaLkxbjd{m)7b`o4W#8=*gShI&0%biN>3Y)>XS zIc!frO;G(Vkv6B{$n~e;F&adZc-c7}a?O1Yt}CBxr$x2~4TU>z`#M04ClEA$(@45R zZ*O~o(`J4@gjra$QvA5SRZ`5R|=9*p*Q?yQ@4l@ly6UPU}+KD5wD-I!#tG?Eg+cNAB0K| z4A*^bP+uCn&9AEHylOj5Lrm`s!Wxd+Ta3M|A}i;=ItH!HqMBNSYk9}&_99leR212s zBnXt~pbdJhZslp25He=b8)`2Sx;UrrgoG!?^|j{^ZgXm^>}kzQ@-`h{E&B4l{!c7> z6(?JlzvPapo(z`^7py=5^MSAZJE%{i4h(l;7~tqMEb3EQ-U1hl3U5^?W?PJ_3kd_O z8K3>wgC_yb$iNT;#ui;^O7wExY!CsR`K=)gz_T5>D>c`QGddzm9UCu zBA2A|-oaKblXkOP5Q}PVWd@>r8w2+@efrOu7clRK*A`2EP{3dmu=b@xXaJA@q=dl; zU;#`=2;GA4Ed)Vw7>os0k8~^waBAt+d=FHl8z>8;yO`#JEc;9iWrYVE&V^{1?X`xyFCGr&#OZ*B0W*gy9O6u9g91@_N80xq+3 zgSP>fXJF8u+6FE~bQp0CP?G(jFo7KfPR?|c<1i2kC{^Pznn+a literal 0 HcmV?d00001 From de5d48a959a789553c337a7d096e58ef0a2410bb Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Tue, 19 Mar 2019 14:54:08 -0700 Subject: [PATCH 08/18] Refactor cancellation mechanism and make it internal, accessible via experimental nuget. (#2797) * Clean up of cancellation mechanism. * changes. * Make cancellation mechanism internal and accessible via experiemtnal nuget. * merge conflicts. * PR feedback and fix analyzer test. * Add test for public API in experimental nuget. * PR feedback. * fix tests. * PR feedback. * PR feedback. * PR feedback. * PR feedback. * PR feedback. * PR feedback. * PR feedback. * PR feedback. --- Microsoft.ML.sln | 15 +++++ .../Microsoft.ML.Experimental.nupkgproj | 12 ++++ ...icrosoft.ML.Experimental.symbols.nupkgproj | 5 ++ .../Data/IHostEnvironment.cs | 22 ++++--- .../Environment/ConsoleEnvironment.cs | 2 +- .../Environment/HostEnvironmentBase.cs | 57 +++++++++++-------- src/Microsoft.ML.Core/Utilities/Contracts.cs | 4 +- src/Microsoft.ML.Data/MLContext.cs | 5 +- .../Properties/AssemblyInfo.cs | 1 + .../Utilities/LocalEnvironment.cs | 2 +- .../MLContextExtensions.cs | 15 +++++ .../Microsoft.ML.Experimental.csproj | 12 ++++ .../Code/ContractsCheckTest.cs | 10 +++- .../Resources/ContractsCheckResource.cs | 1 - .../UnitTests/TestHosts.cs | 36 +++++++++++- 15 files changed, 154 insertions(+), 45 deletions(-) create mode 100644 pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.nupkgproj create mode 100644 pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.symbols.nupkgproj create mode 100644 src/Microsoft.ML.Experimental/MLContextExtensions.cs create mode 100644 src/Microsoft.ML.Experimental/Microsoft.ML.Experimental.csproj diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 09c41cd191..242482a6bf 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -266,6 +266,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Ensemble", "Mi pkg\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.symbols.nupkgproj = pkg\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.symbols.nupkgproj EndProjectSection EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Experimental", "src\Microsoft.ML.Experimental\Microsoft.ML.Experimental.csproj", "{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -948,6 +950,18 @@ Global {5E920CAC-5A28-42FB-936E-49C472130953}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU {5E920CAC-5A28-42FB-936E-49C472130953}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU {5E920CAC-5A28-42FB-936E-49C472130953}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release|Any CPU.Build.0 = Release|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1033,6 +1047,7 @@ Global {31D38B21-102B-41C0-9E0A-2FE0BF68D123} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {5E920CAC-5A28-42FB-936E-49C472130953} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {AD7058C9-5608-49A8-BE23-58C33A74EE91} = {D3D38B03-B557-484D-8348-8BADEE4DF592} + {E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC} = {09EADF06-BE25-4228-AB53-95AE3E15B530} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.nupkgproj b/pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.nupkgproj new file mode 100644 index 0000000000..edf80ad475 --- /dev/null +++ b/pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.nupkgproj @@ -0,0 +1,12 @@ + + + + netstandard2.0 + Microsoft.ML.Experimental contains experimental work such extension methods to access internal methods. + + + + + + + diff --git a/pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.symbols.nupkgproj b/pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.symbols.nupkgproj new file mode 100644 index 0000000000..c869da5d2b --- /dev/null +++ b/pkg/Microsoft.ML.Experimental/Microsoft.ML.Experimental.symbols.nupkgproj @@ -0,0 +1,5 @@ + + + + + diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs index f81a1aa293..959b1e940b 100644 --- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs +++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs @@ -63,14 +63,23 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider IHost Register(string name, int? seed = null, bool? verbose = null); /// - /// Flag which indicate should we stop any code execution in this host. + /// The catalog of loadable components () that are available in this host. /// - bool IsCancelled { get; } + ComponentCatalog ComponentCatalog { get; } + } + [BestFriend] + internal interface ICancelable + { /// - /// The catalog of loadable components () that are available in this host. + /// Signal to stop exection in all the hosts. /// - ComponentCatalog ComponentCatalog { get; } + void CancelExecution(); + + /// + /// Flag which indicates host execution has been stopped. + /// + bool IsCanceled { get; } } /// @@ -85,11 +94,6 @@ public interface IHost : IHostEnvironment /// generators are NOT thread safe. /// Random Rand { get; } - - /// - /// Signal to stop exection in this host and all its children. - /// - void StopExecution(); } /// diff --git a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs index 011efea28f..f4fa53d6c6 100644 --- a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs +++ b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs @@ -465,7 +465,7 @@ private sealed class Host : HostBase public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) : base(source, shortName, parentFullName, rand, verbose) { - IsCancelled = source.IsCancelled; + IsCanceled = source.IsCanceled; } protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name) diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs index 3b46334a9e..da0e2e711c 100644 --- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs +++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs @@ -93,9 +93,23 @@ internal interface IMessageSource /// query progress. /// [BestFriend] - internal abstract class HostEnvironmentBase : ChannelProviderBase, IHostEnvironment, IChannelProvider + internal abstract class HostEnvironmentBase : ChannelProviderBase, IHostEnvironment, IChannelProvider, ICancelable where TEnv : HostEnvironmentBase { + void ICancelable.CancelExecution() + { + lock (_cancelLock) + { + foreach (var child in _children) + if (child.TryGetTarget(out IHost host)) + if (host is ICancelable cancelableHost) + cancelableHost.CancelExecution(); + + _children.Clear(); + IsCanceled = true; + } + } + /// /// Base class for hosts. Classes derived from may choose /// to provide their own host class that derives from this class. @@ -107,28 +121,10 @@ public abstract class HostBase : HostEnvironmentBase, IHost public Random Rand => _rand; - // We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference. - private readonly List> _children; - public HostBase(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) : base(source, rand, verbose, shortName, parentFullName) { Depth = source.Depth + 1; - _children = new List>(); - } - - public void StopExecution() - { - lock (_cancelLock) - { - IsCancelled = true; - foreach (var child in _children) - { - if (child.TryGetTarget(out IHost host)) - host.StopExecution(); - } - _children.Clear(); - } } public new IHost Register(string name, int? seed = null, bool? verbose = null) @@ -139,7 +135,7 @@ public void StopExecution() { Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose); - if (!IsCancelled) + if (!IsCanceled) _children.Add(new WeakReference(host)); } return host; @@ -175,7 +171,7 @@ protected PipeBase(ChannelProviderBase parent, string shortName, public void Dispose() { - if(!_disposed) + if (!_disposed) { Dispose(true); _disposed = true; @@ -339,12 +335,15 @@ public void RemoveListener(Action listenerFunc) protected readonly ProgressReporting.ProgressTracker ProgressTracker; - public bool IsCancelled { get; protected set; } - public ComponentCatalog ComponentCatalog { get; } public override int Depth => 0; + public bool IsCanceled { get; protected set; } + + // We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference. + private readonly List> _children; + /// /// The main constructor. /// @@ -359,6 +358,7 @@ protected HostEnvironmentBase(Random rand, bool verbose, _cancelLock = new object(); Root = this as TEnv; ComponentCatalog = new ComponentCatalog(); + _children = new List>(); } /// @@ -379,13 +379,20 @@ protected HostEnvironmentBase(HostEnvironmentBase source, Random rand, boo ListenerDict = source.ListenerDict; ProgressTracker = source.ProgressTracker; ComponentCatalog = source.ComponentCatalog; + _children = new List>(); } public IHost Register(string name, int? seed = null, bool? verbose = null) { Contracts.CheckNonEmpty(name, nameof(name)); - Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); - return RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose); + IHost host; + lock (_cancelLock) + { + Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); + host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose); + _children.Add(new WeakReference(host)); + } + return host; } protected abstract IHost RegisterCore(HostEnvironmentBase source, string shortName, diff --git a/src/Microsoft.ML.Core/Utilities/Contracts.cs b/src/Microsoft.ML.Core/Utilities/Contracts.cs index 9770197f58..95ef3d0d00 100644 --- a/src/Microsoft.ML.Core/Utilities/Contracts.cs +++ b/src/Microsoft.ML.Core/Utilities/Contracts.cs @@ -737,6 +737,7 @@ public static void CheckIO(this IExceptionContext ctx, bool f, string msg) if (!f) throw ExceptIO(ctx, msg); } + public static void CheckIO(this IExceptionContext ctx, bool f, string msg, params object[] args) { if (!f) @@ -748,11 +749,10 @@ public static void CheckIO(this IExceptionContext ctx, bool f, string msg, param /// public static void CheckAlive(this IHostEnvironment env) { - if (env.IsCancelled) + if (env is ICancelable cancelableEnv && cancelableEnv.IsCanceled) throw Process(new OperationCanceledException("Operation was cancelled."), env); } #endif - /// /// This documents that the parameter can legally be null. /// diff --git a/src/Microsoft.ML.Data/MLContext.cs b/src/Microsoft.ML.Data/MLContext.cs index 95f494c16e..7e8bc535fe 100644 --- a/src/Microsoft.ML.Data/MLContext.cs +++ b/src/Microsoft.ML.Data/MLContext.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using Microsoft.ML.Data; using Microsoft.ML.Runtime; @@ -104,12 +105,14 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message) log(this, new LoggingEventArgs(msg)); } - bool IHostEnvironment.IsCancelled => _env.IsCancelled; string IExceptionContext.ContextDescription => _env.ContextDescription; TException IExceptionContext.Process(TException ex) => _env.Process(ex); IHost IHostEnvironment.Register(string name, int? seed, bool? verbose) => _env.Register(name, seed, verbose); IChannel IChannelProvider.Start(string name) => _env.Start(name); IPipe IChannelProvider.StartPipe(string name) => _env.StartPipe(name); IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name); + + [BestFriend] + internal void CancelExecution() => ((ICancelable)_env).CancelExecution(); } } diff --git a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs index a37c10210d..2375a5a05a 100644 --- a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs @@ -45,6 +45,7 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet50" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipe" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Experimental" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.MetaLinearLearner" + InternalPublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "TMSNlearnPrediction" + InternalPublicKey.Value)] diff --git a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs index 2423b43a42..f2ca816e70 100644 --- a/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs +++ b/src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs @@ -93,7 +93,7 @@ private sealed class Host : HostBase public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) : base(source, shortName, parentFullName, rand, verbose) { - IsCancelled = source.IsCancelled; + IsCanceled = source.IsCanceled; } protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name) diff --git a/src/Microsoft.ML.Experimental/MLContextExtensions.cs b/src/Microsoft.ML.Experimental/MLContextExtensions.cs new file mode 100644 index 0000000000..cc5255fbd9 --- /dev/null +++ b/src/Microsoft.ML.Experimental/MLContextExtensions.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Experimental +{ + public static class MLContextExtensions + { + /// + /// Stop the execution of pipeline in + /// + /// reference. + public static void CancelExecution(this MLContext ctx) => ctx.CancelExecution(); + } +} diff --git a/src/Microsoft.ML.Experimental/Microsoft.ML.Experimental.csproj b/src/Microsoft.ML.Experimental/Microsoft.ML.Experimental.csproj new file mode 100644 index 0000000000..4c1b189a5b --- /dev/null +++ b/src/Microsoft.ML.Experimental/Microsoft.ML.Experimental.csproj @@ -0,0 +1,12 @@ + + + + netstandard2.0 + Microsoft.ML.Experimental + + + + + + + diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs index 8ef28ab554..4b0e1921de 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs @@ -39,6 +39,8 @@ public async Task ContractsCheck() VerifyCS.Diagnostic(ContractsCheckAnalyzer.SimpleMessageDiagnostic.Rule).WithLocation(basis + 32, 35).WithArguments("Check", "\"Less fine: \" + env.GetType().Name"), VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(basis + 34, 17).WithArguments("CheckUserArg", "name", "\"p\""), VerifyCS.Diagnostic(ContractsCheckAnalyzer.DecodeMessageWithLoadContextDiagnostic.Rule).WithLocation(basis + 39, 41).WithArguments("CheckDecode", "\"This message is suspicious\""), + new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 24).WithMessage("'ICancelable' is inaccessible due to its protection level"), + new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 67).WithMessage("'ICancelable.IsCanceled' is inaccessible due to its protection level"), }; var test = new VerifyCS.Test @@ -125,7 +127,9 @@ public async Task ContractsCheckFix() VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(23, 39).WithArguments("CheckValue", "paramName", "\"noMatch\""), VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(24, 53).WithArguments("CheckUserArg", "name", "\"chumble\""), VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(25, 53).WithArguments("CheckUserArg", "name", "\"sp\""), - new DiagnosticResult("CS1503", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 91).WithMessage("Argument 2: cannot convert from 'Microsoft.ML.Runtime.IHostEnvironment' to 'Microsoft.ML.Runtime.IExceptionContext'"), + new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 24).WithMessage("'ICancelable' is inaccessible due to its protection level"), + new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 67).WithMessage("'ICancelable.IsCanceled' is inaccessible due to its protection level"), + new DiagnosticResult("CS1503", DiagnosticSeverity.Error).WithLocation("Test1.cs", 753, 91).WithMessage("Argument 2: cannot convert from 'Microsoft.ML.Runtime.IHostEnvironment' to 'Microsoft.ML.Runtime.IExceptionContext'"), }, AdditionalReferences = { AdditionalMetadataReferences.RefFromType>() }, }, @@ -144,7 +148,9 @@ public async Task ContractsCheckFix() { VerifyCS.Diagnostic(ContractsCheckAnalyzer.ExceptionDiagnostic.Rule).WithLocation(9, 43).WithArguments("ExceptParam"), VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(23, 39).WithArguments("CheckValue", "paramName", "\"noMatch\""), - new DiagnosticResult("CS1503", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 91).WithMessage("Argument 2: cannot convert from 'Microsoft.ML.Runtime.IHostEnvironment' to 'Microsoft.ML.Runtime.IExceptionContext'"), + new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 24).WithMessage("'ICancelable' is inaccessible due to its protection level"), + new DiagnosticResult("CS0122", DiagnosticSeverity.Error).WithLocation("Test1.cs", 752, 67).WithMessage("'ICancelable.IsCanceled' is inaccessible due to its protection level"), + new DiagnosticResult("CS1503", DiagnosticSeverity.Error).WithLocation("Test1.cs", 753, 91).WithMessage("Argument 2: cannot convert from 'Microsoft.ML.Runtime.IHostEnvironment' to 'Microsoft.ML.Runtime.IExceptionContext'"), }, }, }; diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/ContractsCheckResource.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/ContractsCheckResource.cs index f4ab3d02d2..d8f5edb205 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/ContractsCheckResource.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/ContractsCheckResource.cs @@ -72,7 +72,6 @@ internal enum MessageSensitivity } internal interface IHostEnvironment : IExceptionContext { - bool IsCancelled { get; } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs index 186c0b7621..1c283ff4e9 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs @@ -55,8 +55,8 @@ public void TestCancellation() do { index = rand.Next(hosts.Count); - } while (hosts.ElementAt(index).Item1.IsCancelled || hosts.ElementAt(index).Item2 < 3); - hosts.ElementAt(index).Item1.StopExecution(); + } while ((hosts.ElementAt(index).Item1 as ICancelable).IsCanceled || hosts.ElementAt(index).Item2 < 3); + (hosts.ElementAt(index).Item1 as ICancelable).CancelExecution(); rootHost = hosts.ElementAt(index).Item1; queue.Enqueue(rootHost); } @@ -64,7 +64,7 @@ public void TestCancellation() while (queue.Count > 0) { var currentHost = queue.Dequeue(); - Assert.True(currentHost.IsCancelled); + Assert.True((currentHost as ICancelable).IsCanceled); if (children.ContainsKey(currentHost)) children[currentHost].ForEach(x => queue.Enqueue(x)); @@ -72,6 +72,36 @@ public void TestCancellation() } } + [Fact] + public void TestCancellationApi() + { + IHostEnvironment env = new MLContext(seed: 42); + var mainHost = env.Register("Main"); + var children = new ConcurrentDictionary>(); + var hosts = new BlockingCollection>(); + hosts.Add(new Tuple(mainHost.Register("1"), 1)); + hosts.Add(new Tuple(mainHost.Register("2"), 1)); + hosts.Add(new Tuple(mainHost.Register("3"), 1)); + hosts.Add(new Tuple(mainHost.Register("4"), 1)); + hosts.Add(new Tuple(mainHost.Register("5"), 1)); + + for (int i = 0; i < 5; i++) + { + var tupple = hosts.ElementAt(i); + var newHost = tupple.Item1.Register((tupple.Item2 + 1).ToString()); + hosts.Add(new Tuple(newHost, tupple.Item2 + 1)); + } + + ((MLContext)env).CancelExecution(); + + //Ensure all created hosts are cancelled. + //5 parent and one child for each. + Assert.Equal(10, hosts.Count); + + foreach (var host in hosts) + Assert.True((host.Item1 as ICancelable).IsCanceled); + } + /// /// Tests that MLContext's Log event intercepts messages properly. /// From c38f81b3957fed6aa88ea0e6b295522d5bf3f9ec Mon Sep 17 00:00:00 2001 From: Rogan Carr Date: Tue, 19 Mar 2019 15:44:32 -0700 Subject: [PATCH 09/18] Add functional tests for ONNX scenarios (#2984) * Adding functional tests for ONNX scenarios --- test/Microsoft.ML.Functional.Tests/Common.cs | 4 +- .../Datasets/CommonColumns.cs | 38 ++++ .../Datasets/FeatureColumn.cs | 14 -- .../Datasets/FeatureContributionOutput.cs | 15 -- test/Microsoft.ML.Functional.Tests/ONNX.cs | 179 ++++++++++++++++++ 5 files changed, 219 insertions(+), 31 deletions(-) create mode 100644 test/Microsoft.ML.Functional.Tests/Datasets/CommonColumns.cs delete mode 100644 test/Microsoft.ML.Functional.Tests/Datasets/FeatureColumn.cs delete mode 100644 test/Microsoft.ML.Functional.Tests/Datasets/FeatureContributionOutput.cs create mode 100644 test/Microsoft.ML.Functional.Tests/ONNX.cs diff --git a/test/Microsoft.ML.Functional.Tests/Common.cs b/test/Microsoft.ML.Functional.Tests/Common.cs index 58bbb730b2..6b6722cf5d 100644 --- a/test/Microsoft.ML.Functional.Tests/Common.cs +++ b/test/Microsoft.ML.Functional.Tests/Common.cs @@ -83,14 +83,14 @@ public static void AssertTestTypeDatasetsAreEqual(MLContext mlContext, IDataView /// /// An array of floats. /// An array of floats. - public static void AssertEqual(float[] array1, float[] array2) + public static void AssertEqual(float[] array1, float[] array2, int precision = 6) { Assert.NotNull(array1); Assert.NotNull(array2); Assert.Equal(array1.Length, array2.Length); for (int i = 0; i < array1.Length; i++) - Assert.Equal(array1[i], array2[i]); + Assert.Equal(array1[i], array2[i], precision: precision); } /// diff --git a/test/Microsoft.ML.Functional.Tests/Datasets/CommonColumns.cs b/test/Microsoft.ML.Functional.Tests/Datasets/CommonColumns.cs new file mode 100644 index 0000000000..348d2563f9 --- /dev/null +++ b/test/Microsoft.ML.Functional.Tests/Datasets/CommonColumns.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Functional.Tests.Datasets +{ + /// + /// A class to hold a feature column. + /// + internal sealed class FeatureColumn + { + public float[] Features { get; set; } + } + + /// + /// A class to hold the output of FeatureContributionCalculator. + /// + internal sealed class FeatureContributionOutput + { + public float[] FeatureContributions { get; set; } + } + + /// + /// A class to hold the Score column. + /// + internal sealed class ScoreColumn + { + public float Score { get; set; } + } + + /// + /// A class to hold a vector Score column. + /// + internal sealed class VectorScoreColumn + { + public float[] Score { get; set; } + } +} diff --git a/test/Microsoft.ML.Functional.Tests/Datasets/FeatureColumn.cs b/test/Microsoft.ML.Functional.Tests/Datasets/FeatureColumn.cs deleted file mode 100644 index 090ad23646..0000000000 --- a/test/Microsoft.ML.Functional.Tests/Datasets/FeatureColumn.cs +++ /dev/null @@ -1,14 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -namespace Microsoft.ML.Functional.Tests.Datasets -{ - /// - /// A class to hold a feature column. - /// - internal sealed class FeatureColumn - { - public float[] Features { get; set; } - } -} diff --git a/test/Microsoft.ML.Functional.Tests/Datasets/FeatureContributionOutput.cs b/test/Microsoft.ML.Functional.Tests/Datasets/FeatureContributionOutput.cs deleted file mode 100644 index 6aa8dcbb11..0000000000 --- a/test/Microsoft.ML.Functional.Tests/Datasets/FeatureContributionOutput.cs +++ /dev/null @@ -1,15 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - - -namespace Microsoft.ML.Functional.Tests.Datasets -{ - /// - /// A class to hold the output of FeatureContributionCalculator - /// - internal sealed class FeatureContributionOutput - { - public float[] FeatureContributions { get; set; } - } -} diff --git a/test/Microsoft.ML.Functional.Tests/ONNX.cs b/test/Microsoft.ML.Functional.Tests/ONNX.cs new file mode 100644 index 0000000000..3ece5658b8 --- /dev/null +++ b/test/Microsoft.ML.Functional.Tests/ONNX.cs @@ -0,0 +1,179 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.IO; +using Microsoft.ML.Functional.Tests.Datasets; +using Microsoft.ML.RunTests; +using Microsoft.ML.TestFramework; +using Microsoft.ML.TestFramework.Attributes; +using Microsoft.ML.Trainers; +using Microsoft.ML.Trainers.FastTree; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Functional.Tests +{ + public class ONNX : BaseTestClass + { + public ONNX(ITestOutputHelper output) : base(output) + { + } + + /// + /// ONNX: Models can be serialized to ONNX, deserialized back to ML.NET, and used a pipeline. + /// + [OnnxFactAttribute] + public void SaveOnnxModelLoadAndScoreFastTree() + { + var mlContext = new MLContext(seed: 1); + + // Get the dataset. + var data = mlContext.Data.LoadFromTextFile(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true); + + // Create a pipeline to train on the housing data. + var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features) + .Append(mlContext.Transforms.Normalize("Features")) + .AppendCacheCheckpoint(mlContext) + .Append(mlContext.Regression.Trainers.FastTree( + new FastTreeRegressionTrainer.Options { NumberOfThreads = 1, NumberOfTrees = 10 })); + + // Fit the pipeline. + var model = pipeline.Fit(data); + + // Serialize the pipeline to a file. + var modelFileName = "SaveOnnxLoadAndScoreFastTreeModel.onnx"; + var modelPath = DeleteOutputPath(modelFileName); + using (var file = File.Create(modelPath)) + mlContext.Model.ConvertToOnnx(model, data, file); + + // Load the model as a transform. + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(modelPath); + var onnxModel = onnxEstimator.Fit(data); + + // TODO #2980: ONNX outputs don't match the outputs of the model, so we must hand-correct this for now. + // TODO #2981: ONNX models cannot be fit as part of a pipeline, so we must use a workaround like this. + var onnxWorkaroundPipeline = onnxModel.Append( + mlContext.Transforms.CopyColumns("Score", "Score0").Fit(onnxModel.Transform(data))); + + // Create prediction engine and test predictions. + var originalPredictionEngine = mlContext.Model.CreatePredictionEngine(model); + // TODO #2982: ONNX produces vector types and not the original output type. + var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine(onnxWorkaroundPipeline); + + // Take a handful of examples out of the dataset and compute predictions. + var dataEnumerator = mlContext.Data.CreateEnumerable(mlContext.Data.TakeRows(data, 5), false); + foreach (var row in dataEnumerator) + { + var originalPrediction = originalPredictionEngine.Predict(row); + var onnxPrediction = onnxPredictionEngine.Predict(row); + // Check that the predictions are identical. + Assert.Equal(originalPrediction.Score, onnxPrediction.Score[0], precision: 4); // Note the low-precision equality! + } + } + + /// + /// ONNX: Models can be serialized to ONNX, deserialized back to ML.NET, and used a pipeline. + /// + [OnnxFactAttribute] + public void SaveOnnxModelLoadAndScoreKMeans() + { + var mlContext = new MLContext(seed: 1); + + // Get the dataset. + var data = mlContext.Data.LoadFromTextFile(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true); + + // Create a pipeline to train on the housing data. + var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features) + .Append(mlContext.Transforms.Normalize("Features")) + .AppendCacheCheckpoint(mlContext) + .Append(mlContext.Clustering.Trainers.KMeans( + new KMeansTrainer.Options { NumberOfThreads = 1, MaximumNumberOfIterations = 10 })); + + // Fit the pipeline. + var model = pipeline.Fit(data); + + // Serialize the pipeline to a file. + var modelFileName = "SaveOnnxLoadAndScoreKMeansModel.onnx"; + var modelPath = DeleteOutputPath(modelFileName); + using (var file = File.Create(modelPath)) + mlContext.Model.ConvertToOnnx(model, data, file); + + // Load the model as a transform. + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(modelPath); + var onnxModel = onnxEstimator.Fit(data); + + // TODO #2980: ONNX outputs don't match the outputs of the model, so we must hand-correct this for now. + // TODO #2981: ONNX models cannot be fit as part of a pipeline, so we must use a workaround like this. + var onnxWorkaroundPipeline = onnxModel.Append( + mlContext.Transforms.CopyColumns("Score", "Score0").Fit(onnxModel.Transform(data))); + + // Create prediction engine and test predictions. + var originalPredictionEngine = mlContext.Model.CreatePredictionEngine(model); + // TODO #2982: ONNX produces vector types and not the original output type. + var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine(onnxWorkaroundPipeline); + + // Take a handful of examples out of the dataset and compute predictions. + var dataEnumerator = mlContext.Data.CreateEnumerable(mlContext.Data.TakeRows(data, 5), false); + foreach (var row in dataEnumerator) + { + var originalPrediction = originalPredictionEngine.Predict(row); + var onnxPrediction = onnxPredictionEngine.Predict(row); + // Check that the predictions are identical. + Common.AssertEqual(originalPrediction.Score, onnxPrediction.Score, precision: 4); // Note the low precision! + } + } + + /// + /// ONNX: Models can be serialized to ONNX, deserialized back to ML.NET, and used a pipeline. + /// + [OnnxFactAttribute] + public void SaveOnnxModelLoadAndScoreSDCA() + { + var mlContext = new MLContext(seed: 1); + + // Get the dataset. + var data = mlContext.Data.LoadFromTextFile(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true); + + // Create a pipeline to train on the housing data. + var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features) + .Append(mlContext.Transforms.Normalize("Features")) + .AppendCacheCheckpoint(mlContext) + .Append(mlContext.Regression.Trainers.Sdca( + new SdcaRegressionTrainer.Options { NumberOfThreads = 1, MaximumNumberOfIterations = 10 })); + + // Fit the pipeline. + var model = pipeline.Fit(data); + + // Serialize the pipeline to a file. + var modelFileName = "SaveOnnxLoadAndScoreSdcaModel.onnx"; + var modelPath = DeleteOutputPath(modelFileName); + using (var file = File.Create(modelPath)) + mlContext.Model.ConvertToOnnx(model, data, file); + + // Load the model as a transform. + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(modelPath); + var onnxModel = onnxEstimator.Fit(data); + + // TODO #2980: ONNX outputs don't match the outputs of the model, so we must hand-correct this for now. + // TODO #2981: ONNX models cannot be fit as part of a pipeline, so we must use a workaround like this. + var onnxWorkaroundPipeline = onnxModel.Append( + mlContext.Transforms.CopyColumns("Score", "Score0").Fit(onnxModel.Transform(data))); + + // Create prediction engine and test predictions. + var originalPredictionEngine = mlContext.Model.CreatePredictionEngine(model); + // TODO #2982: ONNX produces vector types and not the original output type. + var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine(onnxWorkaroundPipeline); + + // Take a handful of examples out of the dataset and compute predictions. + var dataEnumerator = mlContext.Data.CreateEnumerable(mlContext.Data.TakeRows(data, 5), false); + foreach (var row in dataEnumerator) + { + var originalPrediction = originalPredictionEngine.Predict(row); + var onnxPrediction = onnxPredictionEngine.Predict(row); + // Check that the predictions are identical. + Assert.Equal(originalPrediction.Score, onnxPrediction.Score[0], precision: 4); // Note the low-precision equality! + } + } + } +} From 3af9a5d96ade88e888894af23baef8fe4598f826 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 19 Mar 2019 17:19:49 -0700 Subject: [PATCH 10/18] Make Multiclass Linear Trainers Typed Based on Output Model Types. (#2976) * Step 1: create two multi-class linear models Step 2: Make SDCA trainers typed Finish version 0.1 Delete commented lines * Add some doc strings More document * Handle static extensions * Rename several maximum entropy models and trainers * Fix EP test Fix two tests and address a comment Add missing piece * Address comments * Improve option of MCSDCA * Address comments * Update code sample * Refactorize saving family * Rename a class following binary SDCA trainer --- docs/code/MlNetCookBook.md | 10 +- .../StochasticDualCoordinateAscent.cs | 2 +- ...ochasticDualCoordinateAscentWithOptions.cs | 9 +- .../MulticlassDataPartitionEnsembleTrainer.cs | 6 +- .../MulticlassLogisticRegression.cs | 615 +++++++++++------- .../Standard/SdcaMulticlass.cs | 171 ++++- .../StandardTrainersCatalog.cs | 82 ++- src/Microsoft.ML.StaticPipe/LbfgsStatic.cs | 18 +- .../SdcaStaticExtensions.cs | 107 ++- .../CommandTrainMlrWithStats-summary.txt | 2 +- .../Common/EntryPoints/core_ep-list.tsv | 4 +- .../Common/EntryPoints/core_manifest.json | 2 +- ...nLogisticRegressionSaveModelToOnnxTest.txt | 2 +- .../PredictionEngineBench.cs | 4 +- ...sticDualCoordinateAscentClassifierBench.cs | 10 +- .../Text/MultiClassClassification.cs | 2 +- .../UnitTests/TestEntryPoints.cs | 6 +- .../Evaluation.cs | 4 +- .../IntrospectiveTraining.cs | 8 +- .../Microsoft.ML.Functional.Tests/Training.cs | 4 +- .../TestPredictors.cs | 4 +- .../Training.cs | 61 +- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 2 +- .../PermutationFeatureImportanceTests.cs | 6 +- .../Api/CookbookSamples/CookbookSamples.cs | 2 +- .../CookbookSamplesDynamicApi.cs | 6 +- .../Estimators/DecomposableTrainAndPredict.cs | 4 +- .../Scenarios/Api/Estimators/Extensibility.cs | 4 +- .../Api/Estimators/PredictAndMetadata.cs | 4 +- .../Scenarios/IrisPlantClassificationTests.cs | 4 +- ...PlantClassificationWithStringLabelTests.cs | 4 +- .../Scenarios/TensorflowTests.cs | 2 +- .../IrisPlantClassificationTests.cs | 4 +- .../TrainerEstimators/LbfgsTests.cs | 16 +- .../TrainerEstimators/SdcaTests.cs | 73 ++- 35 files changed, 895 insertions(+), 369 deletions(-) diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md index 87c5a580ff..716e8fac12 100644 --- a/docs/code/MlNetCookBook.md +++ b/docs/code/MlNetCookBook.md @@ -244,7 +244,7 @@ We tried to make `Preview` debugger-friendly: our expectation is that, if you en Here is the code sample: ```csharp var estimator = mlContext.Transforms.Categorical.MapValueToKey("Label") - .Append(mlContext.MulticlassClassification.Trainers.Sdca()) + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated()) .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel")); var data = mlContext.Data.LoadFromTextFile(new TextLoader.Column[] { @@ -423,7 +423,7 @@ var pipeline = // Cache data in memory for steps after the cache check point stage. .AppendCacheCheckpoint(mlContext) // Use the multi-class SDCA model to predict the label using features. - .Append(mlContext.MulticlassClassification.Trainers.Sdca()) + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated()) // Apply the inverse conversion from 'PredictedLabel' column back to string value. .Append(mlContext.Transforms.Conversion.MapKeyToValue(("PredictedLabel", "Data"))); @@ -547,13 +547,13 @@ var pipeline = // Cache data in memory for steps after the cache check point stage. .AppendCacheCheckpoint(mlContext) // Use the multi-class SDCA model to predict the label using features. - .Append(mlContext.MulticlassClassification.Trainers.Sdca()); + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated()); // Train the model. var trainedModel = pipeline.Fit(trainData); // Inspect the model parameters. -var modelParameters = trainedModel.LastTransformer.Model as MulticlassLogisticRegressionModelParameters; +var modelParameters = trainedModel.LastTransformer.Model as MaximumEntropyModelParameters; // Now we can use 'modelParameters' to look at the weights. // 'weights' will be an array of weight vectors, one vector per class. @@ -822,7 +822,7 @@ var pipeline = // Notice that unused part in the data may not be cached. .AppendCacheCheckpoint(mlContext) // Use the multi-class SDCA model to predict the label using features. - .Append(mlContext.MulticlassClassification.Trainers.Sdca()); + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated()); // Split the data 90:10 into train and test sets, train and evaluate. var split = mlContext.Data.TrainTestSplit(data, testFraction: 0.1); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscent.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscent.cs index a318477344..af632b5808 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscent.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscent.cs @@ -30,7 +30,7 @@ public static void Example() // Convert the string labels into key types. mlContext.Transforms.Conversion.MapValueToKey("Label") // Apply StochasticDualCoordinateAscent multiclass trainer. - .Append(mlContext.MulticlassClassification.Trainers.Sdca()); + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated()); // Split the data into training and test sets. Only training set is used in fitting // the created pipeline. Metrics are computed on the test. diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs index 2d6440dc22..10bc9c7918 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs @@ -1,5 +1,4 @@ -using Microsoft.ML.Data; -using Microsoft.ML.SamplesUtils; +using Microsoft.ML.SamplesUtils; using Microsoft.ML.Trainers; namespace Microsoft.ML.Samples.Dynamic.Trainers.MulticlassClassification @@ -26,10 +25,10 @@ public static void Example() // CC 1.216908,1.248052,1.391902,0.4326252,1.099942,0.9262842,1.334019,1.08762,0.9468155,0.4811099 // DD 0.7871246,1.053327,0.8971719,1.588544,1.242697,1.362964,0.6303943,0.9810045,0.9431419,1.557455 - var options = new SdcaMulticlassTrainer.Options + var options = new SdcaNonCalibratedMulticlassTrainer.Options { // Add custom loss - LossFunction = new HingeLoss(), + Loss = new HingeLoss(), // Make the convergence tolerance tighter. ConvergenceTolerance = 0.05f, // Increase the maximum number of passes over training data. @@ -41,7 +40,7 @@ public static void Example() // Convert the string labels into key types. mlContext.Transforms.Conversion.MapValueToKey("Label") // Apply StochasticDualCoordinateAscent multiclass trainer. - .Append(mlContext.MulticlassClassification.Trainers.Sdca(options)); + .Append(mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated(options)); // Split the data into training and test sets. Only training set is used in fitting // the created pipeline. Metrics are computed on the test. diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index 9946cc66e5..808586f435 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -64,9 +64,9 @@ public Arguments() // non-default column names. Unfortuantely no method of resolving this temporary strikes me as being any // less laborious than the proper fix, which is that this "meta" component should itself be a trainer // estimator, as opposed to a regular trainer. - var trainerEstimator = new LogisticRegressionMulticlassClassificationTrainer(env, LabelColumnName, FeatureColumnName); - return TrainerUtils.MapTrainerEstimatorToTrainer(env, trainerEstimator); + var trainerEstimator = new LbfgsMaximumEntropyTrainer(env, LabelColumnName, FeatureColumnName); + return TrainerUtils.MapTrainerEstimatorToTrainer(env, trainerEstimator); }) }; } diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs index f56bc329d2..4c36daa5ce 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -19,25 +19,28 @@ using Microsoft.ML.Trainers; using Newtonsoft.Json.Linq; -[assembly: LoadableClass(typeof(LogisticRegressionMulticlassClassificationTrainer), typeof(LogisticRegressionMulticlassClassificationTrainer.Options), +[assembly: LoadableClass(typeof(LbfgsMaximumEntropyTrainer), typeof(LbfgsMaximumEntropyTrainer.Options), new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer) }, - LogisticRegressionMulticlassClassificationTrainer.UserNameValue, - LogisticRegressionMulticlassClassificationTrainer.LoadNameValue, + LbfgsMaximumEntropyTrainer.UserNameValue, + LbfgsMaximumEntropyTrainer.LoadNameValue, "MulticlassLogisticRegressionPredictorNew", - LogisticRegressionMulticlassClassificationTrainer.ShortName, + LbfgsMaximumEntropyTrainer.ShortName, "multilr")] -[assembly: LoadableClass(typeof(MulticlassLogisticRegressionModelParameters), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(MaximumEntropyModelParameters), null, typeof(SignatureLoadModel), "Multiclass LR Executor", - MulticlassLogisticRegressionModelParameters.LoaderSignature)] + MaximumEntropyModelParameters.LoaderSignature)] + +[assembly: LoadableClass(typeof(void), typeof(LbfgsMaximumEntropyTrainer), null, typeof(SignatureEntryPointModule), LbfgsMaximumEntropyTrainer.LoadNameValue)] namespace Microsoft.ML.Trainers { /// /// - public sealed class LogisticRegressionMulticlassClassificationTrainer : LbfgsTrainerBase, MulticlassLogisticRegressionModelParameters> + public sealed class LbfgsMaximumEntropyTrainer : LbfgsTrainerBase, MaximumEntropyModelParameters> { + internal const string Summary = "Maximum entrypy classification is a method in statistics used to predict the probabilities of parallel events. The model predicts the probabilities of parallel events by fitting data to a softmax function."; internal const string LoadNameValue = "MultiClassLogisticRegression"; internal const string UserNameValue = "Multi-class Logistic Regression"; internal const string ShortName = "mlr"; @@ -71,7 +74,7 @@ public sealed class Options : OptionsBase private protected override int ClassCount => _numClasses; /// - /// Initializes a new instance of + /// Initializes a new instance of . /// /// The environment to use. /// The name of the label column. @@ -82,7 +85,7 @@ public sealed class Options : OptionsBase /// Weight of L2 regularizer term. /// Memory size for . Low=faster, less accurate. /// Threshold for optimizer convergence. - internal LogisticRegressionMulticlassClassificationTrainer(IHostEnvironment env, + internal LbfgsMaximumEntropyTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weights = null, @@ -100,9 +103,9 @@ internal LogisticRegressionMulticlassClassificationTrainer(IHostEnvironment env, } /// - /// Initializes a new instance of + /// Initializes a new instance of . /// - internal LogisticRegressionMulticlassClassificationTrainer(IHostEnvironment env, Options options) + internal LbfgsMaximumEntropyTrainer(IHostEnvironment env, Options options) : base(env, options, TrainerUtils.MakeU4ScalarColumn(options.LabelColumnName)) { ShowTrainingStats = LbfgsTrainerOptions.ShowTrainingStatistics; @@ -223,7 +226,7 @@ private protected override float AccumulateOneGradient(in VBuffer feat, f private protected override VBuffer InitializeWeightsFromPredictor(IPredictor srcPredictor) { - var pred = srcPredictor as MulticlassLogisticRegressionModelParameters; + var pred = srcPredictor as MaximumEntropyModelParameters; Contracts.AssertValue(pred); Contracts.Assert(pred.InputType.GetVectorSize() > 0); @@ -234,7 +237,7 @@ private protected override VBuffer InitializeWeightsFromPredictor(IPredic return InitializeWeights(pred.DenseWeightsEnumerable(), pred.GetBiases()); } - private protected override MulticlassLogisticRegressionModelParameters CreatePredictor() + private protected override MaximumEntropyModelParameters CreatePredictor() { if (_numClasses < 1) throw Contracts.Except("Cannot create a multiclass predictor with {0} classes", _numClasses); @@ -246,7 +249,7 @@ private protected override MulticlassLogisticRegressionModelParameters CreatePre } } - return new MulticlassLogisticRegressionModelParameters(Host, in CurrentWeights, _numClasses, NumFeatures, _labelNames, _stats); + return new MaximumEntropyModelParameters(Host, in CurrentWeights, _numClasses, NumFeatures, _labelNames, _stats); } private protected override void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor.Factory cursorFactory, float loss, int numParams) @@ -324,21 +327,39 @@ private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape }; } - private protected override MulticlassPredictionTransformer MakeTransformer(MulticlassLogisticRegressionModelParameters model, DataViewSchema trainSchema) - => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); + private protected override MulticlassPredictionTransformer MakeTransformer(MaximumEntropyModelParameters model, DataViewSchema trainSchema) + => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); /// - /// Continues the training of a using an already trained and returns + /// Continues the training of a using an already trained and returns /// a . /// - public MulticlassPredictionTransformer Fit(IDataView trainData, MulticlassLogisticRegressionModelParameters modelParameters) + public MulticlassPredictionTransformer Fit(IDataView trainData, MaximumEntropyModelParameters modelParameters) => TrainTransformer(trainData, initPredictor: modelParameters); + + [TlcModule.EntryPoint(Name = "Trainers.LogisticRegressionClassifier", + Desc = LbfgsMaximumEntropyTrainer.Summary, + UserName = LbfgsMaximumEntropyTrainer.UserNameValue, + ShortName = LbfgsMaximumEntropyTrainer.ShortName)] + internal static CommonOutputs.MulticlassClassificationOutput TrainMulticlass(IHostEnvironment env, LbfgsMaximumEntropyTrainer.Options input) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register("TrainLRMultiClass"); + host.CheckValue(input, nameof(input)); + EntryPointUtils.CheckInputArgs(host, input); + + return TrainerEntryPointsUtils.Train(host, input, + () => new LbfgsMaximumEntropyTrainer(host, input), + () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName), + () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName)); + } } /// - /// The model parameters class for Multiclass Logistic Regression. + /// Common linear model of multiclass classifiers. contains a single + /// linear model per class. /// - public sealed class MulticlassLogisticRegressionModelParameters : + public abstract class LinearMulticlassModelParametersBase : ModelParametersBase>, IValueMapper, ICanSaveInTextFormat, @@ -350,32 +371,16 @@ public sealed class MulticlassLogisticRegressionModelParameters : ISingleCanSavePfa, ISingleCanSaveOnnx { - internal const string LoaderSignature = "MultiClassLRExec"; - internal const string RegistrationName = "MulticlassLogisticRegressionPredictor"; - - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "MULTI LR", - // verWrittenCur: 0x00010001, // Initial - // verWrittenCur: 0x00010002, // Added class names - verWrittenCur: 0x00010003, // Added model stats - verReadableCur: 0x00010001, - verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(MulticlassLogisticRegressionModelParameters).Assembly.FullName); - } - private const string ModelStatsSubModelFilename = "ModelStats"; private const string LabelNamesSubModelFilename = "LabelNames"; - private readonly int _numClasses; - private readonly int _numFeatures; + private protected readonly int NumberOfClasses; + private protected readonly int NumberOfFeatures; // The label names used to write model summary. Either null or of length _numClasses. private readonly string[] _labelNames; - private readonly float[] _biases; - private readonly VBuffer[] _weights; + private protected readonly float[] Biases; + private protected readonly VBuffer[] Weights; public readonly ModelStatisticsBase Statistics; // This stores the _weights matrix in dense format for performance. @@ -395,30 +400,30 @@ private static VersionInfo GetVersionInfo() bool ICanSavePfa.CanSavePfa => true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; - internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, in VBuffer weights, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null) - : base(env, RegistrationName) + internal LinearMulticlassModelParametersBase(IHostEnvironment env, string name, in VBuffer weights, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null) + : base(env, name) { Contracts.Assert(weights.Length == numClasses + numClasses * numFeatures); - _numClasses = numClasses; - _numFeatures = numFeatures; + NumberOfClasses = numClasses; + NumberOfFeatures = numFeatures; // weights contains both bias and feature weights in a flat vector // Biases are stored in the first _numClass elements // followed by one weight vector for each class, in turn, all concatenated // (i.e.: in "row major", if we encode each weight vector as a row of a matrix) - Contracts.Assert(weights.Length == _numClasses + _numClasses * _numFeatures); + Contracts.Assert(weights.Length == NumberOfClasses + NumberOfClasses * NumberOfFeatures); - _biases = new float[_numClasses]; - for (int i = 0; i < _biases.Length; i++) - weights.GetItemOrDefault(i, ref _biases[i]); - _weights = new VBuffer[_numClasses]; - for (int i = 0; i < _weights.Length; i++) - weights.CopyTo(ref _weights[i], _numClasses + i * _numFeatures, _numFeatures); - if (_weights.All(v => v.IsDense)) - _weightsDense = _weights; + Biases = new float[NumberOfClasses]; + for (int i = 0; i < Biases.Length; i++) + weights.GetItemOrDefault(i, ref Biases[i]); + Weights = new VBuffer[NumberOfClasses]; + for (int i = 0; i < Weights.Length; i++) + weights.CopyTo(ref Weights[i], NumberOfClasses + i * NumberOfFeatures, NumberOfFeatures); + if (Weights.All(v => v.IsDense)) + _weightsDense = Weights; - InputType = new VectorType(NumberDataViewType.Single, _numFeatures); - OutputType = new VectorType(NumberDataViewType.Single, _numClasses); + InputType = new VectorType(NumberDataViewType.Single, NumberOfFeatures); + OutputType = new VectorType(NumberDataViewType.Single, NumberOfClasses); Contracts.Assert(labelNames == null || labelNames.Length == numClasses); _labelNames = labelNames; @@ -428,41 +433,42 @@ internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, in VB } /// - /// Initializes a new instance of the class. - /// This constructor is called by to create the predictor. + /// Initializes a new instance of the class. + /// This constructor is called by to create the predictor. /// /// The host environment. + /// Registration name of this model's actual type. /// The array of weights vectors. It should contain weights. /// The array of biases. It should contain contain weights. /// The number of classes for multi-class classification. Must be at least 2. /// The length of the feature vector. /// The optional label names. If specified not null, it should have the same length as . /// The model statistics. - internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, VBuffer[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null) - : base(env, RegistrationName) + internal LinearMulticlassModelParametersBase(IHostEnvironment env, string name, VBuffer[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null) + : base(env, name) { Contracts.CheckValue(weights, nameof(weights)); Contracts.CheckValue(bias, nameof(bias)); Contracts.CheckParam(numClasses >= 2, nameof(numClasses), "Must be at least 2."); - _numClasses = numClasses; + NumberOfClasses = numClasses; Contracts.CheckParam(numFeatures >= 1, nameof(numFeatures), "Must be positive."); - _numFeatures = numFeatures; - Contracts.Check(Utils.Size(weights) == _numClasses); - Contracts.Check(Utils.Size(bias) == _numClasses); - _weights = new VBuffer[_numClasses]; - _biases = new float[_numClasses]; - for (int iClass = 0; iClass < _numClasses; iClass++) + NumberOfFeatures = numFeatures; + Contracts.Check(Utils.Size(weights) == NumberOfClasses); + Contracts.Check(Utils.Size(bias) == NumberOfClasses); + Weights = new VBuffer[NumberOfClasses]; + Biases = new float[NumberOfClasses]; + for (int iClass = 0; iClass < NumberOfClasses; iClass++) { - Contracts.Assert(weights[iClass].Length == _numFeatures); - weights[iClass].CopyTo(ref _weights[iClass]); - _biases[iClass] = bias[iClass]; + Contracts.Assert(weights[iClass].Length == NumberOfFeatures); + weights[iClass].CopyTo(ref Weights[iClass]); + Biases[iClass] = bias[iClass]; } - if (_weights.All(v => v.IsDense)) - _weightsDense = _weights; + if (Weights.All(v => v.IsDense)) + _weightsDense = Weights; - InputType = new VectorType(NumberDataViewType.Single, _numFeatures); - OutputType = new VectorType(NumberDataViewType.Single, _numClasses); + InputType = new VectorType(NumberDataViewType.Single, NumberOfFeatures); + OutputType = new VectorType(NumberDataViewType.Single, NumberOfClasses); Contracts.Assert(labelNames == null || labelNames.Length == numClasses); _labelNames = labelNames; @@ -471,8 +477,8 @@ internal MulticlassLogisticRegressionModelParameters(IHostEnvironment env, VBuff Statistics = stats; } - private MulticlassLogisticRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx) - : base(env, RegistrationName, ctx) + private protected LinearMulticlassModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx) + : base(env, name, ctx) { // *** Binary format *** // int: number of features @@ -489,13 +495,13 @@ private MulticlassLogisticRegressionModelParameters(IHostEnvironment env, ModelL // int[]: Id of label names (optional, in a separate stream) // ModelStatisticsBase: model statistics (optional, in a separate stream) - _numFeatures = ctx.Reader.ReadInt32(); - Host.CheckDecode(_numFeatures >= 1); + NumberOfFeatures = ctx.Reader.ReadInt32(); + Host.CheckDecode(NumberOfFeatures >= 1); - _numClasses = ctx.Reader.ReadInt32(); - Host.CheckDecode(_numClasses >= 1); + NumberOfClasses = ctx.Reader.ReadInt32(); + Host.CheckDecode(NumberOfClasses >= 1); - _biases = ctx.Reader.ReadFloatArray(_numClasses); + Biases = ctx.Reader.ReadFloatArray(NumberOfClasses); int numStarts = ctx.Reader.ReadInt32(); @@ -505,19 +511,19 @@ private MulticlassLogisticRegressionModelParameters(IHostEnvironment env, ModelL int numIndices = ctx.Reader.ReadInt32(); Host.CheckDecode(numIndices == 0); int numWeights = ctx.Reader.ReadInt32(); - Host.CheckDecode(numWeights == _numClasses * _numFeatures); - _weights = new VBuffer[_numClasses]; - for (int i = 0; i < _weights.Length; i++) + Host.CheckDecode(numWeights == NumberOfClasses * NumberOfFeatures); + Weights = new VBuffer[NumberOfClasses]; + for (int i = 0; i < Weights.Length; i++) { - var w = ctx.Reader.ReadFloatArray(_numFeatures); - _weights[i] = new VBuffer(_numFeatures, w); + var w = ctx.Reader.ReadFloatArray(NumberOfFeatures); + Weights[i] = new VBuffer(NumberOfFeatures, w); } - _weightsDense = _weights; + _weightsDense = Weights; } else { // Read weight matrix as CSR. - Host.CheckDecode(numStarts == _numClasses + 1); + Host.CheckDecode(numStarts == NumberOfClasses + 1); int[] starts = ctx.Reader.ReadIntArray(numStarts); Host.CheckDecode(starts[0] == 0); Host.CheckDecode(Utils.IsMonotonicallyIncreasing(starts)); @@ -525,26 +531,26 @@ private MulticlassLogisticRegressionModelParameters(IHostEnvironment env, ModelL int numIndices = ctx.Reader.ReadInt32(); Host.CheckDecode(numIndices == starts[starts.Length - 1]); - var indices = new int[_numClasses][]; + var indices = new int[NumberOfClasses][]; for (int i = 0; i < indices.Length; i++) { indices[i] = ctx.Reader.ReadIntArray(starts[i + 1] - starts[i]); - Host.CheckDecode(Utils.IsIncreasing(0, indices[i], _numFeatures)); + Host.CheckDecode(Utils.IsIncreasing(0, indices[i], NumberOfFeatures)); } int numValues = ctx.Reader.ReadInt32(); Host.CheckDecode(numValues == numIndices); - _weights = new VBuffer[_numClasses]; - for (int i = 0; i < _weights.Length; i++) + Weights = new VBuffer[NumberOfClasses]; + for (int i = 0; i < Weights.Length; i++) { float[] values = ctx.Reader.ReadFloatArray(starts[i + 1] - starts[i]); - _weights[i] = new VBuffer(_numFeatures, Utils.Size(values), values, indices[i]); + Weights[i] = new VBuffer(NumberOfFeatures, Utils.Size(values), values, indices[i]); } } WarnOnOldNormalizer(ctx, GetType(), Host); - InputType = new VectorType(NumberDataViewType.Single, _numFeatures); - OutputType = new VectorType(NumberDataViewType.Single, _numClasses); + InputType = new VectorType(NumberDataViewType.Single, NumberOfFeatures); + OutputType = new VectorType(NumberDataViewType.Single, NumberOfClasses); // REVIEW: Should not save the label names duplicately with the predictor again. // Get it from the label column schema metadata instead. @@ -560,24 +566,18 @@ private MulticlassLogisticRegressionModelParameters(IHostEnvironment env, ModelL Statistics = stats; } - private static MulticlassLogisticRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); - return new MulticlassLogisticRegressionModelParameters(env, ctx); - } + private protected abstract VersionInfo GetVersionInfo(); private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); - Host.Assert(_biases.Length == _numClasses); - Host.Assert(_biases.Length == _weights.Length); + Host.Assert(Biases.Length == NumberOfClasses); + Host.Assert(Biases.Length == Weights.Length); #if DEBUG - foreach (var fw in _weights) - Host.Assert(fw.Length == _numFeatures); + foreach (var fw in Weights) + Host.Assert(fw.Length == NumberOfFeatures); #endif // *** Binary format *** // int: number of features @@ -595,32 +595,32 @@ private protected override void SaveCore(ModelSaveContext ctx) // int[]: Id of label names (optional, in a separate stream) // LinearModelParameterStatistics: model statistics (optional, in a separate stream) - ctx.Writer.Write(_numFeatures); - ctx.Writer.Write(_numClasses); - ctx.Writer.WriteSinglesNoCount(_biases.AsSpan(0, _numClasses)); + ctx.Writer.Write(NumberOfFeatures); + ctx.Writer.Write(NumberOfClasses); + ctx.Writer.WriteSinglesNoCount(Biases.AsSpan(0, NumberOfClasses)); // _weights == _weighsDense means we checked that all vectors in _weights // are actually dense, and so we assigned the same object, or it came dense // from deserialization. - if (_weights == _weightsDense) + if (Weights == _weightsDense) { ctx.Writer.Write(0); // Number of starts. ctx.Writer.Write(0); // Number of indices. - ctx.Writer.Write(_numFeatures * _weights.Length); - foreach (var fv in _weights) + ctx.Writer.Write(NumberOfFeatures * Weights.Length); + foreach (var fv in Weights) { - Host.Assert(fv.Length == _numFeatures); + Host.Assert(fv.Length == NumberOfFeatures); ctx.Writer.WriteSinglesNoCount(fv.GetValues()); } } else { // Number of starts. - ctx.Writer.Write(_numClasses + 1); + ctx.Writer.Write(NumberOfClasses + 1); // Starts always starts with 0. int numIndices = 0; ctx.Writer.Write(numIndices); - for (int i = 0; i < _weights.Length; i++) + for (int i = 0; i < Weights.Length; i++) { // REVIEW: Assuming the presence of *any* zero justifies // writing in sparse format seems stupid, but might be difficult @@ -629,7 +629,7 @@ private protected override void SaveCore(ModelSaveContext ctx) // This is actually a bug waiting to happen: sparse/dense vectors // can have different dot products even if they are logically the // same vector. - numIndices += NonZeroCount(in _weights[i]); + numIndices += NonZeroCount(in Weights[i]); ctx.Writer.Write(numIndices); } @@ -637,7 +637,7 @@ private protected override void SaveCore(ModelSaveContext ctx) { // just scoping the count so we can use another further down int count = 0; - foreach (var fw in _weights) + foreach (var fw in Weights) { var fwValues = fw.GetValues(); if (fw.IsDense) @@ -665,7 +665,7 @@ private protected override void SaveCore(ModelSaveContext ctx) { int count = 0; - foreach (var fw in _weights) + foreach (var fw in Weights) { var fwValues = fw.GetValues(); if (fw.IsDense) @@ -719,7 +719,7 @@ ValueMapper IValueMapper.GetMapper() ValueMapper, VBuffer> del = (in VBuffer src, ref VBuffer dst) => { - Host.Check(src.Length == _numFeatures); + Host.Check(src.Length == NumberOfFeatures); PredictCore(in src, ref dst); }; @@ -728,14 +728,14 @@ ValueMapper IValueMapper.GetMapper() private void PredictCore(in VBuffer src, ref VBuffer dst) { - Host.Check(src.Length == _numFeatures, "src length should equal the number of features"); - var weights = _weights; + Host.Check(src.Length == NumberOfFeatures, "src length should equal the number of features"); + var weights = Weights; if (!src.IsDense) weights = DensifyWeights(); - var editor = VBufferEditor.Create(ref dst, _numClasses); - for (int i = 0; i < _biases.Length; i++) - editor.Values[i] = _biases[i] + VectorUtils.DotProduct(in weights[i], in src); + var editor = VBufferEditor.Create(ref dst, NumberOfClasses); + for (int i = 0; i < Biases.Length; i++) + editor.Values[i] = Biases[i] + VectorUtils.DotProduct(in weights[i], in src); Calibrate(editor.Values); dst = editor.Commit(); @@ -745,16 +745,16 @@ private VBuffer[] DensifyWeights() { if (_weightsDense == null) { - lock (_weights) + lock (Weights) { if (_weightsDense == null) { - var weightsDense = new VBuffer[_numClasses]; - for (int i = 0; i < _weights.Length; i++) + var weightsDense = new VBuffer[NumberOfClasses]; + for (int i = 0; i < Weights.Length; i++) { // Haven't yet created dense version of the weights. // REVIEW: Should we always expand to full weights or should this be subject to an option? - var w = _weights[i]; + var w = Weights[i]; if (w.IsDense) weightsDense[i] = w; else @@ -768,35 +768,14 @@ private VBuffer[] DensifyWeights() return _weightsDense; } - private void Calibrate(Span dst) - { - Host.Assert(dst.Length >= _numClasses); - - // scores are in log-space; convert and fix underflow/overflow - // TODO: re-normalize probabilities to account for underflow/overflow? - float softmax = MathUtils.SoftMax(dst.Slice(0, _numClasses)); - for (int i = 0; i < _numClasses; ++i) - dst[i] = MathUtils.ExpSlow(dst[i] - softmax); - } - /// - /// Output the text model to a given writer + /// Post-processing function applied to scores of each class' linear model output. + /// In we compute the i-th class' score + /// by using inner product of the i-th linear coefficient vector [i] and the input feature vector (plus bias). + /// Then, will be called to adjust those raw scores. /// - void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) - { - writer.WriteLine(nameof(LogisticRegressionMulticlassClassificationTrainer) + " bias and non-zero weights"); + private protected abstract void Calibrate(Span dst); - foreach (var namedValues in ((ICanGetSummaryInKeyValuePairs)this).GetSummaryInKeyValuePairs(schema)) - { - Host.Assert(namedValues.Value is float); - writer.WriteLine("\t{0}\t{1}", namedValues.Key, (float)namedValues.Value); - } - - if (Statistics != null) - Statistics.SaveText(writer, schema.Feature.Value, 20); - } - - /// IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema) { Host.CheckValueOrNull(schema); @@ -804,18 +783,18 @@ IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKe List> results = new List>(); var names = default(VBuffer>); - AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, _numFeatures, ref names); - for (int classNumber = 0; classNumber < _biases.Length; classNumber++) + AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, NumberOfFeatures, ref names); + for (int classNumber = 0; classNumber < Biases.Length; classNumber++) { results.Add(new KeyValuePair( string.Format("{0}+(Bias)", GetLabelName(classNumber)), - _biases[classNumber] + Biases[classNumber] )); } - for (int classNumber = 0; classNumber < _weights.Length; classNumber++) + for (int classNumber = 0; classNumber < Weights.Length; classNumber++) { - var orderedWeights = _weights[classNumber].Items().OrderByDescending(kv => Math.Abs(kv.Value)); + var orderedWeights = Weights[classNumber].Items().OrderByDescending(kv => Math.Abs(kv.Value)); foreach (var weight in orderedWeights) { var value = weight.Value; @@ -835,33 +814,72 @@ IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKe } /// - /// Output the text model to a given writer + /// Actual implementation of should happen in derived classes. /// - void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) + private void SaveAsTextCore(TextWriter writer, RoleMappedSchema schema) { - Host.CheckValue(writer, nameof(writer)); - Host.CheckValueOrNull(schema); + writer.WriteLine(GetTrainerName() + " bias and non-zero weights"); - for (int i = 0; i < _biases.Length; i++) + foreach (var namedValues in ((ICanGetSummaryInKeyValuePairs)this).GetSummaryInKeyValuePairs(schema)) { - LinearPredictorUtils.SaveAsCode(writer, - in _weights[i], - _biases[i], - schema, - "score[" + i.ToString() + "]"); + Host.Assert(namedValues.Value is float); + writer.WriteLine("\t{0}\t{1}", namedValues.Key, (float)namedValues.Value); } - writer.WriteLine(string.Format("var softmax = MathUtils.SoftMax(scores.AsSpan(0, {0}));", _numClasses)); - for (int c = 0; c < _biases.Length; c++) - writer.WriteLine("output[{0}] = Math.Exp(scores[{0}] - softmax);", c); + if (Statistics != null) + Statistics.SaveText(writer, schema.Feature.Value, 20); } + private protected abstract string GetTrainerName(); + + /// + /// Redirect call to the right function. + /// + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) => SaveAsTextCore(writer, schema); + + /// + /// Summary is equivalent to its information in text format. + /// void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { ((ICanSaveInTextFormat)this).SaveAsText(writer, schema); } - JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) + /// + /// Actual implementation of should happen in derived classes. + /// + private void SaveAsCodeCore(TextWriter writer, RoleMappedSchema schema) + { + Host.CheckValue(writer, nameof(writer)); + Host.CheckValueOrNull(schema); + + writer.WriteLine(string.Format("var scores = new float[{0}];", NumberOfClasses)); + + for (int i = 0; i < Biases.Length; i++) + { + LinearPredictorUtils.SaveAsCode(writer, + in Weights[i], + Biases[i], + schema, + "scores[" + i.ToString() + "]"); + } + } + + /// + /// The raw scores of all linear classifiers are stored in [] . + /// Derived classes can use this functin to add C# code for post-transformation. + /// + private protected abstract void SavePostTransformAsCode(TextWriter writer, string scoresName); + + /// + /// Redirect call to the right function. + /// + void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) => SaveAsCodeCore(writer, schema); + + /// + /// Actual implementation of should happen in derived classes. + /// + private JToken SaveAsPfaCore(BoundPfaContext ctx, JToken input) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckValue(input, nameof(input)); @@ -880,29 +898,55 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) type["fields"] = fields; typeDecl = type; } + JObject predictor = new JObject(); - predictor["coeff"] = new JArray(_weights.Select(w => new JArray(w.DenseValues()))); - predictor["const"] = new JArray(_biases); + predictor["coeff"] = new JArray(Weights.Select(w => new JArray(w.DenseValues()))); + predictor["const"] = new JArray(Biases); var cell = ctx.DeclareCell("MCLinearPredictor", typeDecl, predictor); var cellRef = PfaUtils.Cell(cell); - return PfaUtils.Call("m.link.softmax", PfaUtils.Call("model.reg.linear", input, cellRef)); + return ApplyPfaPostTransform(PfaUtils.Call("model.reg.linear", input, cellRef)); } - bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) + /// + /// This is called at the end of to adjust the final outputs of all linear models. + /// + private protected abstract JToken ApplyPfaPostTransform(JToken input); + + /// + /// Redirect call to the right function. + /// + JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) => SaveAsPfaCore(ctx, input); + + /// + /// Actual implementation of should happen in derived classes. + /// It's ok to make a method in the future + /// if any derived class wants to override. + /// + private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); string opType = "LinearClassifier"; var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType)); - // Selection of logit or probit output transform. enum {'NONE', 'SOFTMAX', 'LOGISTIC', 'SOFTMAX_ZERO', 'PROBIT} - node.AddAttribute("post_transform", "NONE"); + node.AddAttribute("post_transform", GetOnnxPostTransform()); node.AddAttribute("multi_class", true); - node.AddAttribute("coefficients", _weights.SelectMany(w => w.DenseValues())); - node.AddAttribute("intercepts", _biases); - node.AddAttribute("classlabels_ints", Enumerable.Range(0, _numClasses).Select(x => (long)x)); + node.AddAttribute("coefficients", Weights.SelectMany(w => w.DenseValues())); + node.AddAttribute("intercepts", Biases); + node.AddAttribute("classlabels_ints", Enumerable.Range(0, NumberOfClasses).Select(x => (long)x)); return true; } + /// + /// Post-transform applied to the raw scores produced by those linear models of all classes. For maximum entropy classification, it should be + /// a softmax function. This function is used only in . + /// + private protected abstract string GetOnnxPostTransform(); + + /// + /// Redirect call to the right function. + /// + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) => SaveAsOnnxCore(ctx, outputs, featureColumn); + /// /// Copies the weight vector for each class into a set of buffers. /// @@ -912,37 +956,37 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string fea /// of . public void GetWeights(ref VBuffer[] weights, out int numClasses) { - numClasses = _numClasses; - Utils.EnsureSize(ref weights, _numClasses, _numClasses); - for (int i = 0; i < _numClasses; i++) - _weights[i].CopyTo(ref weights[i]); + numClasses = NumberOfClasses; + Utils.EnsureSize(ref weights, NumberOfClasses, NumberOfClasses); + for (int i = 0; i < NumberOfClasses; i++) + Weights[i].CopyTo(ref weights[i]); + } + + /// + /// Gets the biases for the logistic regression predictor. + /// + public IEnumerable GetBiases() + { + return Biases; } internal IEnumerable DenseWeightsEnumerable() { - Contracts.Assert(_weights.Length == _biases.Length); + Contracts.Assert(Weights.Length == Biases.Length); - int featuresCount = _weights[0].Length; - for (var i = 0; i < _weights.Length; i++) + int featuresCount = Weights[0].Length; + for (var i = 0; i < Weights.Length; i++) { - Host.Assert(featuresCount == _weights[i].Length); - foreach (var weight in _weights[i].Items(all: true)) + Host.Assert(featuresCount == Weights[i].Length); + foreach (var weight in Weights[i].Items(all: true)) yield return weight.Value; } } - /// - /// Gets the biases for the logistic regression predictor. - /// - public IEnumerable GetBiases() - { - return _biases; - } - internal string GetLabelName(int classNumber) { const string classNumberFormat = "Class_{0}"; - Contracts.Assert(0 <= classNumber && classNumber < _numClasses); + Contracts.Assert(0 <= classNumber && classNumber < NumberOfClasses); return _labelNames == null ? string.Format(classNumberFormat, classNumber) : _labelNames[classNumber]; } @@ -950,8 +994,8 @@ private string[] LoadLabelNames(ModelLoadContext ctx, BinaryReader reader) { Contracts.AssertValue(ctx); Contracts.AssertValue(reader); - string[] labelNames = new string[_numClasses]; - for (int i = 0; i < _numClasses; i++) + string[] labelNames = new string[NumberOfClasses]; + for (int i = 0; i < NumberOfClasses; i++) { int id = reader.ReadInt32(); Host.CheckDecode(0 <= id && id < Utils.Size(ctx.Strings)); @@ -967,8 +1011,8 @@ private void SaveLabelNames(ModelSaveContext ctx, BinaryWriter writer) { Contracts.AssertValue(ctx); Contracts.AssertValue(writer); - Contracts.Assert(Utils.Size(_labelNames) == _numClasses); - for (int i = 0; i < _numClasses; i++) + Contracts.Assert(Utils.Size(_labelNames) == NumberOfClasses); + for (int i = 0; i < NumberOfClasses; i++) { Host.AssertValue(_labelNames[i]); writer.Write(ctx.Strings.Add(_labelNames[i]).Id); @@ -981,12 +1025,12 @@ IDataView ICanGetSummaryAsIDataView.GetSummaryDataView(RoleMappedSchema schema) ValueGetter>> getSlotNames = (ref VBuffer> dst) => - AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, _numFeatures, ref dst); + AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, NumberOfFeatures, ref dst); // Add the bias and the weight columns. - bldr.AddColumn("Bias", NumberDataViewType.Single, _biases); - bldr.AddColumn("Weights", getSlotNames, NumberDataViewType.Single, _weights); - bldr.AddColumn("ClassNames", Enumerable.Range(0, _numClasses).Select(i => GetLabelName(i)).ToArray()); + bldr.AddColumn("Bias", NumberDataViewType.Single, Biases); + bldr.AddColumn("Weights", getSlotNames, NumberDataViewType.Single, Weights); + bldr.AddColumn("ClassNames", Enumerable.Range(0, NumberOfClasses).Select(i => GetLabelName(i)).ToArray()); return bldr.GetDataView(); } @@ -1001,32 +1045,157 @@ DataViewRow ICanGetSummaryAsIRow.GetStatsIRowOrNull(RoleMappedSchema schema) return null; var names = default(VBuffer>); - AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, _weights.Length, ref names); + AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weights.Length, ref names); var meta = Statistics.MakeStatisticsMetadata(schema, in names); return AnnotationUtils.AnnotationsAsRow(meta); } } /// - /// A component to train a logistic regression model. + /// Linear model of multiclass classifiers. It outputs raw scores of all its linear models, and no probablistic output is provided. /// - public partial class LogisticRegressionBinaryTrainer + public sealed class LinearMulticlassModelParameters : LinearMulticlassModelParametersBase { - [TlcModule.EntryPoint(Name = "Trainers.LogisticRegressionClassifier", - Desc = Summary, - UserName = LogisticRegressionMulticlassClassificationTrainer.UserNameValue, - ShortName = LogisticRegressionMulticlassClassificationTrainer.ShortName)] - internal static CommonOutputs.MulticlassClassificationOutput TrainMulticlass(IHostEnvironment env, LogisticRegressionMulticlassClassificationTrainer.Options input) + internal const string LoaderSignature = "MulticlassLinear"; + internal const string RegistrationName = "MulticlassLinearPredictor"; + + private static VersionInfo VersionInfo => + new VersionInfo( + modelSignature: "MCLINEAR", + verWrittenCur: 0x00010001, + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LinearMulticlassModelParameters).Assembly.FullName); + + /// + /// Function used to pass into parent class. It may be used when saving the model. + /// + private protected override VersionInfo GetVersionInfo() => VersionInfo; + + internal LinearMulticlassModelParameters(IHostEnvironment env, in VBuffer weights, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null) + : base(env, RegistrationName, weights, numClasses, numFeatures, labelNames, stats) + { + } + + internal LinearMulticlassModelParameters(IHostEnvironment env, VBuffer[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null) + : base(env, RegistrationName, weights, bias, numClasses, numFeatures, labelNames, stats) + { + } + + private LinearMulticlassModelParameters(IHostEnvironment env, ModelLoadContext ctx) + : base(env, RegistrationName, ctx) + { + } + + /// + /// This function does not do any calibration. It's common in multi-class support vector machines where probabilitic outputs are not provided. + /// + /// Score vector should be calibrated. + private protected override void Calibrate(Span dst) + { + } + + private static LinearMulticlassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); - var host = env.Register("TrainLRMultiClass"); - host.CheckValue(input, nameof(input)); - EntryPointUtils.CheckInputArgs(host, input); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(VersionInfo); + return new LinearMulticlassModelParameters(env, ctx); + } - return TrainerEntryPointsUtils.Train(host, input, - () => new LogisticRegressionMulticlassClassificationTrainer(host, input), - () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName), - () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName)); + private protected override void SavePostTransformAsCode(TextWriter writer, string scoresName) { } + + /// + /// No post-transform is needed for non-clibrated classifier. + /// + private protected override string GetOnnxPostTransform() => "NONE"; + + /// + /// No post-transform is needed for non-clibrated classifier. + /// + private protected override JToken ApplyPfaPostTransform(JToken input) => input; + + private protected override string GetTrainerName() => nameof(LinearMulticlassModelParameters); + } + + /// + /// Linear maximum entropy model of multiclass classifiers. It outputs classes probabilities. + /// This model is also known as multinomial logistic regression. + /// Please see https://en.wikipedia.org/wiki/Multinomial_logistic_regression for details. + /// + public sealed class MaximumEntropyModelParameters : LinearMulticlassModelParametersBase + { + internal const string LoaderSignature = "MultiClassLRExec"; + internal const string RegistrationName = "MulticlassLogisticRegressionPredictor"; + + private static VersionInfo VersionInfo => + new VersionInfo( + modelSignature: "MULTI LR", + // verWrittenCur: 0x00010001, // Initial + // verWrittenCur: 0x00010002, // Added class names + verWrittenCur: 0x00010003, // Added model stats + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(MaximumEntropyModelParameters).Assembly.FullName); + + /// + /// Function used to pass into parent class. It may be used when saving the model. + /// + private protected override VersionInfo GetVersionInfo() => VersionInfo; + + internal MaximumEntropyModelParameters(IHostEnvironment env, in VBuffer weights, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null) + : base(env, RegistrationName, weights, numClasses, numFeatures, labelNames, stats) + { + } + + internal MaximumEntropyModelParameters(IHostEnvironment env, VBuffer[] weights, float[] bias, int numClasses, int numFeatures, string[] labelNames, ModelStatisticsBase stats = null) + : base(env, RegistrationName, weights, bias, numClasses, numFeatures, labelNames, stats) + { + } + + private MaximumEntropyModelParameters(IHostEnvironment env, ModelLoadContext ctx) + : base(env, RegistrationName, ctx) + { } + + /// + /// This function applies softmax to . For details about softmax, see https://en.wikipedia.org/wiki/Softmax_function. + /// + /// Score vector should be calibrated. + private protected override void Calibrate(Span dst) + { + Host.Assert(dst.Length == NumberOfClasses); + + // scores are in log-space; convert and fix underflow/overflow + // TODO: re-normalize probabilities to account for underflow/overflow? + float softmax = MathUtils.SoftMax(dst.Slice(0, NumberOfClasses)); + for (int i = 0; i < NumberOfClasses; ++i) + dst[i] = MathUtils.ExpSlow(dst[i] - softmax); + } + + /// + /// Apply softmax function to , which contains raw scores from all linear models. + /// + private protected override void SavePostTransformAsCode(TextWriter writer, string scoresName) + { + writer.WriteLine(string.Format("var softmax = MathUtils.SoftMax({0}.AsSpan(0, {1}));", scoresName, NumberOfClasses)); + + for (int c = 0; c < Biases.Length; c++) + writer.WriteLine("{1}[{0}] = Math.Exp({1}[{0}] - softmax);", c, scoresName); + } + + /// + /// Apply softmax to the raw scores produced by the lienar models of all classes. + /// + private protected override string GetOnnxPostTransform() => "SOFTMAX"; + + /// + /// Apply softmax to the raw scores produced by the lienar models of all classes. + /// + private protected override JToken ApplyPfaPostTransform(JToken input) => PfaUtils.Call("m.link.softmax", input); + + private protected override string GetTrainerName() => nameof(LbfgsMaximumEntropyTrainer); } } diff --git a/src/Microsoft.ML.StandardTrainers/Standard/SdcaMulticlass.cs b/src/Microsoft.ML.StandardTrainers/Standard/SdcaMulticlass.cs index 452f35d0fc..54c7748067 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/SdcaMulticlass.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/SdcaMulticlass.cs @@ -16,19 +16,20 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Trainers; -[assembly: LoadableClass(SdcaMulticlassTrainer.Summary, typeof(SdcaMulticlassTrainer), typeof(SdcaMulticlassTrainer.Options), +[assembly: LoadableClass(SdcaCalibratedMulticlassTrainer.Summary, typeof(SdcaCalibratedMulticlassTrainer), typeof(SdcaCalibratedMulticlassTrainer.Options), new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, - SdcaMulticlassTrainer.UserNameValue, - SdcaMulticlassTrainer.LoadNameValue, - SdcaMulticlassTrainer.ShortName)] + SdcaCalibratedMulticlassTrainer.UserNameValue, + SdcaCalibratedMulticlassTrainer.LoadNameValue, + SdcaCalibratedMulticlassTrainer.ShortName)] namespace Microsoft.ML.Trainers { /// - /// The for training a multiclass logistic regression classification model using the stochastic dual coordinate ascent method. + /// The for training a multiclass linear classification model using the stochastic dual coordinate ascent method. /// /// - public sealed class SdcaMulticlassTrainer : SdcaTrainerBase, MulticlassLogisticRegressionModelParameters> + public abstract class SdcaMulticlassClassificationTrainerBase : SdcaTrainerBase.MulticlassOptions, MulticlassPredictionTransformer, TModel> + where TModel : class { internal const string LoadNameValue = "SDCAMC"; internal const string UserNameValue = "Fast Linear Multi-class Classification (SA-SDCA)"; @@ -36,9 +37,9 @@ public sealed class SdcaMulticlassTrainer : SdcaTrainerBase - /// Options for the . + /// Options for the . /// - public sealed class Options : OptionsBase + public class MulticlassOptions : OptionsBase { /// /// The custom loss. @@ -50,12 +51,12 @@ public sealed class Options : OptionsBase internal ISupportSdcaClassificationLossFactory LossFunctionFactory = new LogLossFactory(); /// - /// The custom loss. + /// Internal state of or storage of + /// a customized loss passed in. cannot set this field because its + /// loss function is always . In addition, and are + /// the two fields used to determined the actual loss function inside the training framework of . /// - /// - /// If unspecified, will be used. - /// - public ISupportSdcaClassificationLoss LossFunction { get; set; } + internal ISupportSdcaClassificationLoss InternalLoss; } private readonly ISupportSdcaClassificationLoss _loss; @@ -63,7 +64,7 @@ public sealed class Options : OptionsBase private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification; /// - /// Initializes a new instance of + /// Initializes a new instance of . /// /// The environment to use. /// The label, or dependent variable. @@ -73,7 +74,7 @@ public sealed class Options : OptionsBase /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - internal SdcaMulticlassTrainer(IHostEnvironment env, + internal SdcaMulticlassClassificationTrainerBase(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, string weights = null, @@ -86,22 +87,22 @@ internal SdcaMulticlassTrainer(IHostEnvironment env, { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - _loss = loss ?? SdcaTrainerOptions.LossFunction ?? SdcaTrainerOptions.LossFunctionFactory.CreateComponent(env); + _loss = loss ?? SdcaTrainerOptions.InternalLoss ?? SdcaTrainerOptions.LossFunctionFactory.CreateComponent(env); Loss = _loss; } - internal SdcaMulticlassTrainer(IHostEnvironment env, Options options, + internal SdcaMulticlassClassificationTrainerBase(IHostEnvironment env, MulticlassOptions options, string featureColumn, string labelColumn, string weightColumn = null) : base(env, options, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); - _loss = options.LossFunction ?? options.LossFunctionFactory.CreateComponent(env); + _loss = options.InternalLoss ?? options.LossFunctionFactory.CreateComponent(env); Loss = _loss; } - internal SdcaMulticlassTrainer(IHostEnvironment env, Options options) + internal SdcaMulticlassClassificationTrainerBase(IHostEnvironment env, MulticlassOptions options) : this(env, options, options.FeatureColumnName, options.LabelColumnName) { } @@ -412,34 +413,132 @@ private protected override bool CheckConvergence( return converged; } - private protected override MulticlassLogisticRegressionModelParameters CreatePredictor(VBuffer[] weights, float[] bias) + private protected override void CheckLabel(RoleMappedData examples, out int weightSetCount) + { + examples.CheckMulticlassLabel(out weightSetCount); + } + + private protected override float[] InitializeFeatureNormSquared(int length) + { + Contracts.Assert(0 < length & length <= Utils.ArrayMaxSize); + return new float[length]; + } + + private protected override float GetInstanceWeight(FloatLabelCursor cursor) + { + return cursor.Weight; + } + } + + /// + /// The for training a maximum entropy classification model using the stochastic dual coordinate ascent method. + /// The trained model produces probabilities of classes. + /// + /// + public sealed class SdcaCalibratedMulticlassTrainer : SdcaMulticlassClassificationTrainerBase + { + public sealed class Options : MulticlassOptions + { + } + + internal SdcaCalibratedMulticlassTrainer(IHostEnvironment env, + string labelColumn = DefaultColumnNames.Label, + string featureColumn = DefaultColumnNames.Features, + string weights = null, + float? l2Const = null, + float? l1Threshold = null, + int? maxIterations = null) + : base(env, labelColumn: labelColumn, featureColumn: featureColumn, weights: weights, loss: new LogLoss(), + l2Const: l2Const, l1Threshold: l1Threshold, maxIterations: maxIterations) + { + } + + internal SdcaCalibratedMulticlassTrainer(IHostEnvironment env, Options options, + string featureColumn, string labelColumn, string weightColumn = null) + : base(env, options: options, featureColumn: featureColumn, labelColumn: labelColumn, weightColumn: weightColumn) + { + } + + internal SdcaCalibratedMulticlassTrainer(IHostEnvironment env, Options options) + : base(env, options) + { + } + + private protected override MaximumEntropyModelParameters CreatePredictor(VBuffer[] weights, float[] bias) { Host.CheckValue(weights, nameof(weights)); Host.CheckValue(bias, nameof(bias)); Host.CheckParam(weights.Length > 0, nameof(weights)); Host.CheckParam(weights.Length == bias.Length, nameof(weights)); - return new MulticlassLogisticRegressionModelParameters(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null); + return new MaximumEntropyModelParameters(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null); } - private protected override void CheckLabel(RoleMappedData examples, out int weightSetCount) + private protected override MulticlassPredictionTransformer MakeTransformer( + MaximumEntropyModelParameters model, DataViewSchema trainSchema) => + new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); + } + + /// + /// The for training a multiclass linear model using the stochastic dual coordinate ascent method. + /// The trained model does not produces probabilities of classes, but we can still make decisions + /// by choosing the class associated with the largest score. + /// + /// + public sealed class SdcaNonCalibratedMulticlassTrainer : SdcaMulticlassClassificationTrainerBase + { + public sealed class Options : MulticlassOptions { - examples.CheckMulticlassLabel(out weightSetCount); + /// + /// Loss function minimized by this trainer. + /// + /// + /// If unspecified, will be used. + /// + public ISupportSdcaClassificationLoss Loss + { + get { return InternalLoss; } + set { InternalLoss = value; } + } } - private protected override float[] InitializeFeatureNormSquared(int length) + internal SdcaNonCalibratedMulticlassTrainer(IHostEnvironment env, + string labelColumn = DefaultColumnNames.Label, + string featureColumn = DefaultColumnNames.Features, + string weights = null, + ISupportSdcaClassificationLoss loss = null, + float? l2Const = null, + float? l1Threshold = null, + int? maxIterations = null) + : base(env, labelColumn: labelColumn, featureColumn: featureColumn, weights: weights, loss: loss, + l2Const: l2Const, l1Threshold: l1Threshold, maxIterations: maxIterations) { - Contracts.Assert(0 < length & length <= Utils.ArrayMaxSize); - return new float[length]; } - private protected override float GetInstanceWeight(FloatLabelCursor cursor) + internal SdcaNonCalibratedMulticlassTrainer(IHostEnvironment env, Options options, + string featureColumn, string labelColumn, string weightColumn = null) + : base(env, options: options, featureColumn: featureColumn, labelColumn: labelColumn, weightColumn: weightColumn) { - return cursor.Weight; } - private protected override MulticlassPredictionTransformer MakeTransformer(MulticlassLogisticRegressionModelParameters model, DataViewSchema trainSchema) - => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); + internal SdcaNonCalibratedMulticlassTrainer(IHostEnvironment env, Options options) + : base(env, options) + { + } + + private protected override LinearMulticlassModelParameters CreatePredictor(VBuffer[] weights, float[] bias) + { + Host.CheckValue(weights, nameof(weights)); + Host.CheckValue(bias, nameof(bias)); + Host.CheckParam(weights.Length > 0, nameof(weights)); + Host.CheckParam(weights.Length == bias.Length, nameof(weights)); + + return new LinearMulticlassModelParameters(Host, weights, bias, bias.Length, weights[0].Length, null, stats: null); + } + + private protected override MulticlassPredictionTransformer MakeTransformer( + LinearMulticlassModelParameters model, DataViewSchema trainSchema) => + new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); } /// @@ -448,18 +547,18 @@ private protected override MulticlassPredictionTransformer(host, input, - () => new SdcaMulticlassTrainer(host, input), + return TrainerEntryPointsUtils.Train(host, input, + () => new SdcaCalibratedMulticlassTrainer(host, input), () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName)); } } diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 4a87d8e23c..09ac97c694 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -281,13 +281,12 @@ public static SdcaNonCalibratedBinaryTrainer SdcaNonCalibrated( } /// - /// Predict a target using a linear multiclass classification model trained with . + /// Predict a target using a maximum entropy classification model trained with . /// /// The multiclass classification catalog trainer object. /// The name of the label column. /// The name of the feature column. /// The name of the example weight column (optional). - /// The custom loss. Defaults to if not specified. /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. @@ -297,7 +296,58 @@ public static SdcaNonCalibratedBinaryTrainer SdcaNonCalibrated( /// [!code-csharp[SDCA](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscent.cs)] /// ]]> /// - public static SdcaMulticlassTrainer Sdca(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + public static SdcaCalibratedMulticlassTrainer SdcaCalibrated(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + string labelColumnName = DefaultColumnNames.Label, + string featureColumnName = DefaultColumnNames.Features, + string exampleWeightColumnName = null, + float? l2Regularization = null, + float? l1Threshold = null, + int? maximumNumberOfIterations = null) + { + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + return new SdcaCalibratedMulticlassTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l2Regularization, l1Threshold, maximumNumberOfIterations); + } + + /// + /// Predict a target using a maximum entropy classification model trained with and advanced options. + /// + /// The multiclass classification catalog trainer object. + /// Trainer options. + /// + /// + /// + /// + public static SdcaCalibratedMulticlassTrainer SdcaCalibrated(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + SdcaCalibratedMulticlassTrainer.Options options) + { + Contracts.CheckValue(catalog, nameof(catalog)); + Contracts.CheckValue(options, nameof(options)); + + var env = CatalogUtils.GetEnvironment(catalog); + return new SdcaCalibratedMulticlassTrainer(env, options); + } + + /// + /// Predict a target using a linear multiclass classification model trained with . + /// + /// The multiclass classification catalog trainer object. + /// The name of the label column. + /// The name of the feature column. + /// The name of the example weight column (optional). + /// Loss function to be minimized. Defaults to if not specified. + /// The L2 regularization hyperparameter. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The maximum number of passes to perform over the data. + /// + /// + /// + /// + public static SdcaNonCalibratedMulticlassTrainer SdcaNonCalibrated(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, @@ -308,11 +358,11 @@ public static SdcaMulticlassTrainer Sdca(this MulticlassClassificationCatalog.Mu { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new SdcaMulticlassTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, loss, l2Regularization, l1Threshold, maximumNumberOfIterations); + return new SdcaNonCalibratedMulticlassTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, loss, l2Regularization, l1Threshold, maximumNumberOfIterations); } /// - /// Predict a target using a linear multiclass classification model trained with and advanced options. + /// Predict a target using linear multiclass classification model trained with and advanced options. /// /// The multiclass classification catalog trainer object. /// Trainer options. @@ -322,14 +372,14 @@ public static SdcaMulticlassTrainer Sdca(this MulticlassClassificationCatalog.Mu /// [!code-csharp[SDCA](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs)] /// ]]> /// - public static SdcaMulticlassTrainer Sdca(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, - SdcaMulticlassTrainer.Options options) + public static SdcaNonCalibratedMulticlassTrainer SdcaNonCalibrated(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + SdcaNonCalibratedMulticlassTrainer.Options options) { Contracts.CheckValue(catalog, nameof(catalog)); Contracts.CheckValue(options, nameof(options)); var env = CatalogUtils.GetEnvironment(catalog); - return new SdcaMulticlassTrainer(env, options); + return new SdcaNonCalibratedMulticlassTrainer(env, options); } /// @@ -537,7 +587,7 @@ public static PoissonRegressionTrainer PoissonRegression(this RegressionCatalog. } /// - /// Predict a target using a linear multiclass classification model trained with the trainer. + /// Predict a target using a maximum entropy classification model trained with the L-BFGS method implemented in . /// /// The . /// The name of the label column. @@ -546,9 +596,9 @@ public static PoissonRegressionTrainer PoissonRegression(this RegressionCatalog. /// Enforce non-negative weights. /// Weight of L1 regularization term. /// Weight of L2 regularization term. - /// Memory size for . Low=faster, less accurate. + /// Memory size for . Low=faster, less accurate. /// Threshold for optimizer convergence. - public static LogisticRegressionMulticlassClassificationTrainer LogisticRegression(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + public static LbfgsMaximumEntropyTrainer LbfgsMaximumEntropy(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, @@ -560,22 +610,22 @@ public static LogisticRegressionMulticlassClassificationTrainer LogisticRegressi { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new LogisticRegressionMulticlassClassificationTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l1Regularization, l2Regularization, optimizationTolerance, historySize, enforceNonNegativity); + return new LbfgsMaximumEntropyTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l1Regularization, l2Regularization, optimizationTolerance, historySize, enforceNonNegativity); } /// - /// Predict a target using a linear multiclass classification model trained with the trainer. + /// Predict a target using a maximum entropy classification model trained with the L-BFGS method implemented in . /// /// The . /// Advanced arguments to the algorithm. - public static LogisticRegressionMulticlassClassificationTrainer LogisticRegression(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, - LogisticRegressionMulticlassClassificationTrainer.Options options) + public static LbfgsMaximumEntropyTrainer LbfgsMaximumEntropy(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + LbfgsMaximumEntropyTrainer.Options options) { Contracts.CheckValue(catalog, nameof(catalog)); Contracts.CheckValue(options, nameof(options)); var env = CatalogUtils.GetEnvironment(catalog); - return new LogisticRegressionMulticlassClassificationTrainer(env, options); + return new LbfgsMaximumEntropyTrainer(env, options); } /// diff --git a/src/Microsoft.ML.StaticPipe/LbfgsStatic.cs b/src/Microsoft.ML.StaticPipe/LbfgsStatic.cs index 13c2554079..1110fd56ee 100644 --- a/src/Microsoft.ML.StaticPipe/LbfgsStatic.cs +++ b/src/Microsoft.ML.StaticPipe/LbfgsStatic.cs @@ -209,7 +209,7 @@ public static Scalar PoissonRegression(this RegressionCatalog.RegressionT public static class LbfgsMulticlassExtensions { /// - /// Predict a target using a linear multiclass classification model trained with the trainer. + /// Predict a target using a maximum entropy classification model trained with the L-BFGS method implemented in . /// /// The multiclass classification catalog trainer object. /// The label, or dependent variable. @@ -227,7 +227,7 @@ public static class LbfgsMulticlassExtensions /// result in any way; it is only a way for the caller to be informed about what was learnt. /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. public static (Vector score, Key predictedLabel) - MulticlassLogisticRegression(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + LbfgsMaximumEntropy(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, Key label, Vector features, Scalar weights = null, @@ -236,14 +236,14 @@ public static (Vector score, Key predictedLabel) float optimizationTolerance = Options.Defaults.OptimizationTolerance, int historySize = Options.Defaults.HistorySize, bool enforceNonNegativity = Options.Defaults.EnforceNonNegativity, - Action onFit = null) + Action onFit = null) { LbfgsStaticUtils.ValidateParams(label, features, weights, l1Regularization, l2Regularization, optimizationTolerance, historySize, enforceNonNegativity, onFit); var rec = new TrainerEstimatorReconciler.MulticlassClassificationReconciler( (env, labelName, featuresName, weightsName) => { - var trainer = new LogisticRegressionMulticlassClassificationTrainer(env, labelName, featuresName, weightsName, + var trainer = new LbfgsMaximumEntropyTrainer(env, labelName, featuresName, weightsName, l1Regularization, l2Regularization, optimizationTolerance, historySize, enforceNonNegativity); if (onFit != null) @@ -255,7 +255,7 @@ public static (Vector score, Key predictedLabel) } /// - /// Predict a target using a linear multiclass classification model trained with the trainer. + /// Predict a target using a maximum entropy classification model trained with the L-BFGS method implemented in . /// /// The multiclass classification catalog trainer object. /// The label, or dependent variable. @@ -269,12 +269,12 @@ public static (Vector score, Key predictedLabel) /// result in any way; it is only a way for the caller to be informed about what was learnt. /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. public static (Vector score, Key predictedLabel) - MulticlassLogisticRegression(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + LbfgsMaximumEntropy(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, Key label, Vector features, Scalar weights, - LogisticRegressionMulticlassClassificationTrainer.Options options, - Action onFit = null) + LbfgsMaximumEntropyTrainer.Options options, + Action onFit = null) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); @@ -288,7 +288,7 @@ public static (Vector score, Key predictedLabel) options.FeatureColumnName = featuresName; options.ExampleWeightColumnName = weightsName; - var trainer = new LogisticRegressionMulticlassClassificationTrainer(env, options); + var trainer = new LbfgsMaximumEntropyTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); diff --git a/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs index 08468f93f3..88f8dcda91 100644 --- a/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs @@ -329,12 +329,11 @@ public static (Scalar score, Scalar predictedLabel) SdcaNonCalibrat } /// - /// Predict a target using a linear multiclass classification model trained with the SDCA trainer. + /// Predict a target using a maximum entropy classification model trained with the SDCA trainer. /// /// The multiclass classification catalog trainer object. /// The label, or dependent variable. /// The features, or independent variables. - /// The custom loss. /// The optional example weights. /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. @@ -349,12 +348,102 @@ public static (Vector score, Key predictedLabel) Sdca( this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, Key label, Vector features, - ISupportSdcaClassificationLoss loss = null, Scalar weights = null, float? l2Regularization = null, float? l1Threshold = null, int? numberOfIterations = null, - Action onFit = null) + Action onFit = null) + { + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckValue(features, nameof(features)); + Contracts.CheckValueOrNull(weights); + Contracts.CheckParam(!(l2Regularization < 0), nameof(l2Regularization), "Must not be negative, if specified."); + Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative, if specified."); + Contracts.CheckParam(!(numberOfIterations < 1), nameof(numberOfIterations), "Must be positive if specified"); + Contracts.CheckValueOrNull(onFit); + + var rec = new TrainerEstimatorReconciler.MulticlassClassificationReconciler( + (env, labelName, featuresName, weightsName) => + { + var trainer = new SdcaCalibratedMulticlassTrainer(env, labelName, featuresName, weightsName, l2Regularization, l1Threshold, numberOfIterations); + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, weights); + + return rec.Output; + } + + /// + /// Predict a target using a maximum entropy classification model trained with the SDCA trainer. + /// + /// The multiclass classification catalog trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The optional example weights. + /// Advanced arguments to the algorithm. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. + public static (Vector score, Key predictedLabel) Sdca( + this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + Key label, + Vector features, + Scalar weights, + SdcaCalibratedMulticlassTrainer.Options options, + Action onFit = null) + { + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckValue(features, nameof(features)); + Contracts.CheckValueOrNull(weights); + Contracts.CheckValueOrNull(options); + Contracts.CheckValueOrNull(onFit); + + var rec = new TrainerEstimatorReconciler.MulticlassClassificationReconciler( + (env, labelName, featuresName, weightsName) => + { + options.LabelColumnName = labelName; + options.FeatureColumnName = featuresName; + + var trainer = new SdcaCalibratedMulticlassTrainer(env, options); + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, weights); + + return rec.Output; + } + + /// + /// Predict a target using a linear multiclass classification model trained with the SDCA trainer. + /// + /// The multiclass classification catalog trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The custom loss, for example, . + /// The optional example weights. + /// The L2 regularization hyperparameter. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The maximum number of passes to perform over the data. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. + public static (Vector score, Key predictedLabel) SdcaNonCalibrated( + this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + Key label, + Vector features, + ISupportSdcaClassificationLoss loss, + Scalar weights = null, + float? l2Regularization = null, + float? l1Threshold = null, + int? numberOfIterations = null, + Action onFit = null) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); @@ -368,7 +457,7 @@ public static (Vector score, Key predictedLabel) Sdca( var rec = new TrainerEstimatorReconciler.MulticlassClassificationReconciler( (env, labelName, featuresName, weightsName) => { - var trainer = new SdcaMulticlassTrainer(env, labelName, featuresName, weightsName, loss, l2Regularization, l1Threshold, numberOfIterations); + var trainer = new SdcaNonCalibratedMulticlassTrainer(env, labelName, featuresName, weightsName, loss, l2Regularization, l1Threshold, numberOfIterations); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -391,13 +480,13 @@ public static (Vector score, Key predictedLabel) Sdca( /// the linear model that was trained. Note that this action cannot change the /// result in any way; it is only a way for the caller to be informed about what was learnt. /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. - public static (Vector score, Key predictedLabel) Sdca( + public static (Vector score, Key predictedLabel) SdcaNonCalibrated( this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, Key label, Vector features, Scalar weights, - SdcaMulticlassTrainer.Options options, - Action onFit = null) + SdcaNonCalibratedMulticlassTrainer.Options options, + Action onFit = null) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); @@ -411,7 +500,7 @@ public static (Vector score, Key predictedLabel) Sdca( options.LabelColumnName = labelName; options.FeatureColumnName = featuresName; - var trainer = new SdcaMulticlassTrainer(env, options); + var trainer = new SdcaNonCalibratedMulticlassTrainer(env, options); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; diff --git a/test/BaselineOutput/Common/Command/CommandTrainMlrWithStats-summary.txt b/test/BaselineOutput/Common/Command/CommandTrainMlrWithStats-summary.txt index 18e66a6f8a..51540ac0ea 100644 --- a/test/BaselineOutput/Common/Command/CommandTrainMlrWithStats-summary.txt +++ b/test/BaselineOutput/Common/Command/CommandTrainMlrWithStats-summary.txt @@ -1,4 +1,4 @@ -LogisticRegressionMulticlassClassificationTrainer bias and non-zero weights +LbfgsMaximumEntropyTrainer bias and non-zero weights Iris-setosa+(Bias) 2.265129 Iris-versicolor+(Bias) 0.7695086 Iris-virginica+(Bias) -3.034663 diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index a23c9b29ae..8ebeb23b72 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -59,14 +59,14 @@ Trainers.LightGbmRanker Train a LightGBM ranking model. Microsoft.ML.Trainers.Li Trainers.LightGbmRegressor LightGBM Regression Microsoft.ML.Trainers.LightGbm.LightGbm TrainRegression Microsoft.ML.Trainers.LightGbm.LightGbmRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.LinearSvmBinaryClassifier Train a linear SVM. Microsoft.ML.Trainers.LinearSvmTrainer TrainLinearSvm Microsoft.ML.Trainers.LinearSvmTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.LogisticRegressionBinaryClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer TrainBinary Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Trainers.LogisticRegressionClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer TrainMulticlass Microsoft.ML.Trainers.LogisticRegressionMulticlassClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput +Trainers.LogisticRegressionClassifier Maximum entrypy classification is a method in statistics used to predict the probabilities of parallel events. The model predicts the probabilities of parallel events by fitting data to a softmax function. Microsoft.ML.Trainers.LbfgsMaximumEntropyTrainer TrainMulticlass Microsoft.ML.Trainers.LbfgsMaximumEntropyTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput Trainers.NaiveBayesClassifier Train a MulticlassNaiveBayesTrainer. Microsoft.ML.Trainers.NaiveBayesMulticlassTrainer TrainMulticlassNaiveBayesTrainer Microsoft.ML.Trainers.NaiveBayesMulticlassTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput Trainers.OnlineGradientDescentRegressor Train a Online gradient descent perceptron. Microsoft.ML.Trainers.OnlineGradientDescentTrainer TrainRegression Microsoft.ML.Trainers.OnlineGradientDescentTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.OrdinaryLeastSquaresRegressor Train an OLS regression model. Microsoft.ML.Trainers.OlsTrainer TrainRegression Microsoft.ML.Trainers.OlsTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.PcaAnomalyDetector Train an PCA Anomaly model. Microsoft.ML.Trainers.RandomizedPcaTrainer TrainPcaAnomaly Microsoft.ML.Trainers.RandomizedPcaTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+AnomalyDetectionOutput Trainers.PoissonRegressor Train an Poisson regression model. Microsoft.ML.Trainers.PoissonRegressionTrainer TrainRegression Microsoft.ML.Trainers.PoissonRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.StochasticDualCoordinateAscentBinaryClassifier Train an SDCA binary model. Microsoft.ML.Trainers.Sdca TrainBinary Microsoft.ML.Trainers.LegacySdcaBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput -Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Trainers.Sdca TrainMulticlass Microsoft.ML.Trainers.SdcaMulticlassTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput +Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Trainers.Sdca TrainMulticlass Microsoft.ML.Trainers.SdcaCalibratedMulticlassTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput Trainers.StochasticDualCoordinateAscentRegressor The SDCA linear regression trainer. Microsoft.ML.Trainers.Sdca TrainRegression Microsoft.ML.Trainers.SdcaRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput Trainers.StochasticGradientDescentBinaryClassifier Train an Hogwild SGD binary model. Microsoft.ML.Trainers.LegacySgdBinaryTrainer TrainBinary Microsoft.ML.Trainers.LegacySgdBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput Trainers.SymSgdBinaryClassifier Train a symbolic SGD. Microsoft.ML.Trainers.SymbolicSgdTrainer TrainSymSgd Microsoft.ML.Trainers.SymbolicSgdTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index c9251c758f..d874dce5b6 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -13569,7 +13569,7 @@ }, { "Name": "Trainers.LogisticRegressionClassifier", - "Desc": "Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function.", + "Desc": "Maximum entrypy classification is a method in statistics used to predict the probabilities of parallel events. The model predicts the probabilities of parallel events by fitting data to a softmax function.", "FriendlyName": "Multi-class Logistic Regression", "ShortName": "mlr", "Inputs": [ diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt index c614821800..107f833949 100644 --- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt @@ -98,7 +98,7 @@ "attribute": [ { "name": "post_transform", - "s": "Tk9ORQ==", + "s": "U09GVE1BWA==", "type": "STRING" }, { diff --git a/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs b/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs index da4f343404..918332dcb4 100644 --- a/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs +++ b/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs @@ -57,8 +57,8 @@ public void SetupIrisPipeline() var pipeline = new ColumnConcatenatingEstimator(env, "Features", new[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" }) .Append(env.Transforms.Conversion.MapValueToKey("Label")) - .Append(env.MulticlassClassification.Trainers.Sdca( - new SdcaMulticlassTrainer.Options { NumberOfThreads = 1, ConvergenceTolerance = 1e-2f, })); + .Append(env.MulticlassClassification.Trainers.SdcaCalibrated( + new SdcaCalibratedMulticlassTrainer.Options { NumberOfThreads = 1, ConvergenceTolerance = 1e-2f, })); var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs index 4be1ddba1e..b55eb3078d 100644 --- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs +++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs @@ -34,7 +34,7 @@ public class StochasticDualCoordinateAscentClassifierBench : WithExtraMetrics PetalWidth = 5.1f, }; - private TransformerChain> _trainedModel; + private TransformerChain> _trainedModel; private PredictionEngine _predictionEngine; private IrisData[][] _batches; private MulticlassClassificationMetrics _metrics; @@ -53,9 +53,9 @@ protected override IEnumerable GetMetrics() } [Benchmark] - public TransformerChain> TrainIris() => Train(_dataPath); + public TransformerChain> TrainIris() => Train(_dataPath); - private TransformerChain> Train(string dataPath) + private TransformerChain> Train(string dataPath) { // Create text loader. var options = new TextLoader.Options() @@ -76,7 +76,7 @@ private TransformerChain(); + var environment = EnvironmentFactory.CreateClassificationEnvironment(); cmd.ExecuteMamlCommand(environment); } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 0a36b2af62..5ab5c972e6 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -1174,7 +1174,7 @@ public void EntryPointMulticlassPipelineEnsemble() }).Fit(data).Transform(data); data = new ColumnConcatenatingTransformer(Env, "Features", new[] { "Features1", "Features2" }).Transform(data); - var mlr = ML.MulticlassClassification.Trainers.LogisticRegression(); + var mlr = ML.MulticlassClassification.Trainers.LbfgsMaximumEntropy(); var rmd = new RoleMappedData(data, "Label", "Features"); predictorModels[i] = new PredictorModelImpl(Env, rmd, data, mlr.Train(rmd)); @@ -3357,14 +3357,14 @@ public void EntryPointLinearPredictorSummary() }; var model = LogisticRegressionBinaryTrainer.TrainBinary(Env, lrInput).PredictorModel; - var mcLrInput = new LogisticRegressionMulticlassClassificationTrainer.Options + var mcLrInput = new LbfgsMaximumEntropyTrainer.Options { TrainingData = dataView, NormalizeFeatures = NormalizeOption.Yes, NumberOfThreads = 1, ShowTrainingStatistics = true }; - var mcModel = LogisticRegressionBinaryTrainer.TrainMulticlass(Env, mcLrInput).PredictorModel; + var mcModel = LbfgsMaximumEntropyTrainer.TrainMulticlass(Env, mcLrInput).PredictorModel; var output = SummarizePredictor.Summarize(Env, new SummarizePredictor.Input() { PredictorModel = model }); diff --git a/test/Microsoft.ML.Functional.Tests/Evaluation.cs b/test/Microsoft.ML.Functional.Tests/Evaluation.cs index 4cfdce9ded..3845361146 100644 --- a/test/Microsoft.ML.Functional.Tests/Evaluation.cs +++ b/test/Microsoft.ML.Functional.Tests/Evaluation.cs @@ -151,8 +151,8 @@ public void TrainAndEvaluateMulticlassClassification() var pipeline = mlContext.Transforms.Concatenate("Features", Iris.Features) .Append(mlContext.Transforms.Conversion.MapValueToKey("Label")) .AppendCacheCheckpoint(mlContext) - .Append(mlContext.MulticlassClassification.Trainers.Sdca( - new SdcaMulticlassTrainer.Options { NumberOfThreads = 1})); + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated( + new SdcaCalibratedMulticlassTrainer.Options { NumberOfThreads = 1})); // Train the model. var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Functional.Tests/IntrospectiveTraining.cs b/test/Microsoft.ML.Functional.Tests/IntrospectiveTraining.cs index 5eb002c993..8353141014 100644 --- a/test/Microsoft.ML.Functional.Tests/IntrospectiveTraining.cs +++ b/test/Microsoft.ML.Functional.Tests/IntrospectiveTraining.cs @@ -393,7 +393,7 @@ public void InspectNestedPipeline() // Extract the trained models. var modelComponents = model.ToList(); var kMeansModel = (modelComponents[1] as TransformerChain>).LastTransformer; - var mcLrModel = (modelComponents[2] as TransformerChain>).LastTransformer; + var mcLrModel = (modelComponents[2] as TransformerChain>).LastTransformer; // Validate the k-means model. VBuffer[] centroids = default; @@ -419,11 +419,11 @@ private IEstimator>> StepTwo(MLContext mlContext) + private IEstimator>> StepTwo(MLContext mlContext) { return mlContext.Transforms.Conversion.MapValueToKey("Label") - .Append(mlContext.MulticlassClassification.Trainers.Sdca( - new SdcaMulticlassTrainer.Options { + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated( + new SdcaCalibratedMulticlassTrainer.Options { MaximumNumberOfIterations = 10, NumberOfThreads = 1 })); } diff --git a/test/Microsoft.ML.Functional.Tests/Training.cs b/test/Microsoft.ML.Functional.Tests/Training.cs index b00739b699..16d6880983 100644 --- a/test/Microsoft.ML.Functional.Tests/Training.cs +++ b/test/Microsoft.ML.Functional.Tests/Training.cs @@ -268,8 +268,8 @@ public void ContinueTrainingLogisticRegressionMulticlass() .Append(mlContext.Transforms.Conversion.MapValueToKey("Label")) .AppendCacheCheckpoint(mlContext); - var trainer = mlContext.MulticlassClassification.Trainers.LogisticRegression( - new LogisticRegressionMulticlassClassificationTrainer.Options { NumberOfThreads = 1, MaximumNumberOfIterations = 10 }); + var trainer = mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy( + new LbfgsMaximumEntropyTrainer.Options { NumberOfThreads = 1, MaximumNumberOfIterations = 10 }); // Fit the data transformation pipeline. var featurization = featurizationPipeline.Fit(data); diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index 70346af906..160c2c3357 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -776,7 +776,7 @@ public void TestMulticlassEnsembleCombiner() LabelColumnName = DefaultColumnNames.Label, TrainingData = dataView }).PredictorModel, - LogisticRegressionBinaryTrainer.TrainMulticlass(Env, new LogisticRegressionMulticlassClassificationTrainer.Options() + LbfgsMaximumEntropyTrainer.TrainMulticlass(Env, new LbfgsMaximumEntropyTrainer.Options() { FeatureColumnName = "Features", LabelColumnName = DefaultColumnNames.Label, @@ -784,7 +784,7 @@ public void TestMulticlassEnsembleCombiner() TrainingData = dataView, NormalizeFeatures = NormalizeOption.No }).PredictorModel, - LogisticRegressionBinaryTrainer.TrainMulticlass(Env, new LogisticRegressionMulticlassClassificationTrainer.Options() + LbfgsMaximumEntropyTrainer.TrainMulticlass(Env, new LbfgsMaximumEntropyTrainer.Options() { FeatureColumnName = "Features", LabelColumnName = DefaultColumnNames.Label, diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 4a220ffb23..25898a5702 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -374,7 +374,7 @@ public void SdcaMulticlass() var reader = TextLoaderStatic.CreateLoader(env, c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); - MulticlassLogisticRegressionModelParameters pred = null; + MaximumEntropyModelParameters pred = null; var loss = new HingeLoss(1); @@ -385,7 +385,7 @@ public void SdcaMulticlass() r.label, r.features, numberOfIterations: 2, - loss: loss, onFit: p => pred = p))); + onFit: p => pred = p))); var pipe = reader.Append(est); @@ -413,6 +413,57 @@ public void SdcaMulticlass() Assert.True(metrics.TopKAccuracy > 0); } + [Fact] + public void SdcaMulticlassSvm() + { + var env = new MLContext(seed: 0); + var dataPath = GetDataPath(TestDatasets.iris.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var catalog = new MulticlassClassificationCatalog(env); + var reader = TextLoaderStatic.CreateLoader(env, + c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); + + LinearMulticlassModelParameters pred = null; + + var loss = new HingeLoss(1); + + // With a custom loss function we no longer get calibrated predictions. + var est = reader.MakeNewEstimator() + .Append(r => (label: r.label.ToKey(), r.features)) + .Append(r => (r.label, preds: catalog.Trainers.SdcaNonCalibrated( + r.label, + r.features, + loss: new HingeLoss(), + numberOfIterations: 2, + onFit: p => pred = p))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + VBuffer[] weights = default; + pred.GetWeights(ref weights, out int n); + Assert.True(n == 3 && n == weights.Length); + foreach (var w in weights) + Assert.True(w.Length == 4); + + var biases = pred.GetBiases(); + Assert.True(biases.Count() == 3); + + var data = model.Load(dataSource); + + // Just output some data on the schema for fun. + var schema = data.AsDynamic.Schema; + for (int c = 0; c < schema.Count; ++c) + Console.WriteLine($"{schema[c].Name}, {schema[c].Type}"); + + var metrics = catalog.Evaluate(data, r => r.label, r => r.preds, 2); + Assert.InRange(metrics.MacroAccuracy, 0.6, 1); + Assert.InRange(metrics.TopKAccuracy, 0.8, 1); + } + [Fact] public void CrossValidate() { @@ -685,16 +736,16 @@ public void MulticlassLogisticRegression() var reader = TextLoaderStatic.CreateLoader(env, c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); - MulticlassLogisticRegressionModelParameters pred = null; + MaximumEntropyModelParameters pred = null; // With a custom loss function we no longer get calibrated predictions. var est = reader.MakeNewEstimator() .Append(r => (label: r.label.ToKey(), r.features)) - .Append(r => (r.label, preds: catalog.Trainers.MulticlassLogisticRegression( + .Append(r => (r.label, preds: catalog.Trainers.LbfgsMaximumEntropy( r.label, r.features, null, - new LogisticRegressionMulticlassClassificationTrainer.Options { NumberOfThreads = 1 }, + new LbfgsMaximumEntropyTrainer.Options { NumberOfThreads = 1 }, onFit: p => pred = p))); var pipe = reader.Append(est); diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index a685971161..1c957c264d 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -385,7 +385,7 @@ public void MulticlassLogisticRegressionOnnxConversionTest() var pipeline = mlContext.Transforms.Normalize("Features"). Append(mlContext.Transforms.Conversion.MapValueToKey("Label")). - Append(mlContext.MulticlassClassification.Trainers.LogisticRegression(new LogisticRegressionMulticlassClassificationTrainer.Options() { NumberOfThreads = 1 })); + Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(new LbfgsMaximumEntropyTrainer.Options() { NumberOfThreads = 1 })); var model = pipeline.Fit(data); var transformedData = model.Transform(data); diff --git a/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs b/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs index 96993ad53d..93b3ea146f 100644 --- a/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs +++ b/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs @@ -231,7 +231,7 @@ public void TestPfiBinaryClassificationOnSparseFeatures() public void TestPfiMulticlassClassificationOnDenseFeatures() { var data = GetDenseDataset(TaskType.MulticlassClassification); - var model = ML.MulticlassClassification.Trainers.LogisticRegression().Fit(data); + var model = ML.MulticlassClassification.Trainers.LbfgsMaximumEntropy().Fit(data); var pfi = ML.MulticlassClassification.PermutationFeatureImportance(model, data); // Pfi Indices: @@ -268,8 +268,8 @@ public void TestPfiMulticlassClassificationOnDenseFeatures() public void TestPfiMulticlassClassificationOnSparseFeatures() { var data = GetSparseDataset(TaskType.MulticlassClassification); - var model = ML.MulticlassClassification.Trainers.LogisticRegression( - new LogisticRegressionMulticlassClassificationTrainer.Options { MaximumNumberOfIterations = 1000 }).Fit(data); + var model = ML.MulticlassClassification.Trainers.LbfgsMaximumEntropy( + new LbfgsMaximumEntropyTrainer.Options { MaximumNumberOfIterations = 1000 }).Fit(data); var pfi = ML.MulticlassClassification.PermutationFeatureImportance(model, data); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index fabcb41205..30be894425 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -264,7 +264,7 @@ private void TrainAndInspectWeights(string dataPath) var trainData = loader.Load(dataPath); // This is the predictor ('weights collection') that we will train. - MulticlassLogisticRegressionModelParameters predictor = null; + MaximumEntropyModelParameters predictor = null; // And these are the normalizer scales that we will learn. ImmutableArray normScales; // Build the training pipeline. diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index 63f25dcbce..614d6baf91 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -161,13 +161,13 @@ private ITransformer TrainOnIris(string irisDataPath) // Cache data in memory for steps after the cache check point stage. .AppendCacheCheckpoint(mlContext) // Use the multi-class SDCA model to predict the label using features. - .Append(mlContext.MulticlassClassification.Trainers.Sdca()); + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated()); // Train the model. var trainedModel = pipeline.Fit(trainData); // Inspect the model parameters. - var modelParameters = trainedModel.LastTransformer.Model as MulticlassLogisticRegressionModelParameters; + var modelParameters = trainedModel.LastTransformer.Model as MaximumEntropyModelParameters; // Get the weights and the numbers of classes VBuffer[] weights = default; @@ -419,7 +419,7 @@ private void CrossValidationOn(string dataPath) // Notice that unused part in the data may not be cached. .AppendCacheCheckpoint(mlContext) // Use the multi-class SDCA model to predict the label using features. - .Append(mlContext.MulticlassClassification.Trainers.Sdca()); + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated()); // Split the data 90:10 into train and test sets, train and evaluate. var split = mlContext.Data.TrainTestSplit(data, testFraction: 0.1); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs index 151cb3792f..016aca6d3d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs @@ -31,8 +31,8 @@ void DecomposableTrainAndPredict() var pipeline = new ColumnConcatenatingEstimator (ml, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new ValueToKeyMappingEstimator(ml, "Label"), TransformerScope.TrainTest) - .Append(ml.MulticlassClassification.Trainers.Sdca( - new SdcaMulticlassTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1, })) + .Append(ml.MulticlassClassification.Trainers.SdcaCalibrated( + new SdcaCalibratedMulticlassTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1, })) .Append(new KeyToValueMappingEstimator(ml, "PredictedLabel")); var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs index c5098f6b30..71088fb4c5 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs @@ -40,8 +40,8 @@ void Extensibility() var pipeline = new ColumnConcatenatingEstimator (ml, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new CustomMappingEstimator(ml, action, null), TransformerScope.TrainTest) .Append(new ValueToKeyMappingEstimator(ml, "Label"), TransformerScope.TrainTest) - .Append(ml.MulticlassClassification.Trainers.Sdca( - new SdcaMulticlassTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1 })) + .Append(ml.MulticlassClassification.Trainers.SdcaCalibrated( + new SdcaCalibratedMulticlassTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1 })) .Append(new KeyToValueMappingEstimator(ml, "PredictedLabel")); var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs index 8c023bb340..311bbb0927 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs @@ -29,8 +29,8 @@ void PredictAndMetadata() var pipeline = ml.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(ml.Transforms.Conversion.MapValueToKey("Label"), TransformerScope.TrainTest) - .Append(ml.MulticlassClassification.Trainers.Sdca( - new SdcaMulticlassTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1, })); + .Append(ml.MulticlassClassification.Trainers.SdcaCalibrated( + new SdcaCalibratedMulticlassTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1, })); var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); var engine = ml.Model.CreatePredictionEngine(model); diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index c3a33f1492..e934aa694f 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -32,8 +32,8 @@ public void TrainAndPredictIrisModelTest() .Append(mlContext.Transforms.Normalize("Features")) .Append(mlContext.Transforms.Conversion.MapValueToKey("Label")) .AppendCacheCheckpoint(mlContext) - .Append(mlContext.MulticlassClassification.Trainers.Sdca( - new SdcaMulticlassTrainer.Options { NumberOfThreads = 1 })); + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated( + new SdcaCalibratedMulticlassTrainer.Options { NumberOfThreads = 1 })); // Read training and test data sets string dataPath = GetDataPath(TestDatasets.iris.trainFilename); diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index eb383d9fdd..13a4d45db6 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -37,8 +37,8 @@ public void TrainAndPredictIrisModelWithStringLabelTest() .Append(mlContext.Transforms.Normalize("Features")) .Append(mlContext.Transforms.Conversion.MapValueToKey("Label", "IrisPlantType"), TransformerScope.TrainTest) .AppendCacheCheckpoint(mlContext) - .Append(mlContext.MulticlassClassification.Trainers.Sdca( - new SdcaMulticlassTrainer.Options { NumberOfThreads = 1 })) + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated( + new SdcaCalibratedMulticlassTrainer.Options { NumberOfThreads = 1 })) .Append(mlContext.Transforms.Conversion.MapKeyToValue(("Plant", "PredictedLabel"))); // Train the pipeline diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs index db6f57ec99..2e5c2397a1 100644 --- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs @@ -39,7 +39,7 @@ public void TensorFlowTransforCifarEndToEndTest() .Append(new ColumnConcatenatingEstimator(mlContext, "Features", "Output")) .Append(new ValueToKeyMappingEstimator(mlContext, "Label")) .AppendCacheCheckpoint(mlContext) - .Append(mlContext.MulticlassClassification.Trainers.Sdca()); + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated()); var transformer = pipeEstimator.Fit(data); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs index db47181620..8b123c4498 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs @@ -30,8 +30,8 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest() .Append(mlContext.Transforms.Normalize("Features")) .Append(mlContext.Transforms.Conversion.MapValueToKey("Label")) .AppendCacheCheckpoint(mlContext) - .Append(mlContext.MulticlassClassification.Trainers.Sdca( - new SdcaMulticlassTrainer.Options { NumberOfThreads = 1 })); + .Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated( + new SdcaCalibratedMulticlassTrainer.Options { NumberOfThreads = 1 })); // Read training and test data sets string dataPath = GetDataPath(TestDatasets.iris.trainFilename); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs index e57d923d2c..06c2c288c2 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs @@ -33,7 +33,7 @@ public void TestEstimatorLogisticRegression() public void TestEstimatorMulticlassLogisticRegression() { (IEstimator pipe, IDataView dataView) = GetMulticlassPipeline(); - var trainer = ML.MulticlassClassification.Trainers.LogisticRegression(); + var trainer = ML.MulticlassClassification.Trainers.LbfgsMaximumEntropy(); var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView); @@ -163,13 +163,13 @@ public void TestLRWithStatsBackCompatibility() public void TestMLRNoStats() { (IEstimator pipe, IDataView dataView) = GetMulticlassPipeline(); - var trainer = ML.MulticlassClassification.Trainers.LogisticRegression(); + var trainer = ML.MulticlassClassification.Trainers.LbfgsMaximumEntropy(); var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView); var transformer = pipeWithTrainer.Fit(dataView); - var model = transformer.LastTransformer.Model as MulticlassLogisticRegressionModelParameters; + var model = transformer.LastTransformer.Model as MaximumEntropyModelParameters; var stats = model.Statistics; Assert.Null(stats); @@ -182,7 +182,7 @@ public void TestMLRWithStats() { (IEstimator pipe, IDataView dataView) = GetMulticlassPipeline(); - var trainer = ML.MulticlassClassification.Trainers.LogisticRegression(new LogisticRegressionMulticlassClassificationTrainer.Options + var trainer = ML.MulticlassClassification.Trainers.LbfgsMaximumEntropy(new LbfgsMaximumEntropyTrainer.Options { ShowTrainingStatistics = true }); @@ -191,9 +191,9 @@ public void TestMLRWithStats() TestEstimatorCore(pipeWithTrainer, dataView); var transformer = pipeWithTrainer.Fit(dataView); - var model = transformer.LastTransformer.Model as MulticlassLogisticRegressionModelParameters; + var model = transformer.LastTransformer.Model as MaximumEntropyModelParameters; - Action validateStats = (modelParams) => + Action validateStats = (modelParams) => { var stats = modelParams.Statistics; Assert.NotNull(stats); @@ -216,7 +216,7 @@ public void TestMLRWithStats() transformerChain = ML.Model.Load(fs, out var schema); var lastTransformer = ((TransformerChain)transformerChain).LastTransformer as MulticlassPredictionTransformer>>; - model = lastTransformer.Model as MulticlassLogisticRegressionModelParameters; + model = lastTransformer.Model as MaximumEntropyModelParameters; validateStats(model); @@ -231,7 +231,7 @@ public void TestMLRWithStatsBackCompatibility() using (FileStream fs = File.OpenRead(dropModelPath)) { - var result = ModelFileUtils.LoadPredictorOrNull(Env, fs) as MulticlassLogisticRegressionModelParameters; + var result = ModelFileUtils.LoadPredictorOrNull(Env, fs) as MaximumEntropyModelParameters; var stats = result?.Statistics; Assert.Equal(132.012238f, stats.Deviance); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs index 7734dedb49..9d4ce78811 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs @@ -37,8 +37,8 @@ public void SdcaWorkout() TestEstimatorCore(regressionTrainer, data.AsDynamic); var mcData = ML.Transforms.Conversion.MapValueToKey("Label").Fit(data.AsDynamic).Transform(data.AsDynamic); - var mcTrainer = ML.MulticlassClassification.Trainers.Sdca( - new SdcaMulticlassTrainer.Options { ConvergenceTolerance = 1e-2f, MaximumNumberOfIterations = 10 }); + var mcTrainer = ML.MulticlassClassification.Trainers.SdcaCalibrated( + new SdcaCalibratedMulticlassTrainer.Options { ConvergenceTolerance = 1e-2f, MaximumNumberOfIterations = 10 }); TestEstimatorCore(mcTrainer, mcData); Done(); @@ -130,5 +130,74 @@ public void SdcaSupportVectorMachine() Assert.True(first.Score > 0); } + [Fact] + public void SdcaMulticlassLogisticRegression() + { + // Generate C# objects as training examples. + var rawData = SamplesUtils.DatasetUtils.GenerateFloatLabelFloatFeatureVectorSamples(512); + + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + var mlContext = new MLContext(); + + // Step 1: Read the data as an IDataView. + var data = mlContext.Data.LoadFromEnumerable(rawData); + + // ML.NET doesn't cache data set by default. Caching is very helpful when working with iterative + // algorithms which needs many data passes. Since SDCA is the case, we cache. + data = mlContext.Data.Cache(data); + + // Step 2: Create a binary classifier. + // We set the "Label" column as the label of the dataset, and the "Features" column as the features column. + + var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelIndex", "Label"). + Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated(labelColumnName: "LabelIndex", featureColumnName: "Features", l2Regularization: 0.001f)); + + // Step 3: Train the pipeline created. + var model = pipeline.Fit(data); + + // Step 4: Make prediction and evaluate its quality (on training set). + var prediction = model.Transform(data); + var metrics = mlContext.MulticlassClassification.Evaluate(prediction, labelColumnName: "LabelIndex", topK: 1); + + // Check a few metrics to make sure the trained model is ok. + Assert.InRange(metrics.TopKAccuracy, 0.8, 1); + Assert.InRange(metrics.LogLoss, 0, 0.5); + } + + [Fact] + public void SdcaMulticlassSupportVectorMachine() + { + // Generate C# objects as training examples. + var rawData = SamplesUtils.DatasetUtils.GenerateFloatLabelFloatFeatureVectorSamples(512); + + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + var mlContext = new MLContext(); + + // Step 1: Read the data as an IDataView. + var data = mlContext.Data.LoadFromEnumerable(rawData); + + // ML.NET doesn't cache data set by default. Caching is very helpful when working with iterative + // algorithms which needs many data passes. Since SDCA is the case, we cache. + data = mlContext.Data.Cache(data); + + // Step 2: Create a binary classifier. + // We set the "Label" column as the label of the dataset, and the "Features" column as the features column. + var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelIndex", "Label"). + Append(mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated(labelColumnName: "LabelIndex", featureColumnName: "Features", loss: new HingeLoss(), l2Regularization: 0.001f)); + + // Step 3: Train the pipeline created. + var model = pipeline.Fit(data); + + // Step 4: Make prediction and evaluate its quality (on training set). + var prediction = model.Transform(data); + var metrics = mlContext.MulticlassClassification.Evaluate(prediction, labelColumnName: "LabelIndex", topK: 1); + + // Check a few metrics to make sure the trained model is ok. + Assert.InRange(metrics.TopKAccuracy, 0.8, 1); + Assert.InRange(metrics.MacroAccuracy, 0.8, 1); + } + } } From 807d813050796129a76837a6e2ea17a4939435b0 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Wed, 20 Mar 2019 10:01:21 -0700 Subject: [PATCH 11/18] Clean up the SchemaDefinition class (#2995) * Internalize some members of SchemaDefinition, and add tests * Code review comments * Fix build after rebase * Fix failing test * Fix build after rebase * Internalize Column ctor * Fix build after rebase --- .../Data/SchemaDefinition.cs | 62 ++++------ .../DataView/DataViewConstructionUtils.cs | 6 +- .../DataView/InternalSchemaDefinition.cs | 13 +- .../SchemaDefinitionTests.cs | 114 ++++++++++++++++++ .../Scenarios/Api/TestApi.cs | 24 ++-- 5 files changed, 161 insertions(+), 58 deletions(-) create mode 100644 test/Microsoft.ML.Functional.Tests/SchemaDefinitionTests.cs diff --git a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs index ac988e9179..c8c41903f8 100644 --- a/src/Microsoft.ML.Data/Data/SchemaDefinition.cs +++ b/src/Microsoft.ML.Data/Data/SchemaDefinition.cs @@ -203,15 +203,14 @@ public sealed class SchemaDefinition : List /// public sealed class Column { - private readonly Dictionary _annotations; - internal Dictionary Annotations { get { return _annotations; } } + internal Dictionary AnnotationInfos { get; } /// /// The name of the member the column is taken from. The API /// requires this to not be null, and a valid name of a member of /// the type for which we are creating a schema. /// - public string MemberName { get; set; } + public string MemberName { get; } /// /// The name of the column that's created in the data view. If this /// is null, the API uses the . @@ -223,34 +222,21 @@ public sealed class Column /// public DataViewType ColumnType { get; set; } - /// - /// Whether the column is a computed type. - /// - public bool IsComputed { get { return Generator != null; } } - /// /// The generator function. if the column is computed. /// - public Delegate Generator { get; set; } + internal Delegate Generator { get; set; } - public Type ReturnType => Generator?.GetMethodInfo().GetParameters().LastOrDefault().ParameterType.GetElementType(); + internal Type ReturnType => Generator?.GetMethodInfo().GetParameters().LastOrDefault().ParameterType.GetElementType(); - public Column(IExceptionContext ectx, string memberName, DataViewType columnType, - string columnName = null, IEnumerable annotationInfos = null, Delegate generator = null) + internal Column(string memberName, DataViewType columnType, + string columnName = null) { - ectx.CheckNonEmpty(memberName, nameof(memberName)); + Contracts.CheckNonEmpty(memberName, nameof(memberName)); MemberName = memberName; ColumnName = columnName ?? memberName; ColumnType = columnType; - Generator = generator; - _annotations = annotationInfos != null ? - annotationInfos.ToDictionary(m => m.Kind, m => m) - : new Dictionary(); - } - - public Column() - { - _annotations = _annotations ?? new Dictionary(); + AnnotationInfos = new Dictionary(); } /// @@ -262,22 +248,19 @@ public Column() /// The string identifier of the annotation. /// Value of annotation. /// Type of value. - public void AddAnnotation(string kind, T value, DataViewType annotationType = null) + public void AddAnnotation(string kind, T value, DataViewType annotationType) { - if (_annotations.ContainsKey(kind)) + Contracts.CheckValue(kind, nameof(kind)); + Contracts.CheckValue(annotationType, nameof(annotationType)); + + if (AnnotationInfos.ContainsKey(kind)) throw Contracts.Except("Column already contains an annotation of this kind."); - _annotations[kind] = new AnnotationInfo(kind, value, annotationType); + AnnotationInfos[kind] = new AnnotationInfo(kind, value, annotationType); } - /// - /// Remove annotation from the column if it exists. - /// - /// The string identifier of the annotation. - public void RemoveAnnotation(string kind) + internal void AddAnnotation(string kind, AnnotationInfo info) { - if (_annotations.ContainsKey(kind)) - _annotations.Remove(kind); - throw Contracts.Except("Column does not contain an annotation of kind: " + kind); + AnnotationInfos[kind] = info; } /// @@ -285,15 +268,22 @@ public void RemoveAnnotation(string kind) /// /// A dictionary with the kind of the annotation as the key, and the /// annotation type as the associated value. - public IEnumerable> GetAnnotationTypes + public DataViewSchema.Annotations Annotations { get { - return Annotations.Select(x => new KeyValuePair(x.Key, x.Value.AnnotationType)); + var builder = new DataViewSchema.Annotations.Builder(); + foreach (var kvp in AnnotationInfos) + builder.Add(kvp.Key, kvp.Value.AnnotationType, kvp.Value.GetGetterDelegate()); + return builder.ToAnnotations(); } } } + private SchemaDefinition() + { + } + /// /// Get or set the column definition by column name. /// If there's no such column: @@ -430,7 +420,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc else columnType = itemType; - cols.Add(new Column() { MemberName = memberInfo.Name, ColumnName = name, ColumnType = columnType }); + cols.Add(new Column(memberInfo.Name, columnType, name)); } return cols; } diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 1ca8392b36..127cc746ab 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -74,7 +74,7 @@ internal static SchemaDefinition GetSchemaDefinition(IHostEnvironment env, foreach (var annotation in annotations.Schema) { var info = Utils.MarshalInvoke(GetAnnotationInfo, annotation.Type.RawType, annotation.Name, annotations); - schemaDefinitionCol.Annotations.Add(annotation.Name, info); + schemaDefinitionCol.AddAnnotation(annotation.Name , info); } } } @@ -797,7 +797,7 @@ internal static DataViewSchema.DetachedColumn[] GetSchemaColumns(InternalSchemaD /// /// A single instance of annotation information, associated with a column. /// - public abstract partial class AnnotationInfo + internal abstract partial class AnnotationInfo { /// /// The type of the annotation. @@ -826,7 +826,7 @@ private protected AnnotationInfo(string kind, DataViewType annotationType) /// Strongly-typed version of , that contains the actual value of the annotation. /// /// Type of the annotation value. - public sealed class AnnotationInfo : AnnotationInfo + internal sealed class AnnotationInfo : AnnotationInfo { public readonly T Value; diff --git a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs index e906e091a3..d0b31e9243 100644 --- a/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs @@ -27,8 +27,7 @@ public class Column public readonly DataViewType ColumnType; public readonly bool IsComputed; public readonly Delegate Generator; - private readonly Dictionary _annotations; - public Dictionary Annotations { get { return _annotations; } } + public Dictionary Annotations { get; } public Type ComputedReturnType { get { return ReturnParameterInfo.ParameterType.GetElementType(); } } public Type FieldOrPropertyType => (MemberInfo is FieldInfo) ? (MemberInfo as FieldInfo).FieldType : (MemberInfo as PropertyInfo).PropertyType; public Type OutputType => IsComputed ? ComputedReturnType : FieldOrPropertyType; @@ -74,7 +73,7 @@ private Column(string columnName, DataViewType columnType, MemberInfo memberInfo ColumnType = columnType; IsComputed = generator != null; Generator = generator; - _annotations = metadataInfos == null ? new Dictionary() + Annotations = metadataInfos == null ? new Dictionary() : metadataInfos.ToDictionary(entry => entry.Key, entry => entry.Value); AssertRep(); @@ -218,7 +217,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us Type dataItemType; MemberInfo memberInfo = null; - if (!col.IsComputed) + if (col.Generator == null) { memberInfo = userType.GetField(col.MemberName); @@ -277,9 +276,9 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us colType = col.ColumnType; } - dstCols[i] = col.IsComputed ? - new Column(colName, colType, col.Generator, col.Annotations) - : new Column(colName, colType, memberInfo, col.Annotations); + dstCols[i] = col.Generator != null ? + new Column(colName, colType, col.Generator, col.AnnotationInfos) + : new Column(colName, colType, memberInfo, col.AnnotationInfos); } return new InternalSchemaDefinition(dstCols); diff --git a/test/Microsoft.ML.Functional.Tests/SchemaDefinitionTests.cs b/test/Microsoft.ML.Functional.Tests/SchemaDefinitionTests.cs new file mode 100644 index 0000000000..3a3c3dac86 --- /dev/null +++ b/test/Microsoft.ML.Functional.Tests/SchemaDefinitionTests.cs @@ -0,0 +1,114 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.RunTests; +using Microsoft.ML.TestFramework; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Functional.Tests +{ + public class SchemaDefinitionTests : BaseTestClass + { + private MLContext _ml; + + public SchemaDefinitionTests(ITestOutputHelper output) : base(output) + { + } + + protected override void Initialize() + { + base.Initialize(); + + _ml = new MLContext(42); + _ml.AddStandardComponents(); + } + + [Fact] + public void SchemaDefinitionForPredictionEngine() + { + var fileName = GetDataPath(TestDatasets.adult.trainFilename); + var loader = _ml.Data.CreateTextLoader(new TextLoader.Options(), new MultiFileSource(fileName)); + var data = loader.Load(new MultiFileSource(fileName)); + var pipeline1 = _ml.Transforms.Categorical.OneHotEncoding("Cat", "Workclass", maximumNumberOfKeys: 3) + .Append(_ml.Transforms.Concatenate("Features", "Cat", "NumericFeatures")); + var model1 = pipeline1.Fit(data); + + var pipeline2 = _ml.Transforms.Categorical.OneHotEncoding("Cat", "Workclass", maximumNumberOfKeys: 4) + .Append(_ml.Transforms.Concatenate("Features", "Cat", "NumericFeatures")); + var model2 = pipeline2.Fit(data); + + var outputSchemaDefinition = SchemaDefinition.Create(typeof(OutputData)); + outputSchemaDefinition["Features"].ColumnType = model1.GetOutputSchema(data.Schema)["Features"].Type; + var engine1 = _ml.Model.CreatePredictionEngine(model1, outputSchemaDefinition: outputSchemaDefinition); + + outputSchemaDefinition = SchemaDefinition.Create(typeof(OutputData)); + outputSchemaDefinition["Features"].ColumnType = model2.GetOutputSchema(data.Schema)["Features"].Type; + var engine2 = _ml.Model.CreatePredictionEngine(model2, outputSchemaDefinition: outputSchemaDefinition); + + var prediction = engine1.Predict(new InputData() { Workclass = "Self-emp-not-inc", NumericFeatures = new float[6] }); + Assert.Equal((engine1.OutputSchema["Features"].Type as VectorType).Size, prediction.Features.Length); + Assert.True(prediction.Features.All(x => x == 0)); + prediction = engine2.Predict(new InputData() { Workclass = "Self-emp-not-inc", NumericFeatures = new float[6] }); + Assert.Equal((engine2.OutputSchema["Features"].Type as VectorType).Size, prediction.Features.Length); + Assert.True(prediction.Features.Select((x, i) => i == 3 && x == 1 || x == 0).All(b => b)); + } + + [Fact] + public void SchemaDefinitionForCustomMapping() + { + var fileName = GetDataPath(TestDatasets.adult.trainFilename); + var data = new MultiFileSource(fileName); + var loader = _ml.Data.CreateTextLoader(new TextLoader.Options(), new MultiFileSource(fileName)); + var pipeline = _ml.Transforms.Categorical.OneHotEncoding("Categories") + .Append(_ml.Transforms.Categorical.OneHotEncoding("Workclass")) + .Append(_ml.Transforms.Concatenate("Features", "NumericFeatures", "Categories", "Workclass")) + .Append(_ml.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("Features")); + var model = pipeline.Fit(loader.Load(data)); + var schema = model.GetOutputSchema(loader.GetOutputSchema()); + + var inputSchemaDefinition = SchemaDefinition.Create(typeof(OutputData)); + inputSchemaDefinition["Features"].ColumnType = schema["Features"].Type; + var outputSchemaDefinition = SchemaDefinition.Create(typeof(OutputData)); + outputSchemaDefinition["Features"].ColumnType = new VectorType(NumberDataViewType.Single, (schema["Features"].Type as VectorType).Size * 2); + + var custom = _ml.Transforms.CustomMapping( + (OutputData src, OutputData dst) => + { + dst.Features = new float[src.Features.Length * 2]; + for (int i = 0; i < src.Features.Length; i++) + { + dst.Features[2 * i] = src.Features[i]; + dst.Features[2 * i + 1] = (float)Math.Log(src.Features[i]); + } + }, null, inputSchemaDefinition, outputSchemaDefinition); + + model = model.Append(custom.Fit(model.Transform(loader.Load(data))) as ITransformer); + schema = model.GetOutputSchema(loader.GetOutputSchema()); + Assert.Equal(168, (schema["Features"].Type as VectorType).Size); + } + + private sealed class InputData + { + [LoadColumn(0)] + public float Label { get; set; } + [LoadColumn(1)] + public string Workclass { get; set; } + [LoadColumn(2, 8)] + public string[] Categories { get; set; } + [LoadColumn(9, 14)] + [VectorType(6)] + public float[] NumericFeatures { get; set; } + } + + private sealed class OutputData + { + public float Label { get; set; } + public float[] Features { get; set; } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index b10852a1c5..cd3221732f 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -200,32 +200,32 @@ public void MetadataSupportInDataViewConstruction() // Create Metadata. var kindFloat = "Testing float as metadata."; - var valueFloat = 10; + float valueFloat = 10; var coltypeFloat = NumberDataViewType.Single; var kindString = "Testing string as metadata."; var valueString = "Strings have value."; + var coltypeString = TextDataViewType.Instance; var kindStringArray = "Testing string array as metadata."; var valueStringArray = "I really have no idea what these features entail.".Split(' '); + var coltypeStringArray = new VectorType(coltypeString, valueStringArray.Length); var kindFloatArray = "Testing float array as metadata."; var valueFloatArray = new float[] { 1, 17, 7, 19, 25, 0 }; + var coltypeFloatArray = new VectorType(coltypeFloat, valueFloatArray.Length); var kindVBuffer = "Testing VBuffer as metadata."; var valueVBuffer = new VBuffer(4, new float[] { 4, 6, 89, 5 }); - - var metaFloat = new AnnotationInfo(kindFloat, valueFloat, coltypeFloat); - var metaString = new AnnotationInfo(kindString, valueString); + var coltypeVBuffer = new VectorType(coltypeFloat, valueVBuffer.Length); // Add Metadata. var labelColumn = autoSchema[0]; - var labelColumnWithMetadata = new SchemaDefinition.Column(mlContext, labelColumn.MemberName, labelColumn.ColumnType, - annotationInfos: new AnnotationInfo[] { metaFloat, metaString }); + labelColumn.AddAnnotation(kindFloat, valueFloat, coltypeFloat); + labelColumn.AddAnnotation(kindString, valueString, coltypeString); - var featureColumnWithMetadata = autoSchema[1]; - featureColumnWithMetadata.AddAnnotation(kindStringArray, valueStringArray); - featureColumnWithMetadata.AddAnnotation(kindFloatArray, valueFloatArray); - featureColumnWithMetadata.AddAnnotation(kindVBuffer, valueVBuffer); + var featureColumn = autoSchema[1]; + featureColumn.AddAnnotation(kindStringArray, valueStringArray, coltypeStringArray); + featureColumn.AddAnnotation(kindFloatArray, valueFloatArray, coltypeFloatArray); + featureColumn.AddAnnotation(kindVBuffer, valueVBuffer, coltypeVBuffer); - var mySchema = new SchemaDefinition { labelColumnWithMetadata, featureColumnWithMetadata }; - var idv = mlContext.Data.LoadFromEnumerable(data, mySchema); + var idv = mlContext.Data.LoadFromEnumerable(data, autoSchema); Assert.True(idv.Schema[0].Annotations.Schema.Count == 2); Assert.True(idv.Schema[0].Annotations.Schema[0].Name == kindFloat); From c8a4c7dec32a294fef99425364874c9f16a7f559 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Wed, 20 Mar 2019 10:21:11 -0700 Subject: [PATCH 12/18] Data catalog done (#3021) * adding XML to a public AP that had no documentation. * adding a traintest split sample. Small corrections to the images doc.xml. --- .../Dynamic/DataOperations/BootstrapSample.cs | 2 +- .../Dynamic/DataOperations/TrainTestSplit.cs | 107 ++++++++++++++++++ .../DataLoadSave/DataOperationsCatalog.cs | 30 ++++- src/Microsoft.ML.ImageAnalytics/doc.xml | 35 +++--- 4 files changed, 154 insertions(+), 20 deletions(-) create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/TrainTestSplit.cs diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/BootstrapSample.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/BootstrapSample.cs index f765cd7729..bf56a41dbf 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/BootstrapSample.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/BootstrapSample.cs @@ -3,7 +3,7 @@ namespace Microsoft.ML.Samples.Dynamic { - public static class Bootstrap + public static class BootstrapSample { public static void Example() { diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/TrainTestSplit.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/TrainTestSplit.cs new file mode 100644 index 0000000000..eb3bdd5a5a --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/TrainTestSplit.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Microsoft.ML.Data; +using static Microsoft.ML.DataOperationsCatalog; + +namespace Microsoft.ML.Samples.Dynamic +{ + /// + /// Sample class showing how to use TrainTestSplit. + /// + public static class TrainTestSplit + { + public static void Example() + { + // Creating the ML.Net IHostEnvironment object, needed for the pipeline. + var mlContext = new MLContext(); + + // Generate some data points. + var examples = GenerateRandomDataPoints(10); + + // Convert the examples list to an IDataView object, which is consumable by ML.NET API. + var dataview = mlContext.Data.LoadFromEnumerable(examples); + + // Leave out 10% of the dataset for testing.For some types of problems, for example for ranking or anomaly detection, + // we must ensure that the split leaves the rows with the same value in a particular column, in one of the splits. + // So below, we specify Group column as the column containing the sampling keys. + // Notice how keeping the rows with the same value in the Group column overrides the testFraction definition. + TrainTestData split = mlContext.Data.TrainTestSplit(dataview, testFraction: 0.1, samplingKeyColumnName: "Group"); + + PrintPreviewRows(split); + + // The data in the Train split. + // [Group, 1], [Features, 0.8173254] + // [Group, 1], [Features, 0.5581612] + // [Group, 1], [Features, 0.5588848] + // [Group, 1], [Features, 0.4421779] + // [Group, 1], [Features, 0.2737045] + + // The data in the Test split. + // [Group, 0], [Features, 0.7262433] + // [Group, 0], [Features, 0.7680227] + // [Group, 0], [Features, 0.2060332] + // [Group, 0], [Features, 0.9060271] + // [Group, 0], [Features, 0.9775497] + + // Example of a split without specifying a sampling key column. + split = mlContext.Data.TrainTestSplit(dataview, testFraction: 0.2); + PrintPreviewRows(split); + + // The data in the Train split. + // [Group, 0], [Features, 0.7262433] + // [Group, 1], [Features, 0.8173254] + // [Group, 0], [Features, 0.7680227] + // [Group, 1], [Features, 0.5581612] + // [Group, 0], [Features, 0.2060332] + // [Group, 1], [Features, 0.4421779] + // [Group, 0], [Features, 0.9775497] + // [Group, 1], [Features, 0.2737045] + + // The data in the Test split. + // [Group, 1], [Features, 0.5588848] + // [Group, 0], [Features, 0.9060271] + + } + + private static IEnumerable GenerateRandomDataPoints(int count, int seed = 0) + { + var random = new Random(seed); + for (int i = 0; i < count; i++) + { + yield return new DataPoint + { + Group = i % 2, + + // Create random features that are correlated with label. + Features = (float)random.NextDouble() + }; + } + } + + // Example with label and group column. A data set is a collection of such examples. + private class DataPoint + { + public float Group { get; set; } + + public float Features { get; set; } + } + + // print helper + private static void PrintPreviewRows(TrainTestData split) + { + + var trainDataPreview = split.TrainSet.Preview(); + var testDataPreview = split.TestSet.Preview(); + + Console.WriteLine($"The data in the Train split."); + foreach (var row in trainDataPreview.RowView) + Console.WriteLine($"{row.Values[0]}, {row.Values[1]}"); + + Console.WriteLine($"\nThe data in the Test split."); + foreach (var row in testDataPreview.RowView) + Console.WriteLine($"{row.Values[0]}, {row.Values[1]}"); + } + } +} diff --git a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs index 10f252b6c1..e889e07548 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs @@ -70,7 +70,7 @@ internal DataOperationsCatalog(IHostEnvironment env) /// /// /// /// /// @@ -82,6 +82,25 @@ public IDataView LoadFromEnumerable(IEnumerable data, SchemaDefiniti return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schemaDefinition); } + /// + /// Create a new over an enumerable of the items of user-defined type, and the provided + /// which might contain more information about the schema than the type can capture. + /// + /// + /// The user maintains ownership of the and the resulting data view will + /// never alter the contents of the . + /// Since is assumed to be immutable, the user is expected to support + /// multiple enumeration of the that would return the same results, unless + /// the user knows that the data will only be cursored once. + /// One typical usage for streaming data view could be: create the data view that lazily loads data + /// as needed, then apply pre-trained transformations to it and cursor through it for transformation + /// results. + /// One practical usage of this would be to supply the feature column names through the . + /// + /// The to convert to an . + /// The data with to convert to an . + /// The schema of the returned . + /// An with the given . public IDataView LoadFromEnumerable(IEnumerable data, DataViewSchema schema) where TRow : class { @@ -102,7 +121,7 @@ public IDataView LoadFromEnumerable(IEnumerable data, DataViewSchema /// /// /// /// /// @@ -381,6 +400,13 @@ public IDataView TakeRows(IDataView input, long count) /// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set. /// If no row grouping will be performed. /// Seed for the random number generator used to select rows for the train-test split. + /// + /// + /// + /// + /// public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, string samplingKeyColumnName = null, int? seed = null) { _env.CheckValue(data, nameof(data)); diff --git a/src/Microsoft.ML.ImageAnalytics/doc.xml b/src/Microsoft.ML.ImageAnalytics/doc.xml index 6254c64363..9efebe2025 100644 --- a/src/Microsoft.ML.ImageAnalytics/doc.xml +++ b/src/Microsoft.ML.ImageAnalytics/doc.xml @@ -8,16 +8,15 @@ - . - - For end-to-end image processing pipelines, and scenarios in your applications, see the - [examples in the machinelearning-samples github repository](https://github.com/dotnet/machinelearning-samples/tree/master/samples/csharp/getting-started). + @@ -31,10 +30,11 @@ The ImagePixelExtractingEstimator extracts the pixels from the input images and, converts them into a vector of numbers. This can be further used as feature by the algorithms added to the pipeline. - - ImagePixelExtractingEstimator expects a in the pipeline, before it is used. - For end-to-end image processing pipelines, and scenarios in your applications, see the - examples in the machinelearning-samples github repository. + + ImagePixelExtractingEstimator expects a in the pipeline, before it is used. + For end-to-end image processing pipelines, and scenarios in your applications, see the + examples in the machinelearning-samples github repository. + @@ -50,9 +50,10 @@ extract features for usage in the machine learning algorithms. Those pre-trained models have a defined width and height for their input images, so often, after getting loaded, the images will need to get resized before further processing. - - For end-to-end image processing pipelines, and scenarios in your applications, see the - examples in the machinelearning-samples github repository. + + For end-to-end image processing pipelines, and scenarios in your applications, see the + examples in the machinelearning-samples github repository. + From ce564622db503d4a75e790543f8ffb997e088413 Mon Sep 17 00:00:00 2001 From: jignparm Date: Wed, 20 Mar 2019 10:37:49 -0700 Subject: [PATCH 13/18] Activate OnnxTransform unit tests for MacOS (#2695) * Update tests to run on mac ci leg * Updated OnnxRuntime version to 0.3.0 * Update OnnxTheory Attribute --- build/Dependencies.props | 2 +- test/Microsoft.ML.TestFramework/Attributes/OnnxFactAttribute.cs | 2 +- .../Attributes/OnnxTheoryAttribute.cs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/build/Dependencies.props b/build/Dependencies.props index a748124053..2e72ba1923 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -14,7 +14,7 @@ 3.5.1 2.2.3 - 0.2.1 + 0.3.0 0.0.0.9 2.1.3 4.5.0 diff --git a/test/Microsoft.ML.TestFramework/Attributes/OnnxFactAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/OnnxFactAttribute.cs index b6b5502fdc..150c74decf 100644 --- a/test/Microsoft.ML.TestFramework/Attributes/OnnxFactAttribute.cs +++ b/test/Microsoft.ML.TestFramework/Attributes/OnnxFactAttribute.cs @@ -19,7 +19,7 @@ public OnnxFactAttribute() : base("Onnx is 64-bit Windows only") /// protected override bool IsEnvironmentSupported() { - return Environment.Is64BitProcess && (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) || AttributeHelpers.CheckLibcVersionGreaterThanMinimum(new Version(2, 23))); + return Environment.Is64BitProcess && (!RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || AttributeHelpers.CheckLibcVersionGreaterThanMinimum(new Version(2, 23))); } } } \ No newline at end of file diff --git a/test/Microsoft.ML.TestFramework/Attributes/OnnxTheoryAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/OnnxTheoryAttribute.cs index d3be5d3199..0bc1606be4 100644 --- a/test/Microsoft.ML.TestFramework/Attributes/OnnxTheoryAttribute.cs +++ b/test/Microsoft.ML.TestFramework/Attributes/OnnxTheoryAttribute.cs @@ -19,7 +19,7 @@ public OnnxTheoryAttribute() : base("Onnx is 64-bit Windows only") /// protected override bool IsEnvironmentSupported() { - return Environment.Is64BitProcess && (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) || AttributeHelpers.CheckLibcVersionGreaterThanMinimum(new Version(2, 23))); + return Environment.Is64BitProcess && (!RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || AttributeHelpers.CheckLibcVersionGreaterThanMinimum(new Version(2, 23))); } } } \ No newline at end of file From e00d19dcf8ebad884f2fe527d85edc3a7e319b80 Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed <38438266+zeahmed@users.noreply.github.com> Date: Wed, 20 Mar 2019 10:46:00 -0700 Subject: [PATCH 14/18] Added tests for text featurizer options (Part1). (#3006) --- .../Text/TextFeaturizingEstimator.cs | 2 +- .../Transformers/TextFeaturizerTests.cs | 214 +++++++++++++++++- 2 files changed, 214 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs index e9884f912a..2d65de8ef8 100644 --- a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs +++ b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs @@ -393,7 +393,7 @@ internal TextFeaturizingEstimator(IHostEnvironment env, string name, IEnumerable if (options != null) OptionalSettings = options; - _stopWordsRemover = null; + _stopWordsRemover = OptionalSettings.StopWordsRemover; _dictionary = null; _wordFeatureExtractor = OptionalSettings.WordFeatureExtractorFactory; _charFeatureExtractor = OptionalSettings.CharFeatureExtractorFactory; diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index 426dba07cb..e0f32bcaa1 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -1,9 +1,10 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; using System.IO; +using System.Text.RegularExpressions; using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Data.IO; @@ -26,6 +27,217 @@ public TextFeaturizerTests(ITestOutputHelper helper) { } + private class TestClass + { + public string A; + public string[] OutputTokens; + } + + [Fact] + public void TextFeaturizerWithPredefinedStopWordRemoverTest() + { + var data = new[] { new TestClass() { A = "This is some text with english stop words", OutputTokens=null}, + new TestClass() { A = "No stop words", OutputTokens=null } }; + var dataView = ML.Data.LoadFromEnumerable(data); + + var options = new TextFeaturizingEstimator.Options() { StopWordsRemoverOptions = new StopWordsRemovingEstimator.Options(), OutputTokensColumnName = "OutputTokens" }; + var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A"); + var model = pipeline.Fit(dataView); + var engine = model.CreatePredictionEngine(ML); + var prediction = engine.Predict(data[0]); + Assert.Equal("text english stop words", string.Join(" ", prediction.OutputTokens)); + + prediction = engine.Predict(data[1]); + Assert.Equal("stop words", string.Join(" ", prediction.OutputTokens)); + } + + [Fact] + public void TextFeaturizerWithCustomStopWordRemoverTest() + { + var data = new[] { new TestClass() { A = "This is some text with english stop words", OutputTokens=null}, + new TestClass() { A = "No stop words", OutputTokens=null } }; + var dataView = ML.Data.LoadFromEnumerable(data); + + var options = new TextFeaturizingEstimator.Options() + { + StopWordsRemoverOptions = new CustomStopWordsRemovingEstimator.Options() + { + StopWords = new[] { "stop", "words" } + }, + OutputTokensColumnName = "OutputTokens", + CaseMode = TextNormalizingEstimator.CaseMode.None + }; + var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A"); + var model = pipeline.Fit(dataView); + var engine = model.CreatePredictionEngine(ML); + var prediction = engine.Predict(data[0]); + Assert.Equal("This is some text with english", string.Join(" ", prediction.OutputTokens)); + + prediction = engine.Predict(data[1]); + Assert.Equal("No", string.Join(" ", prediction.OutputTokens)); + } + + private void TestCaseMode(IDataView dataView, TestClass[] data, TextNormalizingEstimator.CaseMode caseMode) + { + var options = new TextFeaturizingEstimator.Options() + { + CaseMode = caseMode, + OutputTokensColumnName = "OutputTokens" + }; + var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A"); + var model = pipeline.Fit(dataView); + var engine = model.CreatePredictionEngine(ML); + var prediction1 = engine.Predict(data[0]); + var prediction2 = engine.Predict(data[1]); + + string expected1 = null; + string expected2 = null; + if (caseMode == TextNormalizingEstimator.CaseMode.Upper) + { + expected1 = data[0].A.ToUpper(); + expected2 = data[1].A.ToUpper(); + } + else if (caseMode == TextNormalizingEstimator.CaseMode.Lower) + { + expected1 = data[0].A.ToLower(); + expected2 = data[1].A.ToLower(); + } + else if (caseMode == TextNormalizingEstimator.CaseMode.None) + { + expected1 = data[0].A; + expected2 = data[1].A; + } + + Assert.Equal(expected1, string.Join(" ", prediction1.OutputTokens)); + Assert.Equal(expected2, string.Join(" ", prediction2.OutputTokens)); + } + + [Fact] + public void TextFeaturizerWithUpperCaseTest() + { + var data = new[] { new TestClass() { A = "This is some text with english stop words", OutputTokens=null}, + new TestClass() { A = "No stop words", OutputTokens=null } }; + var dataView = ML.Data.LoadFromEnumerable(data); + + TestCaseMode(dataView, data, TextNormalizingEstimator.CaseMode.Lower); + TestCaseMode(dataView, data, TextNormalizingEstimator.CaseMode.Upper); + TestCaseMode(dataView, data, TextNormalizingEstimator.CaseMode.None); + } + + + private void TestKeepNumbers(IDataView dataView, TestClass[] data, bool keepNumbers) + { + var options = new TextFeaturizingEstimator.Options() + { + KeepNumbers = keepNumbers, + CaseMode = TextNormalizingEstimator.CaseMode.None, + OutputTokensColumnName = "OutputTokens" + }; + var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A"); + var model = pipeline.Fit(dataView); + var engine = model.CreatePredictionEngine(ML); + var prediction1 = engine.Predict(data[0]); + var prediction2 = engine.Predict(data[1]); + + if (keepNumbers) + { + Assert.Equal(data[0].A, string.Join(" ", prediction1.OutputTokens)); + Assert.Equal(data[1].A, string.Join(" ", prediction2.OutputTokens)); + } + else + { + Assert.Equal(data[0].A.Replace("123 ", "").Replace("425", "").Replace("25", "").Replace("23", ""), string.Join(" ", prediction1.OutputTokens)); + Assert.Equal(data[1].A, string.Join(" ", prediction2.OutputTokens)); + } + } + + [Fact] + public void TextFeaturizerWithKeepNumbersTest() + { + var data = new[] { new TestClass() { A = "This is some text with numbers 123 $425 25.23", OutputTokens=null}, + new TestClass() { A = "No numbers", OutputTokens=null } }; + var dataView = ML.Data.LoadFromEnumerable(data); + + TestKeepNumbers(dataView, data, true); + TestKeepNumbers(dataView, data, false); + } + + private void TestKeepPunctuations(IDataView dataView, TestClass[] data, bool keepPunctuations) + { + var options = new TextFeaturizingEstimator.Options() + { + KeepPunctuations = keepPunctuations, + CaseMode = TextNormalizingEstimator.CaseMode.None, + OutputTokensColumnName = "OutputTokens" + }; + var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A"); + var model = pipeline.Fit(dataView); + var engine = model.CreatePredictionEngine(ML); + var prediction1 = engine.Predict(data[0]); + var prediction2 = engine.Predict(data[1]); + + if (keepPunctuations) + { + Assert.Equal(data[0].A, string.Join(" ", prediction1.OutputTokens)); + Assert.Equal(data[1].A, string.Join(" ", prediction2.OutputTokens)); + } + else + { + var expected = Regex.Replace(data[0].A, "[,|_|'|\"|;|\\.]", ""); + Assert.Equal(expected, string.Join(" ", prediction1.OutputTokens)); + Assert.Equal(data[1].A, string.Join(" ", prediction2.OutputTokens)); + } + } + + [Fact] + public void TextFeaturizerWithKeepPunctuationsTest() + { + var data = new[] { new TestClass() { A = "This, is; some_ ,text 'with\" punctuations.", OutputTokens=null}, + new TestClass() { A = "No punctuations", OutputTokens=null } }; + var dataView = ML.Data.LoadFromEnumerable(data); + + TestKeepPunctuations(dataView, data, true); + TestKeepPunctuations(dataView, data, false); + } + + private void TestKeepDiacritics(IDataView dataView, TestClass[] data, bool keepDiacritics) + { + var options = new TextFeaturizingEstimator.Options() + { + KeepDiacritics = keepDiacritics, + CaseMode = TextNormalizingEstimator.CaseMode.None, + OutputTokensColumnName = "OutputTokens" + }; + var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A"); + var model = pipeline.Fit(dataView); + var engine = model.CreatePredictionEngine(ML); + var prediction1 = engine.Predict(data[0]); + var prediction2 = engine.Predict(data[1]); + + if (keepDiacritics) + { + Assert.Equal(data[0].A, string.Join(" ", prediction1.OutputTokens)); + Assert.Equal(data[1].A, string.Join(" ", prediction2.OutputTokens)); + } + else + { + Assert.Equal("This is some text with diacritics", string.Join(" ", prediction1.OutputTokens)); + Assert.Equal(data[1].A, string.Join(" ", prediction2.OutputTokens)); + } + } + + [Fact] + public void TextFeaturizerWithKeepDiacriticsTest() + { + var data = new[] { new TestClass() { A = "Thîs îs sóme text with diácrîtîcs", OutputTokens=null}, + new TestClass() { A = "No diacritics", OutputTokens=null } }; + var dataView = ML.Data.LoadFromEnumerable(data); + + TestKeepDiacritics(dataView, data, true); + TestKeepDiacritics(dataView, data, false); + } + + [Fact] public void TextFeaturizerWorkout() { From a2d7987e8d334b6137c494d6407b782d79f4d3c0 Mon Sep 17 00:00:00 2001 From: Shahab Moradi Date: Wed, 20 Mar 2019 15:07:15 -0400 Subject: [PATCH 15/18] Binary FastTree/Forest samples using T4 templates. (#3035) * Binary FastTree/Forest samples using T4 templates. * Addressed comments. --- .../BinaryClassification/FastForest.cs | 99 +++++++++++++++ .../BinaryClassification/FastForest.tt | 24 ++++ .../FastForestWithOptions.cs | 112 +++++++++++++++++ .../FastForestWithOptions.tt | 33 +++++ .../Trainers/BinaryClassification/FastTree.cs | 103 ++++++++++++++++ .../Trainers/BinaryClassification/FastTree.tt | 27 ++++ .../FastTreeWithOptions.cs | 115 ++++++++++++++++++ .../FastTreeWithOptions.tt | 36 ++++++ .../TreeSamplesTemplate.ttinclude | 99 +++++++++++++++ .../Microsoft.ML.Samples.csproj | 66 ++++++++++ .../TreeTrainersCatalog.cs | 29 +++++ 11 files changed, 743 insertions(+) create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.cs create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.tt create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForestWithOptions.cs create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForestWithOptions.tt create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTree.cs create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTree.tt create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTreeWithOptions.cs create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTreeWithOptions.tt create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/TreeSamplesTemplate.ttinclude diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.cs new file mode 100644 index 0000000000..8432d2b0d3 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.cs @@ -0,0 +1,99 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification +{ + public static class FastForest + { + // This example requires installation of additional NuGet package + // Microsoft.ML.FastTree. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Create a list of training examples. + var examples = GenerateRandomDataPoints(1000); + + // Convert the examples list to an IDataView object, which is consumable by ML.NET API. + var trainingData = mlContext.Data.LoadFromEnumerable(examples); + + // Define the trainer. + var pipeline = mlContext.BinaryClassification.Trainers.FastForest(); + + // Train the model. + var model = pipeline.Fit(trainingData); + + // Create testing examples. Use different random seed to make it different from training data. + var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); + + // Run the model on test data set. + var transformedTestData = model.Transform(testData); + + // Convert IDataView object to a list. + var predictions = mlContext.Data.CreateEnumerable(transformedTestData, reuseRowObject: false).ToList(); + + // Look at 5 predictions + foreach (var p in predictions.Take(5)) + Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}"); + + // Expected output: + // Label: True, Prediction: True + // Label: False, Prediction: False + // Label: True, Prediction: True + // Label: True, Prediction: True + // Label: False, Prediction: False + + // Evaluate the overall metrics + var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(transformedTestData); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Accuracy: 0.74 + // AUC: 0.83 + // F1 Score: 0.74 + // Negative Precision: 0.78 + // Negative Recall: 0.71 + // Positive Precision: 0.71 + // Positive Recall: 0.78 + } + + private static IEnumerable GenerateRandomDataPoints(int count, int seed=0) + { + var random = new Random(seed); + float randomFloat() => (float)random.NextDouble(); + for (int i = 0; i < count; i++) + { + var label = randomFloat() > 0.5f; + yield return new DataPoint + { + Label = label, + // Create random features that are correlated with label. + Features = Enumerable.Repeat(label, 50).Select(x => x ? randomFloat() : randomFloat() + 0.03f).ToArray() + }; + } + } + + // Example with label and 50 feature values. A data set is a collection of such examples. + private class DataPoint + { + public bool Label { get; set; } + [VectorType(50)] + public float[] Features { get; set; } + } + + // Class used to capture predictions. + private class Prediction + { + // Original label. + public bool Label { get; set; } + // Predicted label from the trainer. + public bool PredictedLabel { get; set; } + } + } +} + diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.tt b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.tt new file mode 100644 index 0000000000..9648bd8410 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.tt @@ -0,0 +1,24 @@ +<#@ include file="TreeSamplesTemplate.ttinclude"#> + +<#+ +string ClassName="FastForest"; +string Trainer = "FastForest"; +string TrainerOptions = null; +bool IsCalibrated = false; + +string ExpectedOutputPerInstance= @"// Expected output: + // Label: True, Prediction: True + // Label: False, Prediction: False + // Label: True, Prediction: True + // Label: True, Prediction: True + // Label: False, Prediction: False"; + +string ExpectedOutput = @"// Expected output: + // Accuracy: 0.74 + // AUC: 0.83 + // F1 Score: 0.74 + // Negative Precision: 0.78 + // Negative Recall: 0.71 + // Positive Precision: 0.71 + // Positive Recall: 0.78"; +#> \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForestWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForestWithOptions.cs new file mode 100644 index 0000000000..d243c54c69 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForestWithOptions.cs @@ -0,0 +1,112 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers.FastTree; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification +{ + public static class FastForestWithOptions + { + // This example requires installation of additional NuGet package + // Microsoft.ML.FastTree. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Create a list of training data points. + var dataPoints = GenerateRandomDataPoints(1000); + + // Convert the list of data points to an IDataView object, which is consumable by ML.NET API. + var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); + + // Define trainer options. + var options = new FastForestBinaryTrainer.Options + { + // Only use 80% of features to reduce over-fitting. + FeatureFraction = 0.8, + // Create a simpler model by penalizing usage of new features. + FeatureFirstUsePenalty = 0.1, + // Reduce the number of trees to 50. + NumberOfTrees = 50 + }; + + // Define the trainer. + var pipeline = mlContext.BinaryClassification.Trainers.FastForest(options); + + // Train the model. + var model = pipeline.Fit(trainingData); + + // Create testing data. Use different random seed to make it different from training data. + var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); + + // Run the model on test data set. + var transformedTestData = model.Transform(testData); + + // Convert IDataView object to a list. + var predictions = mlContext.Data.CreateEnumerable(transformedTestData, reuseRowObject: false).ToList(); + + // Look at 5 predictions + foreach (var p in predictions.Take(5)) + Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}"); + + // Expected output: + // Label: True, Prediction: True + // Label: False, Prediction: False + // Label: True, Prediction: True + // Label: True, Prediction: True + // Label: False, Prediction: True + + // Evaluate the overall metrics + var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(transformedTestData); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Accuracy: 0.73 + // AUC: 0.81 + // F1 Score: 0.73 + // Negative Precision: 0.77 + // Negative Recall: 0.68 + // Positive Precision: 0.69 + // Positive Recall: 0.78 + } + + private static IEnumerable GenerateRandomDataPoints(int count, int seed=0) + { + var random = new Random(seed); + float randomFloat() => (float)random.NextDouble(); + for (int i = 0; i < count; i++) + { + var label = randomFloat() > 0.5f; + yield return new DataPoint + { + Label = label, + // Create random features that are correlated with the label. + // For data points with false label, the feature values are slightly increased by adding a constant. + Features = Enumerable.Repeat(label, 50).Select(x => x ? randomFloat() : randomFloat() + 0.03f).ToArray() + }; + } + } + + // Example with label and 50 feature values. A data set is a collection of such examples. + private class DataPoint + { + public bool Label { get; set; } + [VectorType(50)] + public float[] Features { get; set; } + } + + // Class used to capture predictions. + private class Prediction + { + // Original label. + public bool Label { get; set; } + // Predicted label from the trainer. + public bool PredictedLabel { get; set; } + } + } +} + diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForestWithOptions.tt b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForestWithOptions.tt new file mode 100644 index 0000000000..14bc406273 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForestWithOptions.tt @@ -0,0 +1,33 @@ +<#@ include file="TreeSamplesTemplate.ttinclude"#> + +<#+ +string ClassName="FastForestWithOptions"; +string Trainer = "FastForest"; +bool IsCalibrated = false; + +string TrainerOptions = @"FastForestBinaryTrainer.Options + { + // Only use 80% of features to reduce over-fitting. + FeatureFraction = 0.8, + // Create a simpler model by penalizing usage of new features. + FeatureFirstUsePenalty = 0.1, + // Reduce the number of trees to 50. + NumberOfTrees = 50 + }"; + +string ExpectedOutputPerInstance= @"// Expected output: + // Label: True, Prediction: True + // Label: False, Prediction: False + // Label: True, Prediction: True + // Label: True, Prediction: True + // Label: False, Prediction: True"; + +string ExpectedOutput = @"// Expected output: + // Accuracy: 0.73 + // AUC: 0.81 + // F1 Score: 0.73 + // Negative Precision: 0.77 + // Negative Recall: 0.68 + // Positive Precision: 0.69 + // Positive Recall: 0.78"; +#> \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTree.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTree.cs new file mode 100644 index 0000000000..eae52f11ab --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTree.cs @@ -0,0 +1,103 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification +{ + public static class FastTree + { + // This example requires installation of additional NuGet package + // Microsoft.ML.FastTree. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Create a list of training data points. + var dataPoints = GenerateRandomDataPoints(1000); + + // Convert the list of data points to an IDataView object, which is consumable by ML.NET API. + var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); + + // Define the trainer. + var pipeline = mlContext.BinaryClassification.Trainers.FastTree(); + + // Train the model. + var model = pipeline.Fit(trainingData); + + // Create testing data. Use different random seed to make it different from training data. + var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); + + // Run the model on test data set. + var transformedTestData = model.Transform(testData); + + // Convert IDataView object to a list. + var predictions = mlContext.Data.CreateEnumerable(transformedTestData, reuseRowObject: false).ToList(); + + // Look at 5 predictions + foreach (var p in predictions.Take(5)) + Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}"); + + // Expected output: + // Label: True, Prediction: True + // Label: False, Prediction: False + // Label: True, Prediction: True + // Label: True, Prediction: True + // Label: False, Prediction: False + + // Evaluate the overall metrics + var metrics = mlContext.BinaryClassification.Evaluate(transformedTestData); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Accuracy: 0.81 + // AUC: 0.91 + // F1 Score: 0.80 + // Negative Precision: 0.82 + // Negative Recall: 0.80 + // Positive Precision: 0.79 + // Positive Recall: 0.81 + // Log Loss: 0.59 + // Log Loss Reduction: 41.04 + // Entropy: 1.00 + } + + private static IEnumerable GenerateRandomDataPoints(int count, int seed=0) + { + var random = new Random(seed); + float randomFloat() => (float)random.NextDouble(); + for (int i = 0; i < count; i++) + { + var label = randomFloat() > 0.5f; + yield return new DataPoint + { + Label = label, + // Create random features that are correlated with the label. + // For data points with false label, the feature values are slightly increased by adding a constant. + Features = Enumerable.Repeat(label, 50).Select(x => x ? randomFloat() : randomFloat() + 0.03f).ToArray() + }; + } + } + + // Example with label and 50 feature values. A data set is a collection of such examples. + private class DataPoint + { + public bool Label { get; set; } + [VectorType(50)] + public float[] Features { get; set; } + } + + // Class used to capture predictions. + private class Prediction + { + // Original label. + public bool Label { get; set; } + // Predicted label from the trainer. + public bool PredictedLabel { get; set; } + } + } +} + diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTree.tt b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTree.tt new file mode 100644 index 0000000000..58f233a04a --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTree.tt @@ -0,0 +1,27 @@ +<#@ include file="TreeSamplesTemplate.ttinclude"#> + +<#+ +string ClassName="FastTree"; +string Trainer = "FastTree"; +string TrainerOptions = null; +bool IsCalibrated = true; + +string ExpectedOutputPerInstance= @"// Expected output: + // Label: True, Prediction: True + // Label: False, Prediction: False + // Label: True, Prediction: True + // Label: True, Prediction: True + // Label: False, Prediction: False"; + +string ExpectedOutput = @"// Expected output: + // Accuracy: 0.81 + // AUC: 0.91 + // F1 Score: 0.80 + // Negative Precision: 0.82 + // Negative Recall: 0.80 + // Positive Precision: 0.79 + // Positive Recall: 0.81 + // Log Loss: 0.59 + // Log Loss Reduction: 41.04 + // Entropy: 1.00"; +#> \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTreeWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTreeWithOptions.cs new file mode 100644 index 0000000000..26493f0b12 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTreeWithOptions.cs @@ -0,0 +1,115 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers.FastTree; + +namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification +{ + public static class FastTreeWithOptions + { + // This example requires installation of additional NuGet package + // Microsoft.ML.FastTree. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Create a list of training data points. + var dataPoints = GenerateRandomDataPoints(1000); + + // Convert the list of data points to an IDataView object, which is consumable by ML.NET API. + var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); + + // Define trainer options. + var options = new FastTreeBinaryTrainer.Options + { + // Use L2Norm for early stopping. + EarlyStoppingMetric = EarlyStoppingMetric.L2Norm, + // Create a simpler model by penalizing usage of new features. + FeatureFirstUsePenalty = 0.1, + // Reduce the number of trees to 50. + NumberOfTrees = 50 + }; + + // Define the trainer. + var pipeline = mlContext.BinaryClassification.Trainers.FastTree(options); + + // Train the model. + var model = pipeline.Fit(trainingData); + + // Create testing data. Use different random seed to make it different from training data. + var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); + + // Run the model on test data set. + var transformedTestData = model.Transform(testData); + + // Convert IDataView object to a list. + var predictions = mlContext.Data.CreateEnumerable(transformedTestData, reuseRowObject: false).ToList(); + + // Look at 5 predictions + foreach (var p in predictions.Take(5)) + Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}"); + + // Expected output: + // Label: True, Prediction: True + // Label: False, Prediction: False + // Label: True, Prediction: True + // Label: True, Prediction: True + // Label: False, Prediction: False + + // Evaluate the overall metrics + var metrics = mlContext.BinaryClassification.Evaluate(transformedTestData); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Accuracy: 0.78 + // AUC: 0.88 + // F1 Score: 0.79 + // Negative Precision: 0.83 + // Negative Recall: 0.74 + // Positive Precision: 0.74 + // Positive Recall: 0.84 + // Log Loss: 0.62 + // Log Loss Reduction: 37.77 + // Entropy: 1.00 + } + + private static IEnumerable GenerateRandomDataPoints(int count, int seed=0) + { + var random = new Random(seed); + float randomFloat() => (float)random.NextDouble(); + for (int i = 0; i < count; i++) + { + var label = randomFloat() > 0.5f; + yield return new DataPoint + { + Label = label, + // Create random features that are correlated with the label. + // For data points with false label, the feature values are slightly increased by adding a constant. + Features = Enumerable.Repeat(label, 50).Select(x => x ? randomFloat() : randomFloat() + 0.03f).ToArray() + }; + } + } + + // Example with label and 50 feature values. A data set is a collection of such examples. + private class DataPoint + { + public bool Label { get; set; } + [VectorType(50)] + public float[] Features { get; set; } + } + + // Class used to capture predictions. + private class Prediction + { + // Original label. + public bool Label { get; set; } + // Predicted label from the trainer. + public bool PredictedLabel { get; set; } + } + } +} + diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTreeWithOptions.tt b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTreeWithOptions.tt new file mode 100644 index 0000000000..d27f82391e --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastTreeWithOptions.tt @@ -0,0 +1,36 @@ +<#@ include file="TreeSamplesTemplate.ttinclude"#> + +<#+ +string ClassName="FastTreeWithOptions"; +string Trainer = "FastTree"; +bool IsCalibrated = true; + +string TrainerOptions = @"FastTreeBinaryTrainer.Options + { + // Use L2Norm for early stopping. + EarlyStoppingMetric = EarlyStoppingMetric.L2Norm, + // Create a simpler model by penalizing usage of new features. + FeatureFirstUsePenalty = 0.1, + // Reduce the number of trees to 50. + NumberOfTrees = 50 + }"; + +string ExpectedOutputPerInstance= @"// Expected output: + // Label: True, Prediction: True + // Label: False, Prediction: False + // Label: True, Prediction: True + // Label: True, Prediction: True + // Label: False, Prediction: False"; + +string ExpectedOutput = @"// Expected output: + // Accuracy: 0.78 + // AUC: 0.88 + // F1 Score: 0.79 + // Negative Precision: 0.83 + // Negative Recall: 0.74 + // Positive Precision: 0.74 + // Positive Recall: 0.84 + // Log Loss: 0.62 + // Log Loss Reduction: 37.77 + // Entropy: 1.00"; +#> \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/TreeSamplesTemplate.ttinclude b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/TreeSamplesTemplate.ttinclude new file mode 100644 index 0000000000..b25c54f2a8 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/TreeSamplesTemplate.ttinclude @@ -0,0 +1,99 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; +<# if (TrainerOptions != null) { #> +using Microsoft.ML.Trainers.FastTree; +<# } #> + +namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification +{ + public static class <#=ClassName#> + { + // This example requires installation of additional NuGet package + // Microsoft.ML.FastTree. + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + // Setting the seed to a fixed number in this example to make outputs deterministic. + var mlContext = new MLContext(seed: 0); + + // Create a list of training data points. + var dataPoints = GenerateRandomDataPoints(1000); + + // Convert the list of data points to an IDataView object, which is consumable by ML.NET API. + var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); + +<# if (TrainerOptions == null) { #> + // Define the trainer. + var pipeline = mlContext.BinaryClassification.Trainers.<#=Trainer#>(); +<# } else { #> + // Define trainer options. + var options = new <#=TrainerOptions#>; + + // Define the trainer. + var pipeline = mlContext.BinaryClassification.Trainers.<#=Trainer#>(options); +<# } #> + + // Train the model. + var model = pipeline.Fit(trainingData); + + // Create testing data. Use different random seed to make it different from training data. + var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); + + // Run the model on test data set. + var transformedTestData = model.Transform(testData); + + // Convert IDataView object to a list. + var predictions = mlContext.Data.CreateEnumerable(transformedTestData, reuseRowObject: false).ToList(); + + // Look at 5 predictions + foreach (var p in predictions.Take(5)) + Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}"); + + <#=ExpectedOutputPerInstance#> + <# string Evaluator = IsCalibrated ? "Evaluate" : "EvaluateNonCalibrated"; #> + + // Evaluate the overall metrics + var metrics = mlContext.BinaryClassification.<#=Evaluator#>(transformedTestData); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + <#=ExpectedOutput#> + } + + private static IEnumerable GenerateRandomDataPoints(int count, int seed=0) + { + var random = new Random(seed); + float randomFloat() => (float)random.NextDouble(); + for (int i = 0; i < count; i++) + { + var label = randomFloat() > 0.5f; + yield return new DataPoint + { + Label = label, + // Create random features that are correlated with the label. + // For data points with false label, the feature values are slightly increased by adding a constant. + Features = Enumerable.Repeat(label, 50).Select(x => x ? randomFloat() : randomFloat() + 0.03f).ToArray() + }; + } + } + + // Example with label and 50 feature values. A data set is a collection of such examples. + private class DataPoint + { + public bool Label { get; set; } + [VectorType(50)] + public float[] Features { get; set; } + } + + // Class used to capture predictions. + private class Prediction + { + // Original label. + public bool Label { get; set; } + // Predicted label from the trainer. + public bool PredictedLabel { get; set; } + } + } +} \ No newline at end of file diff --git a/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj b/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj index a316dd1abf..7062f985a4 100644 --- a/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj +++ b/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj @@ -26,6 +26,26 @@ + + True + True + FastForest.tt + + + True + True + FastForestWithOptions.tt + + + True + True + FastTree.tt + + + True + True + FastTreeWithOptions.tt + @@ -35,4 +55,50 @@ + + + TextTemplatingFileGenerator + FastForest.cs + + + TextTemplatingFileGenerator + FastForestWithOptions.cs + + + TextTemplatingFileGenerator + FastTree.cs + + + TextTemplatingFileGenerator + FastTreeWithOptions.cs + + + + + + + + + + True + True + FastForest.tt + + + True + True + FastForestWithOptions.tt + + + True + True + FastTree.tt + + + True + True + FastTreeWithOptions.tt + + + diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs index 5d14420148..9e39208e36 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -78,6 +78,13 @@ public static FastTreeRegressionTrainer FastTree(this RegressionCatalog.Regressi /// The maximum number of leaves per decision tree. /// The minimal number of data points required to form a new tree leaf. /// The learning rate. + /// + /// + /// + /// + /// public static FastTreeBinaryTrainer FastTree(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, @@ -97,6 +104,13 @@ public static FastTreeBinaryTrainer FastTree(this BinaryClassificationCatalog.Bi /// /// The . /// Trainer options. + /// + /// + /// + /// + /// public static FastTreeBinaryTrainer FastTree(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, FastTreeBinaryTrainer.Options options) { @@ -351,6 +365,13 @@ public static FastForestRegressionTrainer FastForest(this RegressionCatalog.Regr /// Total number of decision trees to create in the ensemble. /// The maximum number of leaves per decision tree. /// The minimal number of data points required to form a new tree leaf. + /// + /// + /// + /// + /// public static FastForestBinaryTrainer FastForest(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, @@ -369,6 +390,14 @@ public static FastForestBinaryTrainer FastForest(this BinaryClassificationCatalo /// /// The . /// Trainer options. + /// + /// + /// + /// + /// + public static FastForestBinaryTrainer FastForest(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, FastForestBinaryTrainer.Options options) { From 77be9d9f5cfd64812d6474dadb76921d8f42aa5b Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 20 Mar 2019 20:59:46 -0700 Subject: [PATCH 16/18] Polish standard trainers' catalog (Just rename some variables) (#3029) --- .../BinaryClassification/FastForest.cs | 13 +-- ...hasticDualCoordinateAscentNonCalibrated.cs | 2 +- ...GradientDescentNonCalibratedWithOptions.cs | 2 +- .../Metrics/AnomalyDetectionMetrics.cs | 4 +- src/Microsoft.ML.Data/TrainCatalog.cs | 2 +- .../TreeTrainersCatalog.cs | 18 ++-- .../LogisticRegression/LbfgsPredictorBase.cs | 3 +- .../LogisticRegression/LogisticRegression.cs | 16 ++-- .../Standard/SdcaBinary.cs | 55 +++++------ .../Standard/SdcaMulticlass.cs | 12 +-- .../StandardTrainersCatalog.cs | 94 +++++++++---------- .../EvaluatorStaticExtensions.cs | 8 +- .../SdcaStaticExtensions.cs | 28 +++--- src/Microsoft.ML.StaticPipe/SgdStatic.cs | 20 ++-- test/Microsoft.ML.Functional.Tests/Common.cs | 2 +- .../Training.cs | 6 +- .../AnomalyDetectionTests.cs | 2 +- .../TrainerEstimators/MetalinearEstimators.cs | 2 +- .../TrainerEstimators/SdcaTests.cs | 4 +- .../TrainerEstimators/TrainerEstimators.cs | 4 +- 20 files changed, 149 insertions(+), 148 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.cs index 8432d2b0d3..587499997d 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FastForest.cs @@ -16,11 +16,11 @@ public static void Example() // Setting the seed to a fixed number in this example to make outputs deterministic. var mlContext = new MLContext(seed: 0); - // Create a list of training examples. - var examples = GenerateRandomDataPoints(1000); + // Create a list of training data points. + var dataPoints = GenerateRandomDataPoints(1000); - // Convert the examples list to an IDataView object, which is consumable by ML.NET API. - var trainingData = mlContext.Data.LoadFromEnumerable(examples); + // Convert the list of data points to an IDataView object, which is consumable by ML.NET API. + var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); // Define the trainer. var pipeline = mlContext.BinaryClassification.Trainers.FastForest(); @@ -28,7 +28,7 @@ public static void Example() // Train the model. var model = pipeline.Fit(trainingData); - // Create testing examples. Use different random seed to make it different from training data. + // Create testing data. Use different random seed to make it different from training data. var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123)); // Run the model on test data set. @@ -72,7 +72,8 @@ private static IEnumerable GenerateRandomDataPoints(int count, int se yield return new DataPoint { Label = label, - // Create random features that are correlated with label. + // Create random features that are correlated with the label. + // For data points with false label, the feature values are slightly increased by adding a constant. Features = Enumerable.Repeat(label, 50).Select(x => x ? randomFloat() : randomFloat() + 0.03f).ToArray() }; } diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticDualCoordinateAscentNonCalibrated.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticDualCoordinateAscentNonCalibrated.cs index 3459becb1c..964add1503 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticDualCoordinateAscentNonCalibrated.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticDualCoordinateAscentNonCalibrated.cs @@ -41,7 +41,7 @@ public static void Example() // Step 2: Create a binary classifier. This trainer may produce a logistic regression model. // We set the "Label" column as the label of the dataset, and the "Features" column as the features column. var pipeline = mlContext.BinaryClassification.Trainers.SdcaNonCalibrated( - labelColumnName: "Label", featureColumnName: "Features", loss: new HingeLoss(), l2Regularization: 0.001f); + labelColumnName: "Label", featureColumnName: "Features", lossFunction: new HingeLoss(), l2Regularization: 0.001f); // Step 3: Train the pipeline created. var model = pipeline.Fit(data); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticGradientDescentNonCalibratedWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticGradientDescentNonCalibratedWithOptions.cs index 93ce6299dc..826f0a6bc7 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticGradientDescentNonCalibratedWithOptions.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticGradientDescentNonCalibratedWithOptions.cs @@ -25,7 +25,7 @@ public static void Example() .Trainers.SgdNonCalibrated( new SgdNonCalibratedTrainer.Options { - InitialLearningRate = 0.01, + LearningRate = 0.01, NumberOfIterations = 10, L2Regularization = 1e-7f } diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/AnomalyDetectionMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/AnomalyDetectionMetrics.cs index 6f1b01c991..8bb3a5d84b 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/AnomalyDetectionMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/AnomalyDetectionMetrics.cs @@ -33,13 +33,13 @@ public sealed class AnomalyDetectionMetrics /// Predicted Anomalies : TP | FP /// Predicted Non-Anomalies : FN | TN /// - public double DetectionRateAtKFalsePositives { get; } + public double DetectionRateAtFalsePositiveCount { get; } internal AnomalyDetectionMetrics(IExceptionContext ectx, DataViewRow overallResult) { double FetchDouble(string name) => RowCursorUtils.Fetch(ectx, overallResult, name); AreaUnderRocCurve = FetchDouble(BinaryClassifierEvaluator.Auc); - DetectionRateAtKFalsePositives = FetchDouble(AnomalyDetectionEvaluator.OverallMetrics.DrAtK); + DetectionRateAtFalsePositiveCount = FetchDouble(AnomalyDetectionEvaluator.OverallMetrics.DrAtK); } } } diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index 9b3c6f27a9..8cbf7b836a 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -673,7 +673,7 @@ internal AnomalyDetectionTrainers(AnomalyDetectionCatalog catalog) /// The name of the label column in . /// The name of the score column in . /// The name of the predicted label column in . - /// The number of false positives to compute the metric. + /// The number of false positives to compute the metric. /// Evaluation results. public AnomalyDetectionMetrics Evaluate(IDataView data, string labelColumnName = DefaultColumnNames.Label, string scoreColumnName = DefaultColumnNames.Score, string predictedLabelColumnName = DefaultColumnNames.PredictedLabel, int k = 10) diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs index 9e39208e36..02ab57a7fa 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -258,7 +258,7 @@ public static GamRegressionTrainer Gam(this RegressionCatalog.RegressionTrainers /// The name of the example weight column (optional). /// Total number of decision trees to create in the ensemble. /// The maximum number of leaves per decision tree. - /// The minimal number of data points required to form a new tree leaf. + /// The minimal number of data points required to form a new tree leaf. /// The learning rate. /// /// @@ -273,12 +273,12 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionCatalog.Regr string exampleWeightColumnName = null, int numberOfLeaves = Defaults.NumberOfLeaves, int numberOfTrees = Defaults.NumberOfTrees, - int minDatapointsInLeaves = Defaults.MinimumExampleCountPerLeaf, + int minimumExampleCountPerLeaf = Defaults.MinimumExampleCountPerLeaf, double learningRate = Defaults.LearningRate) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new FastTreeTweedieTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, numberOfLeaves, numberOfTrees, minDatapointsInLeaves, learningRate); + return new FastTreeTweedieTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, numberOfLeaves, numberOfTrees, minimumExampleCountPerLeaf, learningRate); } /// @@ -312,7 +312,7 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionCatalog.Regr /// The name of the example weight column (optional). /// The maximum number of leaves per decision tree. /// Total number of decision trees to create in the ensemble. - /// The minimal number of data points required to form a new tree leaf. + /// The minimal number of data points required to form a new tree leaf. /// /// /// @@ -364,7 +364,7 @@ public static FastForestRegressionTrainer FastForest(this RegressionCatalog.Regr /// The name of the example weight column (optional). /// Total number of decision trees to create in the ensemble. /// The maximum number of leaves per decision tree. - /// The minimal number of data points required to form a new tree leaf. + /// The minimal number of data points required to form a new tree leaf. /// /// /// diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LbfgsPredictorBase.cs index 4bb402f376..23762fce57 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -50,8 +50,7 @@ public abstract class OptionsBase : TrainerInputBaseWithWeight /// /// Number of previous iterations to remember for estimate of Hessian. /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Memory size for L-BFGS. Low=faster, less accurate", - ShortName = "m, MemorySize", SortOrder = 50)] + [Argument(ArgumentType.AtMostOnce, HelpText = "Memory size for L-BFGS. Low=faster, less accurate", ShortName = "m, MemorySize", SortOrder = 50)] [TGUI(Description = "Memory size for L-BFGS", SuggestedSweeps = "5,20,50")] [TlcModule.SweepableDiscreteParamAttribute("MemorySize", new object[] { 5, 20, 50 })] public int HistorySize = Defaults.HistorySize; diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LogisticRegression.cs index 3795fbafb7..50adf5b01f 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LogisticRegression.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LogisticRegression.cs @@ -69,23 +69,23 @@ public sealed class Options : OptionsBase /// The environment to use. /// The name of the label column. /// The name of the feature column. - /// The name for the example weight column. + /// The name for the example weight column. /// Enforce non-negative weights. - /// Weight of L1 regularizer term. - /// Weight of L2 regularizer term. + /// Weight of L1 regularizer term. + /// Weight of L2 regularizer term. /// Memory size for . Low=faster, less accurate. /// Threshold for optimizer convergence. internal LogisticRegressionBinaryTrainer(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, - string weights = null, - float l1Weight = Options.Defaults.L1Regularization, - float l2Weight = Options.Defaults.L2Regularization, + string exampleWeightColumnName = null, + float l1Regularization = Options.Defaults.L1Regularization, + float l2Regularization = Options.Defaults.L2Regularization, float optimizationTolerance = Options.Defaults.OptimizationTolerance, int memorySize = Options.Defaults.HistorySize, bool enforceNoNegativity = Options.Defaults.EnforceNonNegativity) - : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weights, - l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) + : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), exampleWeightColumnName, + l1Regularization, l2Regularization, optimizationTolerance, memorySize, enforceNoNegativity) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); diff --git a/src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs index f38722b42c..7116199185 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs @@ -168,10 +168,11 @@ public abstract class OptionsBase : TrainerInputBaseWithLabel /// /// The L1 regularization hyperparameter. /// - [Argument(ArgumentType.AtMostOnce, HelpText = "L1 soft threshold (L1/L2). Note that it is easier to control and sweep using the threshold parameter than the raw L1-regularizer constant. By default the l1 threshold is automatically inferred based on data set.", NullName = "", ShortName = "l1", SortOrder = 2)] + [Argument(ArgumentType.AtMostOnce, HelpText = "L1 soft threshold (L1/L2). Note that it is easier to control and sweep using the threshold parameter than the raw L1-regularizer constant. By default the l1 threshold is automatically inferred based on data set.", + NullName = "", Name = "L1Threshold", ShortName = "l1", SortOrder = 2)] [TGUI(Label = "L1 Soft Threshold", SuggestedSweeps = ",0,0.25,0.5,0.75,1")] [TlcModule.SweepableDiscreteParam("L1Threshold", new object[] { "", 0f, 0.25f, 0.5f, 0.75f, 1f })] - public float? L1Threshold; + public float? L1Regularization; /// /// The degree of lock-free parallelism. @@ -234,7 +235,7 @@ internal virtual void Check(IHostEnvironment env) { Contracts.AssertValue(env); env.CheckUserArg(L2Regularization == null || L2Regularization >= 0, nameof(L2Regularization), "L2 constant must be non-negative."); - env.CheckUserArg(L1Threshold == null || L1Threshold >= 0, nameof(L1Threshold), "L1 threshold must be non-negative."); + env.CheckUserArg(L1Regularization == null || L1Regularization >= 0, nameof(L1Regularization), "L1 threshold must be non-negative."); env.CheckUserArg(MaximumNumberOfIterations == null || MaximumNumberOfIterations > 0, nameof(MaximumNumberOfIterations), "Max number of iterations must be positive."); env.CheckUserArg(ConvergenceTolerance > 0 && ConvergenceTolerance <= 1, nameof(ConvergenceTolerance), "Convergence tolerance must be positive and no larger than 1."); @@ -302,7 +303,7 @@ internal SdcaTrainerBase(IHostEnvironment env, TOptions options, SchemaShape.Col { SdcaTrainerOptions = options; SdcaTrainerOptions.L2Regularization = l2Const ?? options.L2Regularization; - SdcaTrainerOptions.L1Threshold = l1Threshold ?? options.L1Threshold; + SdcaTrainerOptions.L1Regularization = l1Threshold ?? options.L1Regularization; SdcaTrainerOptions.MaximumNumberOfIterations = maxIterations ?? options.MaximumNumberOfIterations; SdcaTrainerOptions.Check(env); } @@ -450,11 +451,11 @@ private protected sealed override TModel TrainCore(IChannel ch, RoleMappedData d SdcaTrainerOptions.L2Regularization = TuneDefaultL2(ch, SdcaTrainerOptions.MaximumNumberOfIterations.Value, count, numThreads); Contracts.Assert(SdcaTrainerOptions.L2Regularization.HasValue); - if (SdcaTrainerOptions.L1Threshold == null) - SdcaTrainerOptions.L1Threshold = TuneDefaultL1(ch, numFeatures); + if (SdcaTrainerOptions.L1Regularization == null) + SdcaTrainerOptions.L1Regularization = TuneDefaultL1(ch, numFeatures); - ch.Assert(SdcaTrainerOptions.L1Threshold.HasValue); - var l1Threshold = SdcaTrainerOptions.L1Threshold.Value; + ch.Assert(SdcaTrainerOptions.L1Regularization.HasValue); + var l1Threshold = SdcaTrainerOptions.L1Regularization.Value; var l1ThresholdZero = l1Threshold == 0; var weights = new VBuffer[weightSetCount]; var bestWeights = new VBuffer[weightSetCount]; @@ -783,12 +784,12 @@ private protected virtual void TrainWithoutLock(IProgressChannelProvider progres VBuffer[] weights, float[] biasUnreg, VBuffer[] l1IntermediateWeights, float[] l1IntermediateBias, float[] featureNormSquared) { Contracts.AssertValueOrNull(progress); - Contracts.Assert(SdcaTrainerOptions.L1Threshold.HasValue); + Contracts.Assert(SdcaTrainerOptions.L1Regularization.HasValue); Contracts.AssertValueOrNull(idToIdx); Contracts.AssertValueOrNull(invariants); Contracts.AssertValueOrNull(featureNormSquared); int maxUpdateTrials = 2 * numThreads; - var l1Threshold = SdcaTrainerOptions.L1Threshold.Value; + var l1Threshold = SdcaTrainerOptions.L1Regularization.Value; bool l1ThresholdZero = l1Threshold == 0; var lr = SdcaTrainerOptions.BiasLearningRate * SdcaTrainerOptions.L2Regularization.Value; var pch = progress != null ? progress.StartProgressChannel("Dual update") : null; @@ -979,9 +980,9 @@ private protected virtual bool CheckConvergence( } Contracts.Assert(SdcaTrainerOptions.L2Regularization.HasValue); - Contracts.Assert(SdcaTrainerOptions.L1Threshold.HasValue); + Contracts.Assert(SdcaTrainerOptions.L1Regularization.HasValue); Double l2Const = SdcaTrainerOptions.L2Regularization.Value; - Double l1Threshold = SdcaTrainerOptions.L1Threshold.Value; + Double l1Threshold = SdcaTrainerOptions.L1Regularization.Value; Double l1Regularizer = l1Threshold * l2Const * (VectorUtils.L1Norm(in weights[0]) + Math.Abs(biasReg[0])); var l2Regularizer = l2Const * (VectorUtils.NormSquared(weights[0]) + biasReg[0] * biasReg[0]) * 0.5; var newLoss = lossSum.Sum / count + l2Regularizer + l1Regularizer; @@ -993,7 +994,7 @@ private protected virtual bool CheckConvergence( var dualityGap = metrics[(int)MetricKind.DualityGap] = newLoss - newDualLoss; metrics[(int)MetricKind.BiasUnreg] = biasUnreg[0]; metrics[(int)MetricKind.BiasReg] = biasReg[0]; - metrics[(int)MetricKind.L1Sparsity] = SdcaTrainerOptions.L1Threshold == 0 ? 1 : (Double)firstWeights.GetValues().Count(w => w != 0) / weights.Length; + metrics[(int)MetricKind.L1Sparsity] = SdcaTrainerOptions.L1Regularization == 0 ? 1 : (Double)firstWeights.GetValues().Count(w => w != 0) / weights.Length; bool converged = dualityGap / newLoss < SdcaTrainerOptions.ConvergenceTolerance; @@ -1811,9 +1812,9 @@ public class OptionsBase : TrainerInputBaseWithWeight /// /// The initial learning rate used by SGD. /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate (only used by SGD)", ShortName = "ilr,lr, InitLearningRate")] + [Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate (only used by SGD)", Name = "InitialLearningRate", ShortName = "ilr,lr,InitLearningRate")] [TGUI(Label = "Initial Learning Rate (for SGD)")] - public double InitialLearningRate = Defaults.InitialLearningRate; + public double LearningRate = Defaults.LearningRate; /// /// Determines whether to shuffle data for each training iteration. @@ -1848,16 +1849,16 @@ internal void Check(IHostEnvironment env) { Contracts.CheckValue(env, nameof(env)); env.CheckUserArg(L2Regularization >= 0, nameof(L2Regularization), "Must be non-negative."); - env.CheckUserArg(InitialLearningRate > 0, nameof(InitialLearningRate), "Must be positive."); + env.CheckUserArg(LearningRate > 0, nameof(LearningRate), "Must be positive."); env.CheckUserArg(NumberOfIterations > 0, nameof(NumberOfIterations), "Must be positive."); env.CheckUserArg(PositiveInstanceWeight > 0, nameof(PositiveInstanceWeight), "Must be positive"); - if (InitialLearningRate * L2Regularization >= 1) + if (LearningRate * L2Regularization >= 1) { using (var ch = env.Start("Argument Adjustment")) { - ch.Warning("{0} {1} set too high; reducing to {1}", nameof(InitialLearningRate), - InitialLearningRate, InitialLearningRate = (float)0.5 / L2Regularization); + ch.Warning("{0} {1} set too high; reducing to {1}", nameof(LearningRate), + LearningRate, LearningRate = (float)0.5 / L2Regularization); } } @@ -1870,7 +1871,7 @@ internal static class Defaults { public const float L2Regularization = 1e-6f; public const int NumberOfIterations = 20; - public const double InitialLearningRate = 0.01; + public const double LearningRate = 0.01; } } @@ -1901,7 +1902,7 @@ internal SgdBinaryTrainerBase(IHostEnvironment env, string weightColumn = null, IClassificationLoss loss = null, int maxIterations = OptionsBase.Defaults.NumberOfIterations, - double initLearningRate = OptionsBase.Defaults.InitialLearningRate, + double initLearningRate = OptionsBase.Defaults.LearningRate, float l2Weight = OptionsBase.Defaults.L2Regularization) : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weightColumn) { @@ -1910,7 +1911,7 @@ internal SgdBinaryTrainerBase(IHostEnvironment env, _options = new OptionsBase(); _options.NumberOfIterations = maxIterations; - _options.InitialLearningRate = initLearningRate; + _options.LearningRate = initLearningRate; _options.L2Regularization = l2Weight; _options.FeatureColumnName = featureColumn; @@ -2039,7 +2040,7 @@ private protected override TModel TrainCore(IChannel ch, RoleMappedData data, Li var trainingTasks = new Action[_options.NumberOfIterations]; var rands = new Random[_options.NumberOfIterations]; - var ilr = _options.InitialLearningRate; + var ilr = _options.LearningRate; long t = 0; for (int epoch = 1; epoch <= _options.NumberOfIterations; epoch++) { @@ -2190,7 +2191,7 @@ internal SgdCalibratedTrainer(IHostEnvironment env, string featureColumn = DefaultColumnNames.Features, string weightColumn = null, int maxIterations = Options.Defaults.NumberOfIterations, - double initLearningRate = Options.Defaults.InitialLearningRate, + double initLearningRate = Options.Defaults.LearningRate, float l2Weight = Options.Defaults.L2Regularization) : base(env, labelColumn, featureColumn, weightColumn, new LogLoss(), maxIterations, initLearningRate, l2Weight) { @@ -2247,7 +2248,7 @@ public sealed class Options : OptionsBase /// The loss function to use. Default is . /// [Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] - public IClassificationLoss Loss = new LogLoss(); + public IClassificationLoss LossFunction = new LogLoss(); } internal SgdNonCalibratedTrainer(IHostEnvironment env, @@ -2255,7 +2256,7 @@ internal SgdNonCalibratedTrainer(IHostEnvironment env, string featureColumn = DefaultColumnNames.Features, string weightColumn = null, int maxIterations = Options.Defaults.NumberOfIterations, - double initLearningRate = Options.Defaults.InitialLearningRate, + double initLearningRate = Options.Defaults.LearningRate, float l2Weight = Options.Defaults.L2Regularization, IClassificationLoss loss = null) : base(env, labelColumn, featureColumn, weightColumn, loss, maxIterations, initLearningRate, l2Weight) @@ -2268,7 +2269,7 @@ internal SgdNonCalibratedTrainer(IHostEnvironment env, /// The environment to use. /// Advanced arguments to the algorithm. internal SgdNonCalibratedTrainer(IHostEnvironment env, Options options) - : base(env, options, loss: options.Loss, doCalibration: false) + : base(env, options, loss: options.LossFunction, doCalibration: false) { } diff --git a/src/Microsoft.ML.StandardTrainers/Standard/SdcaMulticlass.cs b/src/Microsoft.ML.StandardTrainers/Standard/SdcaMulticlass.cs index 54c7748067..ad6f87d10e 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/SdcaMulticlass.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/SdcaMulticlass.cs @@ -127,7 +127,7 @@ private protected override void TrainWithoutLock(IProgressChannelProvider progre VBuffer[] weights, float[] biasUnreg, VBuffer[] l1IntermediateWeights, float[] l1IntermediateBias, float[] featureNormSquared) { Contracts.AssertValueOrNull(progress); - Contracts.Assert(SdcaTrainerOptions.L1Threshold.HasValue); + Contracts.Assert(SdcaTrainerOptions.L1Regularization.HasValue); Contracts.AssertValueOrNull(idToIdx); Contracts.AssertValueOrNull(invariants); Contracts.AssertValueOrNull(featureNormSquared); @@ -136,7 +136,7 @@ private protected override void TrainWithoutLock(IProgressChannelProvider progre Contracts.Assert(Utils.Size(biasUnreg) == numClasses); int maxUpdateTrials = 2 * numThreads; - var l1Threshold = SdcaTrainerOptions.L1Threshold.Value; + var l1Threshold = SdcaTrainerOptions.L1Regularization.Value; bool l1ThresholdZero = l1Threshold == 0; var lr = SdcaTrainerOptions.BiasLearningRate * SdcaTrainerOptions.L2Regularization.Value; @@ -358,9 +358,9 @@ private protected override bool CheckConvergence( } Contracts.Assert(SdcaTrainerOptions.L2Regularization.HasValue); - Contracts.Assert(SdcaTrainerOptions.L1Threshold.HasValue); + Contracts.Assert(SdcaTrainerOptions.L1Regularization.HasValue); Double l2Const = SdcaTrainerOptions.L2Regularization.Value; - Double l1Threshold = SdcaTrainerOptions.L1Threshold.Value; + Double l1Threshold = SdcaTrainerOptions.L1Regularization.Value; Double weightsL1Norm = 0; Double weightsL2NormSquared = 0; @@ -372,7 +372,7 @@ private protected override bool CheckConvergence( biasRegularizationAdjustment += biasReg[iClass] * biasUnreg[iClass]; } - Double l1Regularizer = SdcaTrainerOptions.L1Threshold.Value * l2Const * weightsL1Norm; + Double l1Regularizer = SdcaTrainerOptions.L1Regularization.Value * l2Const * weightsL1Norm; var l2Regularizer = l2Const * weightsL2NormSquared * 0.5; var newLoss = lossSum.Sum / count + l2Regularizer + l1Regularizer; @@ -384,7 +384,7 @@ private protected override bool CheckConvergence( metrics[(int)MetricKind.DualityGap] = dualityGap; metrics[(int)MetricKind.BiasUnreg] = biasUnreg[0]; metrics[(int)MetricKind.BiasReg] = biasReg[0]; - metrics[(int)MetricKind.L1Sparsity] = SdcaTrainerOptions.L1Threshold == 0 ? 1 : weights.Sum( + metrics[(int)MetricKind.L1Sparsity] = SdcaTrainerOptions.L1Regularization == 0 ? 1 : weights.Sum( weight => weight.GetValues().Count(w => w != 0)) / (numClasses * numFeatures); bool converged = dualityGap / newLoss < SdcaTrainerOptions.ConvergenceTolerance; diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 09ac97c694..dde848a0bd 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -25,7 +25,7 @@ public static class StandardTrainersCatalog /// The features, or independent variables. /// The name of the example weight column (optional). /// The maximum number of passes through the training dataset; set to 1 to simulate online learning. - /// The initial learning rate used by SGD. + /// The initial learning rate used by SGD. /// The L2 weight for regularization. /// /// @@ -39,13 +39,13 @@ public static SgdCalibratedTrainer SgdCalibrated(this BinaryClassificationCatalo string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, int numberOfIterations = SgdCalibratedTrainer.Options.Defaults.NumberOfIterations, - double initialLearningRate = SgdCalibratedTrainer.Options.Defaults.InitialLearningRate, + double learningRate = SgdCalibratedTrainer.Options.Defaults.LearningRate, float l2Regularization = SgdCalibratedTrainer.Options.Defaults.L2Regularization) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); return new SgdCalibratedTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, - numberOfIterations, initialLearningRate, l2Regularization); + numberOfIterations, learningRate, l2Regularization); } /// @@ -79,9 +79,9 @@ public static SgdCalibratedTrainer SgdCalibrated(this BinaryClassificationCatalo /// The name of the label column, or dependent variable. /// The features, or independent variables. /// The name of the example weight column (optional). - /// The loss function minimized in the training process. Using, for example, leads to a support vector machine trainer. + /// The loss function minimized in the training process. Using, for example, leads to a support vector machine trainer. /// The maximum number of passes through the training dataset; set to 1 to simulate online learning. - /// The initial learning rate used by SGD. + /// The initial learning rate used by SGD. /// The L2 weight for regularization. /// /// @@ -94,15 +94,15 @@ public static SgdNonCalibratedTrainer SgdNonCalibrated(this BinaryClassification string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, - IClassificationLoss loss = null, + IClassificationLoss lossFunction = null, int numberOfIterations = SgdNonCalibratedTrainer.Options.Defaults.NumberOfIterations, - double initialLearningRate = SgdNonCalibratedTrainer.Options.Defaults.InitialLearningRate, + double learningRate = SgdNonCalibratedTrainer.Options.Defaults.LearningRate, float l2Regularization = SgdNonCalibratedTrainer.Options.Defaults.L2Regularization) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); return new SgdNonCalibratedTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, - numberOfIterations, initialLearningRate, l2Regularization, loss); + numberOfIterations, learningRate, l2Regularization, lossFunction); } /// @@ -135,10 +135,10 @@ public static SgdNonCalibratedTrainer SgdNonCalibrated(this BinaryClassification /// The name of the label column. /// The name of the feature column. /// The name of the example weight column (optional). - /// The L2 regularization hyperparameter. - /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The loss function minimized in the training process. Using, for example, its default leads to a least square trainer. + /// The L2 weight for regularization. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// The custom loss, if unspecified will be . /// /// /// @@ -187,8 +187,8 @@ public static SdcaRegressionTrainer Sdca(this RegressionCatalog.RegressionTraine /// The name of the label column. /// The name of the feature column. /// The name of the example weight column (optional). - /// The L2 regularization hyperparameter. - /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The L2 weight for regularization. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. /// /// @@ -202,12 +202,12 @@ public static SdcaCalibratedBinaryTrainer SdcaCalibrated( string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, float? l2Regularization = null, - float? l1Threshold = null, + float? l1Regularization = null, int? maximumNumberOfIterations = null) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new SdcaCalibratedBinaryTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l2Regularization, l1Threshold, maximumNumberOfIterations); + return new SdcaCalibratedBinaryTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l2Regularization, l1Regularization, maximumNumberOfIterations); } /// @@ -239,9 +239,9 @@ public static SdcaCalibratedBinaryTrainer SdcaCalibrated( /// The name of the label column. /// The name of the feature column. /// The name of the example weight column (optional). - /// The custom loss. Defaults to if not specified. - /// The L2 regularization hyperparameter. - /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The loss function minimized in the training process. Defaults to if not specified. + /// The L2 weight for regularization. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. /// /// @@ -254,14 +254,14 @@ public static SdcaNonCalibratedBinaryTrainer SdcaNonCalibrated( string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, - ISupportSdcaClassificationLoss loss = null, + ISupportSdcaClassificationLoss lossFunction = null, float? l2Regularization = null, - float? l1Threshold = null, + float? l1Regularization = null, int? maximumNumberOfIterations = null) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new SdcaNonCalibratedBinaryTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, loss, l2Regularization, l1Threshold, maximumNumberOfIterations); + return new SdcaNonCalibratedBinaryTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, lossFunction, l2Regularization, l1Regularization, maximumNumberOfIterations); } /// @@ -287,8 +287,8 @@ public static SdcaNonCalibratedBinaryTrainer SdcaNonCalibrated( /// The name of the label column. /// The name of the feature column. /// The name of the example weight column (optional). - /// The L2 regularization hyperparameter. - /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The L2 weight for regularization. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. /// /// @@ -301,12 +301,12 @@ public static SdcaCalibratedMulticlassTrainer SdcaCalibrated(this MulticlassClas string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, float? l2Regularization = null, - float? l1Threshold = null, + float? l1Regularization = null, int? maximumNumberOfIterations = null) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new SdcaCalibratedMulticlassTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l2Regularization, l1Threshold, maximumNumberOfIterations); + return new SdcaCalibratedMulticlassTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l2Regularization, l1Regularization, maximumNumberOfIterations); } /// @@ -337,9 +337,9 @@ public static SdcaCalibratedMulticlassTrainer SdcaCalibrated(this MulticlassClas /// The name of the label column. /// The name of the feature column. /// The name of the example weight column (optional). - /// Loss function to be minimized. Defaults to if not specified. - /// The L2 regularization hyperparameter. - /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The loss function to be minimized. Defaults to if not specified. + /// The L2 weight for regularization. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. /// /// @@ -351,14 +351,14 @@ public static SdcaNonCalibratedMulticlassTrainer SdcaNonCalibrated(this Multicla string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, - ISupportSdcaClassificationLoss loss = null, + ISupportSdcaClassificationLoss lossFunction = null, float? l2Regularization = null, - float? l1Threshold = null, + float? l1Regularization = null, int? maximumNumberOfIterations = null) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new SdcaNonCalibratedMulticlassTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, loss, l2Regularization, l1Threshold, maximumNumberOfIterations); + return new SdcaNonCalibratedMulticlassTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, lossFunction, l2Regularization, l1Regularization, maximumNumberOfIterations); } /// @@ -388,8 +388,8 @@ public static SdcaNonCalibratedMulticlassTrainer SdcaNonCalibrated(this Multicla /// The binary classification catalog trainer object. /// The name of the label column. /// The name of the feature column. - /// A custom loss. If , hinge loss will be used resulting in max-margin averaged perceptron. - /// Learning rate. + /// The loss function minimized in the training process. If , would be used and lead to a max-margin averaged perceptron trainer. + /// The initial learning rate used by SGD. /// /// to decrease the as iterations progress; otherwise, . /// Default is . @@ -462,8 +462,8 @@ public IClassificationLoss CreateComponent(IHostEnvironment env) /// The regression catalog trainer object. /// The name of the label column. /// The name of the feature column. - /// The custom loss. Defaults to if not provided. - /// The learning Rate. + /// The loss function minimized in the training process. Using, for example, leads to a least square trainer. + /// The initial learning rate used by SGD. /// Decrease learning rate as iterations progress. /// The L2 weight for regularization. /// Number of training iterations through the data. @@ -505,8 +505,8 @@ public static OnlineGradientDescentTrainer OnlineGradientDescent(this Regression /// The name of the feature column. /// The name of the example weight column (optional). /// Enforce non-negative weights. - /// Weight of L1 regularization term. - /// Weight of L2 regularization term. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The L2 weight for regularization. /// Memory size for . Low=faster, less accurate. /// Threshold for optimizer convergence. /// @@ -552,8 +552,8 @@ public static LogisticRegressionBinaryTrainer LogisticRegression(this BinaryClas /// The name of the label column. /// The name of the feature column. /// The name of the example weight column (optional). - /// Weight of L1 regularization term. - /// Weight of L2 regularization term. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The L2 weight for regularization. /// Threshold for optimizer convergence. /// Memory size for . Low=faster, less accurate. /// Enforce non-negative weights. @@ -594,8 +594,8 @@ public static PoissonRegressionTrainer PoissonRegression(this RegressionCatalog. /// The name of the feature column. /// The name of the example weight column (optional). /// Enforce non-negative weights. - /// Weight of L1 regularization term. - /// Weight of L2 regularization term. + /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. + /// The L2 weight for regularization. /// Memory size for . Low=faster, less accurate. /// Threshold for optimizer convergence. public static LbfgsMaximumEntropyTrainer LbfgsMaximumEntropy(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, @@ -678,7 +678,7 @@ private static ICalibratorTrainer GetCalibratorTrainerOrThrow(IExceptionContext /// The calibrator. If a calibrator is not explicitely provided, it will default to /// The name of the label colum. /// Whether to treat missing labels as having negative labels, instead of keeping them missing. - /// Number of instances to train the calibrator. + /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. /// The type of the model. This type parameter will usually be inferred automatically from . public static OneVersusAllTrainer OneVersusAll(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, @@ -686,7 +686,7 @@ public static OneVersusAllTrainer OneVersusAll(this MulticlassClassifica string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, IEstimator> calibrator = null, - int maxCalibrationExamples = 1000000000, + int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) where TModel : class { @@ -694,7 +694,7 @@ public static OneVersusAllTrainer OneVersusAll(this MulticlassClassifica var env = CatalogUtils.GetEnvironment(catalog); if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); - return new OneVersusAllTrainer(env, est, labelColumnName, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maxCalibrationExamples, useProbabilities); + return new OneVersusAllTrainer(env, est, labelColumnName, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maximumCalibrationExampleCount, useProbabilities); } /// diff --git a/src/Microsoft.ML.StaticPipe/EvaluatorStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/EvaluatorStaticExtensions.cs index b46c41289f..33cefe6e07 100644 --- a/src/Microsoft.ML.StaticPipe/EvaluatorStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/EvaluatorStaticExtensions.cs @@ -176,14 +176,14 @@ private sealed class TrivialRegressionLossFactory : ISupportRegressionLossFactor /// The data to evaluate. /// The index delegate for the label column. /// The index delegate for predicted score column. - /// Potentially custom loss function. If left unspecified defaults to . + /// Potentially custom loss function. If left unspecified defaults to . /// The evaluation metrics. public static RegressionMetrics Evaluate( this RegressionCatalog catalog, DataView data, Func> label, Func> score, - IRegressionLoss loss = null) + IRegressionLoss lossFunction = null) { Contracts.CheckValue(data, nameof(data)); var env = StaticPipeUtils.GetEnvironment(data); @@ -196,8 +196,8 @@ public static RegressionMetrics Evaluate( string scoreName = indexer.Get(score(indexer.Indices)); var args = new RegressionEvaluator.Arguments() { }; - if (loss != null) - args.LossFunction = new TrivialRegressionLossFactory(loss); + if (lossFunction != null) + args.LossFunction = new TrivialRegressionLossFactory(lossFunction); return new RegressionEvaluator(env, args).Evaluate(data.AsDynamic, labelName, scoreName); } diff --git a/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs index 88f8dcda91..326dd54dec 100644 --- a/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/SdcaStaticExtensions.cs @@ -24,7 +24,7 @@ public static class SdcaStaticExtensions /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// The custom loss, if unspecified will be . + /// The custom loss, if unspecified will be . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -42,7 +42,7 @@ public static Scalar Sdca(this RegressionCatalog.RegressionTrainers catal float? l2Regularization = null, float? l1Threshold = null, int? numberOfIterations = null, - ISupportSdcaRegressionLoss loss = null, + ISupportSdcaRegressionLoss lossFunction = null, Action onFit = null) { Contracts.CheckValue(label, nameof(label)); @@ -51,13 +51,13 @@ public static Scalar Sdca(this RegressionCatalog.RegressionTrainers catal Contracts.CheckParam(!(l2Regularization < 0), nameof(l2Regularization), "Must not be negative, if specified."); Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative, if specified."); Contracts.CheckParam(!(numberOfIterations < 1), nameof(numberOfIterations), "Must be positive if specified"); - Contracts.CheckValueOrNull(loss); + Contracts.CheckValueOrNull(lossFunction); Contracts.CheckValueOrNull(onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => { - var trainer = new SdcaRegressionTrainer(env, labelName, featuresName, weightsName, loss, l2Regularization, l1Threshold, numberOfIterations); + var trainer = new SdcaRegressionTrainer(env, labelName, featuresName, weightsName, lossFunction, l2Regularization, l1Threshold, numberOfIterations); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -229,7 +229,7 @@ public static (Scalar score, Scalar probability, Scalar pred /// The binary classification catalog trainer object. /// The label, or dependent variable. /// The features, or independent variables. - /// The custom loss. + /// The custom loss. /// The optional example weights. /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. @@ -244,7 +244,7 @@ public static (Scalar score, Scalar probability, Scalar pred public static (Scalar score, Scalar predictedLabel) SdcaNonCalibrated( this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, Scalar label, Vector features, - ISupportSdcaClassificationLoss loss, + ISupportSdcaClassificationLoss lossFunction, Scalar weights = null, float? l2Regularization = null, float? l1Threshold = null, @@ -253,7 +253,7 @@ public static (Scalar score, Scalar predictedLabel) SdcaNonCalibrat { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); - Contracts.CheckValue(loss, nameof(loss)); + Contracts.CheckValue(lossFunction, nameof(lossFunction)); Contracts.CheckValueOrNull(weights); Contracts.CheckParam(!(l2Regularization < 0), nameof(l2Regularization), "Must not be negative, if specified."); Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative, if specified."); @@ -263,7 +263,7 @@ public static (Scalar score, Scalar predictedLabel) SdcaNonCalibrat var rec = new TrainerEstimatorReconciler.BinaryClassifierNoCalibration( (env, labelName, featuresName, weightsName) => { - var trainer = new SdcaNonCalibratedBinaryTrainer(env, labelName, featuresName, weightsName, loss, l2Regularization, l1Threshold, numberOfIterations); + var trainer = new SdcaNonCalibratedBinaryTrainer(env, labelName, featuresName, weightsName, lossFunction, l2Regularization, l1Threshold, numberOfIterations); if (onFit != null) { return trainer.WithOnFitDelegate(trans => @@ -285,7 +285,7 @@ public static (Scalar score, Scalar predictedLabel) SdcaNonCalibrat /// The binary classification catalog trainer object. /// The label, or dependent variable. /// The features, or independent variables. - /// The custom loss. + /// The custom loss. /// The optional example weights. /// Advanced arguments to the algorithm. /// A delegate that is called every time the @@ -298,7 +298,7 @@ public static (Scalar score, Scalar predictedLabel) SdcaNonCalibrat public static (Scalar score, Scalar predictedLabel) SdcaNonCalibrated( this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, Scalar label, Vector features, Scalar weights, - ISupportSdcaClassificationLoss loss, + ISupportSdcaClassificationLoss lossFunction, SdcaNonCalibratedBinaryTrainer.Options options, Action onFit = null) { @@ -423,7 +423,7 @@ public static (Vector score, Key predictedLabel) Sdca( /// The multiclass classification catalog trainer object. /// The label, or dependent variable. /// The features, or independent variables. - /// The custom loss, for example, . + /// The custom loss, for example, . /// The optional example weights. /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. @@ -438,7 +438,7 @@ public static (Vector score, Key predictedLabel) SdcaNonCalib this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, Key label, Vector features, - ISupportSdcaClassificationLoss loss, + ISupportSdcaClassificationLoss lossFunction, Scalar weights = null, float? l2Regularization = null, float? l1Threshold = null, @@ -447,7 +447,7 @@ public static (Vector score, Key predictedLabel) SdcaNonCalib { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); - Contracts.CheckValueOrNull(loss); + Contracts.CheckValueOrNull(lossFunction); Contracts.CheckValueOrNull(weights); Contracts.CheckParam(!(l2Regularization < 0), nameof(l2Regularization), "Must not be negative, if specified."); Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative, if specified."); @@ -457,7 +457,7 @@ public static (Vector score, Key predictedLabel) SdcaNonCalib var rec = new TrainerEstimatorReconciler.MulticlassClassificationReconciler( (env, labelName, featuresName, weightsName) => { - var trainer = new SdcaNonCalibratedMulticlassTrainer(env, labelName, featuresName, weightsName, loss, l2Regularization, l1Threshold, numberOfIterations); + var trainer = new SdcaNonCalibratedMulticlassTrainer(env, labelName, featuresName, weightsName, lossFunction, l2Regularization, l1Threshold, numberOfIterations); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; diff --git a/src/Microsoft.ML.StaticPipe/SgdStatic.cs b/src/Microsoft.ML.StaticPipe/SgdStatic.cs index 3765238597..e228a064ef 100644 --- a/src/Microsoft.ML.StaticPipe/SgdStatic.cs +++ b/src/Microsoft.ML.StaticPipe/SgdStatic.cs @@ -21,8 +21,8 @@ public static class SgdStaticExtensions /// The name of the feature column. /// The name for the example weight column. /// The maximum number of iterations; set to 1 to simulate online learning. - /// The initial learning rate used by SGD. - /// The L2 regularization constant. + /// The initial learning rate used by SGD. + /// The L2 weight for regularization. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -35,14 +35,14 @@ public static (Scalar score, Scalar probability, Scalar pred Vector features, Scalar weights = null, int numberOfIterations = SgdCalibratedTrainer.Options.Defaults.NumberOfIterations, - double initialLearningRate = SgdCalibratedTrainer.Options.Defaults.InitialLearningRate, + double learningRate = SgdCalibratedTrainer.Options.Defaults.LearningRate, float l2Regularization = SgdCalibratedTrainer.Options.Defaults.L2Regularization, Action> onFit = null) { var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => { - var trainer = new SgdCalibratedTrainer(env, labelName, featuresName, weightsName, numberOfIterations, initialLearningRate, l2Regularization); + var trainer = new SgdCalibratedTrainer(env, labelName, featuresName, weightsName, numberOfIterations, learningRate, l2Regularization); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); @@ -101,9 +101,9 @@ public static (Scalar score, Scalar probability, Scalar pred /// The name of the feature column. /// The name for the example weight column. /// The maximum number of iterations; set to 1 to simulate online learning. - /// The initial learning rate used by SGD. - /// The L2 regularization constant. - /// The loss function to use. + /// The initial learning rate used by SGD. + /// The L2 weight for regularization. + /// The loss function to use. /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -116,16 +116,16 @@ public static (Scalar score, Scalar predictedLabel) StochasticGradi Vector features, Scalar weights = null, int numberOfIterations = SgdNonCalibratedTrainer.Options.Defaults.NumberOfIterations, - double initialLearningRate = SgdNonCalibratedTrainer.Options.Defaults.InitialLearningRate, + double learningRate = SgdNonCalibratedTrainer.Options.Defaults.LearningRate, float l2Regularization = SgdNonCalibratedTrainer.Options.Defaults.L2Regularization, - IClassificationLoss loss = null, + IClassificationLoss lossFunction = null, Action onFit = null) { var rec = new TrainerEstimatorReconciler.BinaryClassifierNoCalibration( (env, labelName, featuresName, weightsName) => { var trainer = new SgdNonCalibratedTrainer(env, labelName, featuresName, weightsName, - numberOfIterations, initialLearningRate, l2Regularization, loss); + numberOfIterations, learningRate, l2Regularization, lossFunction); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); diff --git a/test/Microsoft.ML.Functional.Tests/Common.cs b/test/Microsoft.ML.Functional.Tests/Common.cs index 6b6722cf5d..648f01351f 100644 --- a/test/Microsoft.ML.Functional.Tests/Common.cs +++ b/test/Microsoft.ML.Functional.Tests/Common.cs @@ -166,7 +166,7 @@ public static void AssertEqual(TypeTestData testType1, TypeTestData testType2) public static void AssertMetrics(AnomalyDetectionMetrics metrics) { Assert.InRange(metrics.AreaUnderRocCurve, 0, 1); - Assert.InRange(metrics.DetectionRateAtKFalsePositives, 0, 1); + Assert.InRange(metrics.DetectionRateAtFalsePositiveCount, 0, 1); } /// diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 25898a5702..171cfbb0af 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -434,7 +434,7 @@ public void SdcaMulticlassSvm() .Append(r => (r.label, preds: catalog.Trainers.SdcaNonCalibrated( r.label, r.features, - loss: new HingeLoss(), + lossFunction: new HingeLoss(), numberOfIterations: 2, onFit: p => pred = p))); @@ -1132,7 +1132,7 @@ public void HogwildSGDSupportVectorMachine() var est = reader.MakeNewEstimator() .Append(r => (r.label, preds: catalog.Trainers.StochasticGradientDescentNonCalibratedClassificationTrainer(r.label, r.features, null, - new SgdNonCalibratedTrainer.Options { L2Regularization = 0, NumberOfThreads = 1, Loss = new HingeLoss()}, + new SgdNonCalibratedTrainer.Options { L2Regularization = 0, NumberOfThreads = 1, LossFunction = new HingeLoss()}, onFit: (p) => { pred = p; }))); var pipe = reader.Append(est); @@ -1167,7 +1167,7 @@ public void HogwildSGDSupportVectorMachineSimple() LinearBinaryModelParameters pred = null; var est = reader.MakeNewEstimator() - .Append(r => (r.label, preds: catalog.Trainers.StochasticGradientDescentNonCalibratedClassificationTrainer(r.label, r.features, loss: new HingeLoss(), onFit: (p) => { pred = p; }))); + .Append(r => (r.label, preds: catalog.Trainers.StochasticGradientDescentNonCalibratedClassificationTrainer(r.label, r.features, lossFunction: new HingeLoss(), onFit: (p) => { pred = p; }))); var pipe = reader.Append(est); diff --git a/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs b/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs index 8b94d289fc..10973ad109 100644 --- a/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs +++ b/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs @@ -33,7 +33,7 @@ public void RandomizedPcaTrainerBaselineTest() var metrics = ML.AnomalyDetection.Evaluate(transformedData, k: 5); Assert.Equal(0.98667, metrics.AreaUnderRocCurve, 5); - Assert.Equal(0.90000, metrics.DetectionRateAtKFalsePositives, 5); + Assert.Equal(0.90000, metrics.DetectionRateAtFalsePositiveCount, 5); } /// diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 75b17d881c..0ec033b4c5 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -25,7 +25,7 @@ public void OVAWithAllConstructorArgs() new AveragedPerceptronTrainer.Options { Shuffle = true }); var ova = ML.MulticlassClassification.Trainers.OneVersusAll(averagePerceptron, imputeMissingLabelsAsNegative: true, - calibrator: calibrator, maxCalibrationExamples: 10000, useProbabilities: true); + calibrator: calibrator, maximumCalibrationExampleCount: 10000, useProbabilities: true); pipeline = pipeline.Append(ova) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs index 9d4ce78811..653805f130 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs @@ -108,7 +108,7 @@ public void SdcaSupportVectorMachine() // Step 2: Create a binary classifier. // We set the "Label" column as the label of the dataset, and the "Features" column as the features column. var pipeline = mlContext.BinaryClassification.Trainers.SdcaNonCalibrated( - labelColumnName: "Label", featureColumnName: "Features", loss: new HingeLoss(), l2Regularization: 0.001f); + labelColumnName: "Label", featureColumnName: "Features", lossFunction: new HingeLoss(), l2Regularization: 0.001f); // Step 3: Train the pipeline created. var model = pipeline.Fit(data); @@ -185,7 +185,7 @@ public void SdcaMulticlassSupportVectorMachine() // Step 2: Create a binary classifier. // We set the "Label" column as the label of the dataset, and the "Features" column as the features column. var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelIndex", "Label"). - Append(mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated(labelColumnName: "LabelIndex", featureColumnName: "Features", loss: new HingeLoss(), l2Regularization: 0.001f)); + Append(mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated(labelColumnName: "LabelIndex", featureColumnName: "Features", lossFunction: new HingeLoss(), l2Regularization: 0.001f)); // Step 3: Train the pipeline created. var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs index d891a55c39..489d557ebe 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs @@ -120,8 +120,8 @@ public void TestEstimatorHogwildSGD() [Fact] public void TestEstimatorHogwildSGDNonCalibrated() { - var trainers = new[] { ML.BinaryClassification.Trainers.SgdNonCalibrated(loss : new SmoothedHingeLoss()), - ML.BinaryClassification.Trainers.SgdNonCalibrated(new Trainers.SgdNonCalibratedTrainer.Options() { Loss = new HingeLoss() }) }; + var trainers = new[] { ML.BinaryClassification.Trainers.SgdNonCalibrated(lossFunction : new SmoothedHingeLoss()), + ML.BinaryClassification.Trainers.SgdNonCalibrated(new Trainers.SgdNonCalibratedTrainer.Options() { LossFunction = new HingeLoss() }) }; foreach (var trainer in trainers) { From 5b22420d28c0cacc9b265d043555b6d11a017b91 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 20 Mar 2019 22:41:28 -0700 Subject: [PATCH 17/18] Polish train catalog (renaming only) (#3030) --- .../Metrics/MulticlassClassificationMetrics.cs | 16 ++++++++-------- src/Microsoft.ML.Data/TrainCatalog.cs | 15 ++++++++------- .../EvaluatorStaticExtensions.cs | 10 +++++----- .../PermutationFeatureImportanceExtensions.cs | 6 +++--- test/Microsoft.ML.Tests/AnomalyDetectionTests.cs | 2 +- test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs | 2 +- .../Scenarios/IrisPlantClassificationTests.cs | 2 +- ...risPlantClassificationWithStringLabelTests.cs | 2 +- .../TrainerEstimators/SdcaTests.cs | 4 ++-- 9 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs index 446c839775..843dc00284 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs @@ -60,15 +60,15 @@ public sealed class MulticlassClassificationMetrics public double MicroAccuracy { get; } /// - /// If is positive, this is the relative number of examples where + /// If is positive, this is the relative number of examples where /// the true label is one of the top-k predicted labels by the predictor. /// public double TopKAccuracy { get; } /// - /// If positive, this is the top-K for which the is calculated. + /// If positive, this indicates the K in . /// - public int TopK { get; } + public int TopKPredictionCount { get; } /// /// Gets the log-loss of the classifier for each class. @@ -82,15 +82,15 @@ public sealed class MulticlassClassificationMetrics /// public IReadOnlyList PerClassLogLoss { get; } - internal MulticlassClassificationMetrics(IExceptionContext ectx, DataViewRow overallResult, int topK) + internal MulticlassClassificationMetrics(IExceptionContext ectx, DataViewRow overallResult, int topKPredictionCount) { double FetchDouble(string name) => RowCursorUtils.Fetch(ectx, overallResult, name); MicroAccuracy = FetchDouble(MulticlassClassificationEvaluator.AccuracyMicro); MacroAccuracy = FetchDouble(MulticlassClassificationEvaluator.AccuracyMacro); LogLoss = FetchDouble(MulticlassClassificationEvaluator.LogLoss); LogLossReduction = FetchDouble(MulticlassClassificationEvaluator.LogLossReduction); - TopK = topK; - if (topK > 0) + TopKPredictionCount = topKPredictionCount; + if (topKPredictionCount > 0) TopKAccuracy = FetchDouble(MulticlassClassificationEvaluator.TopKAccuracy); var perClassLogLoss = RowCursorUtils.Fetch>(ectx, overallResult, MulticlassClassificationEvaluator.PerClassLogLoss); @@ -98,13 +98,13 @@ internal MulticlassClassificationMetrics(IExceptionContext ectx, DataViewRow ove } internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction, - int topK, double topKAccuracy, double[] perClassLogLoss) + int topKPredictionCount, double topKAccuracy, double[] perClassLogLoss) { MicroAccuracy = accuracyMicro; MacroAccuracy = accuracyMacro; LogLoss = logLoss; LogLossReduction = logLossReduction; - TopK = topK; + TopKPredictionCount = topKPredictionCount; TopKAccuracy = topKAccuracy; PerClassLogLoss = perClassLogLoss.ToImmutableArray(); } diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index 8cbf7b836a..dbff614330 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -484,21 +484,22 @@ internal MulticlassClassificationTrainers(MulticlassClassificationCatalog catalo /// The name of the label column in . /// The name of the score column in . /// The name of the predicted label column in . - /// If given a positive value, the will be filled with + /// If given a positive value, the will be filled with /// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within /// the top-K values as being stored "correctly." /// The evaluation results for these calibrated outputs. public MulticlassClassificationMetrics Evaluate(IDataView data, string labelColumnName = DefaultColumnNames.Label, string scoreColumnName = DefaultColumnNames.Score, - string predictedLabelColumnName = DefaultColumnNames.PredictedLabel, int topK = 0) + string predictedLabelColumnName = DefaultColumnNames.PredictedLabel, int topKPredictionCount = 0) { Environment.CheckValue(data, nameof(data)); Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName)); Environment.CheckNonEmpty(scoreColumnName, nameof(scoreColumnName)); Environment.CheckNonEmpty(predictedLabelColumnName, nameof(predictedLabelColumnName)); + Environment.CheckUserArg(topKPredictionCount >= 0, nameof(topKPredictionCount), "Must be non-negative"); var args = new MulticlassClassificationEvaluator.Arguments() { }; - if (topK > 0) - args.OutputTopKAcc = topK; + if (topKPredictionCount > 0) + args.OutputTopKAcc = topKPredictionCount; var eval = new MulticlassClassificationEvaluator(Environment, args); return eval.Evaluate(data, labelColumnName, scoreColumnName, predictedLabelColumnName); } @@ -673,10 +674,10 @@ internal AnomalyDetectionTrainers(AnomalyDetectionCatalog catalog) /// The name of the label column in . /// The name of the score column in . /// The name of the predicted label column in . - /// The number of false positives to compute the metric. + /// The number of false positives to compute the metric. /// Evaluation results. public AnomalyDetectionMetrics Evaluate(IDataView data, string labelColumnName = DefaultColumnNames.Label, string scoreColumnName = DefaultColumnNames.Score, - string predictedLabelColumnName = DefaultColumnNames.PredictedLabel, int k = 10) + string predictedLabelColumnName = DefaultColumnNames.PredictedLabel, int falsePositiveCount = 10) { Environment.CheckValue(data, nameof(data)); Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName)); @@ -684,7 +685,7 @@ public AnomalyDetectionMetrics Evaluate(IDataView data, string labelColumnName = Environment.CheckNonEmpty(predictedLabelColumnName, nameof(predictedLabelColumnName)); var args = new AnomalyDetectionEvaluator.Arguments(); - args.K = k; + args.K = falsePositiveCount; var eval = new AnomalyDetectionEvaluator(Environment, args); return eval.Evaluate(data, labelColumnName, scoreColumnName, predictedLabelColumnName); diff --git a/src/Microsoft.ML.StaticPipe/EvaluatorStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/EvaluatorStaticExtensions.cs index 33cefe6e07..7c803f7d05 100644 --- a/src/Microsoft.ML.StaticPipe/EvaluatorStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/EvaluatorStaticExtensions.cs @@ -127,7 +127,7 @@ public static ClusteringMetrics Evaluate( /// The index delegate for the label column. /// The index delegate for columns from the prediction of a multiclass classifier. /// Under typical scenarios, this will just be the same tuple of results returned from the trainer. - /// If given a positive value, the will be filled with + /// If given a positive value, the will be filled with /// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within /// the top-K values as being stored "correctly." /// The evaluation metrics. @@ -136,14 +136,14 @@ public static MulticlassClassificationMetrics Evaluate( DataView data, Func> label, Func score, Key predictedLabel)> pred, - int topK = 0) + int topKPredictionCount = 0) { Contracts.CheckValue(data, nameof(data)); var env = StaticPipeUtils.GetEnvironment(data); Contracts.AssertValue(env); env.CheckValue(label, nameof(label)); env.CheckValue(pred, nameof(pred)); - env.CheckParam(topK >= 0, nameof(topK), "Must not be negative."); + env.CheckParam(topKPredictionCount >= 0, nameof(topKPredictionCount), "Must not be negative."); var indexer = StaticPipeUtils.GetIndexer(data); string labelName = indexer.Get(label(indexer.Indices)); @@ -154,8 +154,8 @@ public static MulticlassClassificationMetrics Evaluate( string predName = indexer.Get(predCol); var args = new MulticlassClassificationEvaluator.Arguments() { }; - if (topK > 0) - args.OutputTopKAcc = topK; + if (topKPredictionCount > 0) + args.OutputTopKAcc = topKPredictionCount; var eval = new MulticlassClassificationEvaluator(env, args); return eval.Evaluate(data.AsDynamic, labelName, scoreName, predName); diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs index e356769ac5..9e111dac5a 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs @@ -228,8 +228,8 @@ public static ImmutableArray private static MulticlassClassificationMetrics MulticlassClassificationDelta( MulticlassClassificationMetrics a, MulticlassClassificationMetrics b) { - if (a.TopK != b.TopK) - Contracts.Assert(a.TopK == b.TopK, "TopK to compare must be the same length."); + if (a.TopKPredictionCount != b.TopKPredictionCount) + Contracts.Assert(a.TopKPredictionCount == b.TopKPredictionCount, "TopK to compare must be the same length."); var perClassLogLoss = ComputeSequenceDeltas(a.PerClassLogLoss, b.PerClassLogLoss); @@ -238,7 +238,7 @@ private static MulticlassClassificationMetrics MulticlassClassificationDelta( accuracyMacro: a.MacroAccuracy - b.MacroAccuracy, logLoss: a.LogLoss - b.LogLoss, logLossReduction: a.LogLossReduction - b.LogLossReduction, - topK: a.TopK, + topKPredictionCount: a.TopKPredictionCount, topKAccuracy: a.TopKAccuracy - b.TopKAccuracy, perClassLogLoss: perClassLogLoss ); diff --git a/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs b/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs index 10973ad109..7db783b145 100644 --- a/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs +++ b/test/Microsoft.ML.Tests/AnomalyDetectionTests.cs @@ -30,7 +30,7 @@ public void RandomizedPcaTrainerBaselineTest() var transformedData = DetectAnomalyInMnistOneClass(trainPath, testPath); // Evaluate - var metrics = ML.AnomalyDetection.Evaluate(transformedData, k: 5); + var metrics = ML.AnomalyDetection.Evaluate(transformedData, falsePositiveCount: 5); Assert.Equal(0.98667, metrics.AreaUnderRocCurve, 5); Assert.Equal(0.90000, metrics.DetectionRateAtFalsePositiveCount, 5); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index cd3221732f..062939aa6e 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -336,7 +336,7 @@ public void TestTrainTestSplit() // Let's do same thing, but this time we will choose different seed. // Stratification column should still break dataset properly without same values in both subsets. - var stratSeed = mlContext.Data.TrainTestSplit(input, samplingKeyColumnName:"Workclass", seed: 1000000); + var stratSeed = mlContext.Data.TrainTestSplit(input, samplingKeyColumnName: "Workclass", seed: 1000000); var stratTrainWithSeedWorkclass = getWorkclass(stratSeed.TrainSet); var stratTestWithSeedWorkClass = getWorkclass(stratSeed.TestSet); // Let's get unique values for "Workclass" column from train subset. diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index e934aa694f..838729bbaf 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -84,7 +84,7 @@ public void TrainAndPredictIrisModelTest() // Evaluate the trained pipeline var predicted = trainedModel.Transform(testData); - var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topK: 3); + var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topKPredictionCount: 3); Assert.Equal(.98, metrics.MacroAccuracy); Assert.Equal(.98, metrics.MicroAccuracy, 2); diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index 13a4d45db6..1093506819 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -87,7 +87,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest() // Evaluate the trained pipeline var predicted = trainedModel.Transform(testData); - var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topK: 3); + var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topKPredictionCount: 3); Assert.Equal(.98, metrics.MacroAccuracy); Assert.Equal(.98, metrics.MicroAccuracy, 2); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs index 653805f130..66186109ba 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs @@ -158,7 +158,7 @@ public void SdcaMulticlassLogisticRegression() // Step 4: Make prediction and evaluate its quality (on training set). var prediction = model.Transform(data); - var metrics = mlContext.MulticlassClassification.Evaluate(prediction, labelColumnName: "LabelIndex", topK: 1); + var metrics = mlContext.MulticlassClassification.Evaluate(prediction, labelColumnName: "LabelIndex", topKPredictionCount: 1); // Check a few metrics to make sure the trained model is ok. Assert.InRange(metrics.TopKAccuracy, 0.8, 1); @@ -192,7 +192,7 @@ public void SdcaMulticlassSupportVectorMachine() // Step 4: Make prediction and evaluate its quality (on training set). var prediction = model.Transform(data); - var metrics = mlContext.MulticlassClassification.Evaluate(prediction, labelColumnName: "LabelIndex", topK: 1); + var metrics = mlContext.MulticlassClassification.Evaluate(prediction, labelColumnName: "LabelIndex", topKPredictionCount: 1); // Check a few metrics to make sure the trained model is ok. Assert.InRange(metrics.TopKAccuracy, 0.8, 1); From 62dda6f6b94c20d9d9517fa3da30a96e49817db2 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Thu, 21 Mar 2019 14:33:35 -0700 Subject: [PATCH 18/18] Add more checks for the syntax of the embedded TextLoader options --- src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index 01c70b06da..14eb918779 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -1288,7 +1288,9 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files, ch.Assert(h.Loader == null || h.Loader is ICommandLineComponentFactory); var loader = h.Loader as ICommandLineComponentFactory; - if (loader == null || string.IsNullOrWhiteSpace(loader.Name)) + // Make sure that the schema is described using either the syntax TextLoader{} or the syntax Text{}, + // where "settings" is a string that can be parsed by CmdParser into an object of type TextLoader.Options. + if (loader == null || string.IsNullOrWhiteSpace(loader.Name) || (loader.Name != LoaderSignature && loader.Name != "Text")) goto LDone; var optionsNew = new Options();