From f233c901caba4e3e7eb3daaee1a978cd90236ee0 Mon Sep 17 00:00:00 2001 From: Rogan Carr Date: Tue, 12 Mar 2019 13:47:40 -0700 Subject: [PATCH 1/5] Adding an option to provide FM with an empty arguments list. --- .../FactorizationMachineCatalog.cs | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs b/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs index c0dbce1724..cfaf37e92b 100644 --- a/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs @@ -13,6 +13,29 @@ namespace Microsoft.ML /// public static class FactorizationMachineExtensions { + /// + /// Predict a target using a field-aware factorization machine algorithm. + /// + /// The binary classification catalog trainer object. + /// The name(s) of the feature columns. + /// The name of the label column. + /// The name of the example weight column (optional). + /// + /// + /// + /// + public static FieldAwareFactorizationMachineBinaryClassificationTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, + string featureColumnName = DefaultColumnNames.Features, + string labelColumnName = DefaultColumnNames.Label, + string exampleWeightColumnName = null) + { + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + return new FieldAwareFactorizationMachineBinaryClassificationTrainer(env, new string[] { featureColumnName }, labelColumnName, exampleWeightColumnName); + } + /// /// Predict a target using a field-aware factorization machine algorithm. /// From dbcd0f0fe9f2bbc288da340bfd7ffcdb97b6bb6a Mon Sep 17 00:00:00 2001 From: Rogan Carr Date: Tue, 12 Mar 2019 13:50:25 -0700 Subject: [PATCH 2/5] Update the doc strings --- .../FactorizationMachine/FactorizationMachineCatalog.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs b/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs index cfaf37e92b..0d327b243c 100644 --- a/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs @@ -40,7 +40,7 @@ public static FieldAwareFactorizationMachineBinaryClassificationTrainer FieldAwa /// Predict a target using a field-aware factorization machine algorithm. /// /// The binary classification catalog trainer object. - /// The name(s) of the feature columns. + /// The name of the feature column. /// The name of the label column. /// The name of the example weight column (optional). /// From 2b12747fe33999242c1374a2fbd55c794bbdd55e Mon Sep 17 00:00:00 2001 From: Rogan Carr Date: Tue, 12 Mar 2019 13:51:11 -0700 Subject: [PATCH 3/5] Update the doc strings --- .../FactorizationMachine/FactorizationMachineCatalog.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs b/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs index 0d327b243c..50de61c7ae 100644 --- a/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs @@ -17,7 +17,7 @@ public static class FactorizationMachineExtensions /// Predict a target using a field-aware factorization machine algorithm. /// /// The binary classification catalog trainer object. - /// The name(s) of the feature columns. + /// The name of the feature column. /// The name of the label column. /// The name of the example weight column (optional). /// @@ -40,7 +40,7 @@ public static FieldAwareFactorizationMachineBinaryClassificationTrainer FieldAwa /// Predict a target using a field-aware factorization machine algorithm. /// /// The binary classification catalog trainer object. - /// The name of the feature column. + /// The name(s) of the feature columns. /// The name of the label column. /// The name of the example weight column (optional). /// From 5fd2e6c98f06e1608b9fff398fe16740461141ec Mon Sep 17 00:00:00 2001 From: Rogan Carr Date: Wed, 13 Mar 2019 13:16:07 -0700 Subject: [PATCH 4/5] Addressing PR Comments --- ...areFactorizationMachineWithoutArguments.cs | 80 +++++++++++++++++++ .../FactorizationMachineCatalog.cs | 5 +- .../TrainerEstimators/FAFMEstimator.cs | 21 +++++ 3 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachineWithoutArguments.cs diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachineWithoutArguments.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachineWithoutArguments.cs new file mode 100644 index 0000000000..1f4d6bd5be --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachineWithoutArguments.cs @@ -0,0 +1,80 @@ +using System; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Samples.Dynamic +{ + public static class FFMBinaryClassificationWithoutArguments + { + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + var mlContext = new MLContext(); + + // Download and featurize the dataset. + var dataviews = SamplesUtils.DatasetUtils.LoadFeaturizedSentimentDataset(mlContext); + var trainData = dataviews[0]; + var testData = dataviews[1]; + + // ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to + // expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially + // helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a + // cache step in a pipeline is also possible, please see the construction of pipeline below. + trainData = mlContext.Data.Cache(trainData); + + // Step 2: Pipeline + // Create the 'FieldAwareFactorizationMachine' binary classifier, setting the "Sentiment" column as the label of the dataset, and + // the "Features" column as the features column. + var pipeline = mlContext.Transforms.CopyColumns("Label", "Sentiment") + .AppendCacheCheckpoint(mlContext) + .Append(mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine()); + + // Fit the model. + var model = pipeline.Fit(trainData); + + // Let's get the model parameters from the model. + var modelParams = model.LastTransformer.Model; + + // Let's inspect the model parameters. + var featureCount = modelParams.FeatureCount; + var fieldCount = modelParams.FieldCount; + var latentDim = modelParams.LatentDimension; + var linearWeights = modelParams.GetLinearWeights(); + var latentWeights = modelParams.GetLatentWeights(); + + Console.WriteLine("The feature count is: " + featureCount); + Console.WriteLine("The number of fields is: " + fieldCount); + Console.WriteLine("The latent dimension is: " + latentDim); + Console.WriteLine("The linear weights of some of the features are: " + + string.Concat(Enumerable.Range(1, 10).Select(i => $"{linearWeights[i]:F4} "))); + Console.WriteLine("The weights of some of the latent features are: " + + string.Concat(Enumerable.Range(1, 10).Select(i => $"{latentWeights[i]:F4} "))); + + // Expected Output: + // The feature count is: 9374 + // The number of fields is: 1 + // The latent dimension is: 20 + // The linear weights of some of the features are: 0.0188 0.0000 -0.0048 -0.0184 0.0000 0.0031 0.0914 0.0112 -0.0152 0.0110 + // The weights of some of the latent features are: 0.0631 0.0041 -0.0333 0.0694 0.1330 0.0790 0.1168 -0.0848 0.0431 0.0411 + + // Evaluate how the model is doing on the test data. + var dataWithPredictions = model.Transform(testData); + + var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions, "Sentiment"); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Accuracy: 0.61 + // AUC: 0.72 + // F1 Score: 0.59 + // Negative Precision: 0.60 + // Negative Recall: 0.67 + // Positive Precision: 0.63 + // Positive Recall: 0.56 + // Log Loss: 1.21 + // Log Loss Reduction: -21.20 + // Entropy: 1.00 + } + } +} diff --git a/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs b/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs index 50de61c7ae..7bfbe181b3 100644 --- a/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs @@ -16,6 +16,9 @@ public static class FactorizationMachineExtensions /// /// Predict a target using a field-aware factorization machine algorithm. /// + /// + /// Note that because there is only one feature column, the underlying model is equivalent to standard factorization machine. + /// /// The binary classification catalog trainer object. /// The name of the feature column. /// The name of the label column. @@ -23,7 +26,7 @@ public static class FactorizationMachineExtensions /// /// /// /// public static FieldAwareFactorizationMachineBinaryClassificationTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs index 2ea971e096..e55e80be41 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs @@ -12,6 +12,27 @@ namespace Microsoft.ML.Tests.TrainerEstimators { public partial class TrainerEstimators : TestDataPipeBase { + [Fact] + public void FfmBinaryClassificationWithoutArguments() + { + var mlContext = new MLContext(seed: 0); + var data = DatasetUtils.GenerateFfmSamples(500); + var dataView = mlContext.Data.LoadFromEnumerable(data); + + var pipeline = mlContext.Transforms.CopyColumns(DefaultColumnNames.Features, nameof(DatasetUtils.FfmExample.Field0)) + .Append(mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine()); + + var model = pipeline.Fit(dataView); + var prediction = model.Transform(dataView); + + var metrics = mlContext.BinaryClassification.Evaluate(prediction); + + // Run a sanity check against a few of the metrics. + Assert.InRange(metrics.Accuracy, 0.6, 1); + Assert.InRange(metrics.AreaUnderRocCurve, 0.7, 1); + Assert.InRange(metrics.AreaUnderPrecisionRecallCurve, 0.65, 1); + } + [Fact] public void FfmBinaryClassificationWithAdvancedArguments() { From 14a46f0dbb286ed591f1678fa8107be1a284e10b Mon Sep 17 00:00:00 2001 From: Rogan Carr Date: Wed, 13 Mar 2019 14:07:33 -0700 Subject: [PATCH 5/5] Fixing breaking changes in master. --- .../FactorizationMachine/FactorizationMachineCatalog.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs b/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs index e06769753f..4b797e4221 100644 --- a/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs @@ -29,14 +29,14 @@ public static class FactorizationMachineExtensions /// [!code-csharp[FieldAwareFactorizationMachineWithoutArguments](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachineWithoutArguments.cs)] /// ]]> /// - public static FieldAwareFactorizationMachineBinaryClassificationTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, + public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, string featureColumnName = DefaultColumnNames.Features, string labelColumnName = DefaultColumnNames.Label, string exampleWeightColumnName = null) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); - return new FieldAwareFactorizationMachineBinaryClassificationTrainer(env, new string[] { featureColumnName }, labelColumnName, exampleWeightColumnName); + return new FieldAwareFactorizationMachineTrainer(env, new string[] { featureColumnName }, labelColumnName, exampleWeightColumnName); } ///