diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index f99fefe378..701864c577 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -23,315 +23,35 @@ public partial class ScenariosTests [Fact] public void TrainAndPredictSentimentModelTest() { - string dataPath = GetDataPath(SentimentDataPath); - var pipeline = new LearningPipeline(); - - pipeline.Add(new Data.TextLoader(dataPath) - { - Arguments = new TextLoaderArguments - { - Separator = new[] { '\t' }, - HasHeader = true, - Column = new[] - { - new TextLoaderColumn() - { - Name = "Label", - Source = new [] { new TextLoaderRange(0) }, - Type = Runtime.Data.DataKind.Num - }, - - new TextLoaderColumn() - { - Name = "SentimentText", - Source = new [] { new TextLoaderRange(1) }, - Type = Runtime.Data.DataKind.Text - } - } - } - }); - - pipeline.Add(new TextFeaturizer("Features", "SentimentText") - { - KeepDiacritics = false, - KeepPunctuations = false, - TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, - OutputTokens = true, - StopWordsRemover = new PredefinedStopWordsRemover(), - VectorNormalizer = TextTransformTextNormKind.L2, - CharFeatureExtractor = new NGramNgramExtractor() { NgramLength = 3, AllLengths = false }, - WordFeatureExtractor = new NGramNgramExtractor() { NgramLength = 2, AllLengths = true } - }); - - pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); - pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); - - PredictionModel model = pipeline.Train(); - IEnumerable sentiments = new[] - { - new SentimentData - { - SentimentText = "Please refrain from adding nonsense to Wikipedia." - }, - new SentimentData - { - SentimentText = "He is a CHEATER, and the article should say that." - } - }; - - IEnumerable predictions = model.Predict(sentiments); - - Assert.Equal(2, predictions.Count()); - Assert.True(predictions.ElementAt(0).Sentiment.IsFalse); - Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); - - string testDataPath = GetDataPath(SentimentTestPath); - var testData = new Data.TextLoader(testDataPath) - { - Arguments = new TextLoaderArguments - { - Separator = new[] { '\t' }, - HasHeader = true, - Column = new[] - { - new TextLoaderColumn() - { - Name = "Label", - Source = new [] { new TextLoaderRange(0) }, - Type = Runtime.Data.DataKind.Num - }, - - new TextLoaderColumn() - { - Name = "SentimentText", - Source = new [] { new TextLoaderRange(1) }, - Type = Runtime.Data.DataKind.Text - } - } - } - }; - + var pipeline = PreparePipeline(); + var model = pipeline.Train(); + var testData = PrepareTextLoaderTestData(); var evaluator = new BinaryClassificationEvaluator(); BinaryClassificationMetrics metrics = evaluator.Evaluate(model, testData); - - Assert.Equal(.5556, metrics.Accuracy, 4); - Assert.Equal(.8, metrics.Auc, 1); - Assert.Equal(.87, metrics.Auprc, 2); - Assert.Equal(1, metrics.Entropy, 3); - Assert.Equal(.6923, metrics.F1Score, 4); - Assert.Equal(.969, metrics.LogLoss, 3); - Assert.Equal(3.083, metrics.LogLossReduction, 3); - Assert.Equal(1, metrics.NegativePrecision, 3); - Assert.Equal(.111, metrics.NegativeRecall, 3); - Assert.Equal(.529, metrics.PositivePrecision, 3); - Assert.Equal(1, metrics.PositiveRecall); - - ConfusionMatrix matrix = metrics.ConfusionMatrix; - Assert.Equal(2, matrix.Order); - Assert.Equal(2, matrix.ClassNames.Count); - Assert.Equal("positive", matrix.ClassNames[0]); - Assert.Equal("negative", matrix.ClassNames[1]); - - Assert.Equal(9, matrix[0, 0]); - Assert.Equal(9, matrix["positive", "positive"]); - Assert.Equal(0, matrix[0, 1]); - Assert.Equal(0, matrix["positive", "negative"]); - - Assert.Equal(8, matrix[1, 0]); - Assert.Equal(8, matrix["negative", "positive"]); - Assert.Equal(1, matrix[1, 1]); - Assert.Equal(1, matrix["negative", "negative"]); + ValidateExamples(model); + ValidateBinaryMetrics(metrics); } [Fact] public void TrainTestPredictSentimentModelTest() { - string dataPath = GetDataPath(SentimentDataPath); - var pipeline = new LearningPipeline(); - - pipeline.Add(new Data.TextLoader(dataPath) - { - Arguments = new TextLoaderArguments - { - Separator = new[] { '\t' }, - HasHeader = true, - Column = new[] - { - new TextLoaderColumn() - { - Name = "Label", - Source = new [] { new TextLoaderRange(0) }, - Type = Runtime.Data.DataKind.Num - }, - - new TextLoaderColumn() - { - Name = "SentimentText", - Source = new [] { new TextLoaderRange(1) }, - Type = Runtime.Data.DataKind.Text - } - } - } - }); - - pipeline.Add(new TextFeaturizer("Features", "SentimentText") - { - KeepDiacritics = false, - KeepPunctuations = false, - TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, - OutputTokens = true, - StopWordsRemover = new PredefinedStopWordsRemover(), - VectorNormalizer = TextTransformTextNormKind.L2, - CharFeatureExtractor = new NGramNgramExtractor() { NgramLength = 3, AllLengths = false }, - WordFeatureExtractor = new NGramNgramExtractor() { NgramLength = 2, AllLengths = true } - }); - - pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); - pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); - + var pipeline = PreparePipeline(); PredictionModel model = pipeline.Train(); - IEnumerable sentiments = new[] - { - new SentimentData - { - SentimentText = "Please refrain from adding nonsense to Wikipedia." - }, - new SentimentData - { - SentimentText = "He is a CHEATER, and the article should say that." - } - }; - - string testDataPath = GetDataPath(SentimentTestPath); - var testData = new Data.TextLoader(testDataPath) - { - Arguments = new TextLoaderArguments - { - Separator = new[] { '\t' }, - HasHeader = true, - Column = new[] - { - new TextLoaderColumn() - { - Name = "Label", - Source = new [] { new TextLoaderRange(0) }, - Type = Runtime.Data.DataKind.Num - }, - - new TextLoaderColumn() - { - Name = "SentimentText", - Source = new [] { new TextLoaderRange(1) }, - Type = Runtime.Data.DataKind.Text - } - } - } - }; - + var testData = PrepareTextLoaderTestData(); var tt = new TrainTestEvaluator().TrainTestEvaluate(pipeline, testData); Assert.Null(tt.ClassificationMetrics); Assert.Null(tt.RegressionMetrics); Assert.NotNull(tt.BinaryClassificationMetrics); Assert.NotNull(tt.PredictorModels); - - BinaryClassificationMetrics metrics = tt.BinaryClassificationMetrics; - Assert.Equal(.5556, metrics.Accuracy, 4); - Assert.Equal(.8, metrics.Auc, 1); - Assert.Equal(.87, metrics.Auprc, 2); - Assert.Equal(1, metrics.Entropy, 3); - Assert.Equal(.6923, metrics.F1Score, 4); - Assert.Equal(.969, metrics.LogLoss, 3); - Assert.Equal(3.083, metrics.LogLossReduction, 3); - Assert.Equal(1, metrics.NegativePrecision, 3); - Assert.Equal(.111, metrics.NegativeRecall, 3); - Assert.Equal(.529, metrics.PositivePrecision, 3); - Assert.Equal(1, metrics.PositiveRecall); - - ConfusionMatrix matrix = metrics.ConfusionMatrix; - Assert.Equal(2, matrix.Order); - Assert.Equal(2, matrix.ClassNames.Count); - Assert.Equal("positive", matrix.ClassNames[0]); - Assert.Equal("negative", matrix.ClassNames[1]); - - Assert.Equal(9, matrix[0, 0]); - Assert.Equal(9, matrix["positive", "positive"]); - Assert.Equal(0, matrix[0, 1]); - Assert.Equal(0, matrix["positive", "negative"]); - - Assert.Equal(8, matrix[1, 0]); - Assert.Equal(8, matrix["negative", "positive"]); - Assert.Equal(1, matrix[1, 1]); - Assert.Equal(1, matrix["negative", "negative"]); - - IEnumerable predictions = tt.PredictorModels.Predict(sentiments); - Assert.Equal(2, predictions.Count()); - Assert.True(predictions.ElementAt(0).Sentiment.IsFalse); - Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); - - predictions = tt.PredictorModels.Predict(sentiments); - Assert.Equal(2, predictions.Count()); - Assert.True(predictions.ElementAt(0).Sentiment.IsFalse); - Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); + ValidateExamples(tt.PredictorModels); + ValidateBinaryMetrics(tt.BinaryClassificationMetrics); } [Fact] public void CrossValidateSentimentModelTest() { - string dataPath = GetDataPath(SentimentDataPath); - var pipeline = new LearningPipeline(); - - pipeline.Add(new Data.TextLoader(dataPath) - { - Arguments = new TextLoaderArguments - { - Separator = new[] { '\t' }, - HasHeader = true, - Column = new[] - { - new TextLoaderColumn() - { - Name = "Label", - Source = new [] { new TextLoaderRange(0) }, - Type = Runtime.Data.DataKind.Num - }, - - new TextLoaderColumn() - { - Name = "SentimentText", - Source = new [] { new TextLoaderRange(1) }, - Type = Runtime.Data.DataKind.Text - } - } - } - }); - - pipeline.Add(new TextFeaturizer("Features", "SentimentText") - { - KeepDiacritics = false, - KeepPunctuations = false, - TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, - OutputTokens = true, - StopWordsRemover = new PredefinedStopWordsRemover(), - VectorNormalizer = TextTransformTextNormKind.L2, - CharFeatureExtractor = new NGramNgramExtractor() { NgramLength = 3, AllLengths = false }, - WordFeatureExtractor = new NGramNgramExtractor() { NgramLength = 2, AllLengths = true } - }); - - pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); - pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); - - IEnumerable sentiments = new[] - { - new SentimentData - { - SentimentText = "Please refrain from adding nonsense to Wikipedia." - }, - new SentimentData - { - SentimentText = "He is a CHEATER, and the article should say that." - } - }; + var pipeline = PreparePipeline(); var cv = new CrossValidator().CrossValidate(pipeline); @@ -343,7 +63,7 @@ public void CrossValidateSentimentModelTest() Assert.Equal(4, cv.BinaryClassificationMetrics.Count()); //Avergae of all folds. - BinaryClassificationMetrics metrics = cv.BinaryClassificationMetrics[0]; + var metrics = cv.BinaryClassificationMetrics[0]; Assert.Equal(0.57023626091422708, metrics.Accuracy, 4); Assert.Equal(0.54960689910161487, metrics.Auc, 1); Assert.Equal(0.67048277219704255, metrics.Auprc, 2); @@ -386,7 +106,7 @@ public void CrossValidateSentimentModelTest() Assert.Equal(0.58252427184466016, metrics.PositivePrecision, 3); Assert.Equal(0.759493670886076, metrics.PositiveRecall); - ConfusionMatrix matrix = metrics.ConfusionMatrix; + var matrix = metrics.ConfusionMatrix; Assert.Equal(2, matrix.Order); Assert.Equal(2, matrix.ClassNames.Count); Assert.Equal("positive", matrix.ClassNames[0]); @@ -432,7 +152,8 @@ public void CrossValidateSentimentModelTest() Assert.Equal(13, matrix[1, 1]); Assert.Equal(13, matrix["negative", "negative"]); - IEnumerable predictions = cv.PredictorModels[0].Predict(sentiments); + var sentiments = GetTestData(); + var predictions = cv.PredictorModels[0].Predict(sentiments); Assert.Equal(2, predictions.Count()); Assert.True(predictions.ElementAt(0).Sentiment.IsTrue); Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); @@ -443,6 +164,138 @@ public void CrossValidateSentimentModelTest() Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); } + private void ValidateBinaryMetrics(BinaryClassificationMetrics metrics) + { + Assert.Equal(.5556, metrics.Accuracy, 4); + Assert.Equal(.8, metrics.Auc, 1); + Assert.Equal(.87, metrics.Auprc, 2); + Assert.Equal(1, metrics.Entropy, 3); + Assert.Equal(.6923, metrics.F1Score, 4); + Assert.Equal(.969, metrics.LogLoss, 3); + Assert.Equal(3.083, metrics.LogLossReduction, 3); + Assert.Equal(1, metrics.NegativePrecision, 3); + Assert.Equal(.111, metrics.NegativeRecall, 3); + Assert.Equal(.529, metrics.PositivePrecision, 3); + Assert.Equal(1, metrics.PositiveRecall); + + var matrix = metrics.ConfusionMatrix; + Assert.Equal(2, matrix.Order); + Assert.Equal(2, matrix.ClassNames.Count); + Assert.Equal("positive", matrix.ClassNames[0]); + Assert.Equal("negative", matrix.ClassNames[1]); + + Assert.Equal(9, matrix[0, 0]); + Assert.Equal(9, matrix["positive", "positive"]); + Assert.Equal(0, matrix[0, 1]); + Assert.Equal(0, matrix["positive", "negative"]); + + Assert.Equal(8, matrix[1, 0]); + Assert.Equal(8, matrix["negative", "positive"]); + Assert.Equal(1, matrix[1, 1]); + Assert.Equal(1, matrix["negative", "negative"]); + } + + private LearningPipeline PreparePipeline() + { + var dataPath = GetDataPath(SentimentDataPath); + var pipeline = new LearningPipeline(); + + pipeline.Add(new Data.TextLoader(dataPath) + { + Arguments = new TextLoaderArguments + { + Separator = new[] { '\t' }, + HasHeader = true, + Column = new[] + { + new TextLoaderColumn() + { + Name = "Label", + Source = new [] { new TextLoaderRange(0) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SentimentText", + Source = new [] { new TextLoaderRange(1) }, + Type = Runtime.Data.DataKind.Text + } + } + } + }); + + pipeline.Add(new TextFeaturizer("Features", "SentimentText") + { + KeepDiacritics = false, + KeepPunctuations = false, + TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, + OutputTokens = true, + StopWordsRemover = new PredefinedStopWordsRemover(), + VectorNormalizer = TextTransformTextNormKind.L2, + CharFeatureExtractor = new NGramNgramExtractor() { NgramLength = 3, AllLengths = false }, + WordFeatureExtractor = new NGramNgramExtractor() { NgramLength = 2, AllLengths = true } + }); + + pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); + pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); + return pipeline; + } + + private void ValidateExamples(PredictionModel model) + { + var sentiments = GetTestData(); + var predictions = model.Predict(sentiments); + Assert.Equal(2, predictions.Count()); + Assert.True(predictions.ElementAt(0).Sentiment.IsFalse); + Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); + } + + private Data.TextLoader PrepareTextLoaderTestData() + { + var testDataPath = GetDataPath(SentimentTestPath); + var testData = new Data.TextLoader(testDataPath) + { + Arguments = new TextLoaderArguments + { + Separator = new[] { '\t' }, + HasHeader = true, + Column = new[] + { + new TextLoaderColumn() + { + Name = "Label", + Source = new [] { new TextLoaderRange(0) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SentimentText", + Source = new [] { new TextLoaderRange(1) }, + Type = Runtime.Data.DataKind.Text + } + } + } + }; + return testData; + } + + private IEnumerable GetTestData() + { + return new[] + { + new SentimentData + { + SentimentText = "Please refrain from adding nonsense to Wikipedia." + }, + new SentimentData + { + SentimentText = "He is a CHEATER, and the article should say that." + } + }; + } + public class SentimentData { [Column(ordinal: "0", name: "Label")]