From ab146fc4cf6bcc31c1b926bbd5e56719a4b1fe1d Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Mon, 8 Apr 2019 16:27:41 -0700 Subject: [PATCH 1/9] Exposing the confusion matrix --- .../Evaluators/BinaryClassifierEvaluator.cs | 14 ++- .../Evaluators/EvaluatorUtils.cs | 89 ++++++++++++------- .../Metrics/BinaryClassificationMetrics.cs | 8 +- .../CalibratedBinaryClassificationMetrics.cs | 6 +- .../Evaluators/Metrics/ConfusionMatrix.cs | 49 ++++++++++ .../MulticlassClassificationMetrics.cs | 10 ++- .../MulticlassClassificationEvaluator.cs | 3 +- 7 files changed, 133 insertions(+), 46 deletions(-) create mode 100644 src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index 7fb2d9ea22..4ac5d0e773 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -815,16 +815,18 @@ public CalibratedBinaryClassificationMetrics Evaluate(IDataView data, string lab var resultDict = ((IEvaluator)this).Evaluate(roles); Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); var overall = resultDict[MetricKinds.OverallMetrics]; + var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix]; CalibratedBinaryClassificationMetrics result; using (var cursor = overall.GetRowCursorForAllColumns()) { var moved = cursor.MoveNext(); Host.Assert(moved); - result = new CalibratedBinaryClassificationMetrics(Host, cursor); + result = new CalibratedBinaryClassificationMetrics(Host, cursor, confusionMatrix); moved = cursor.MoveNext(); Host.Assert(!moved); } + return result; } @@ -879,13 +881,14 @@ public CalibratedBinaryClassificationMetrics EvaluateWithPRCurve( } } prCurve = prCurveResult; + var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix]; CalibratedBinaryClassificationMetrics result; using (var cursor = overall.GetRowCursorForAllColumns()) { var moved = cursor.MoveNext(); Host.Assert(moved); - result = new CalibratedBinaryClassificationMetrics(Host, cursor); + result = new CalibratedBinaryClassificationMetrics(Host, cursor, confusionMatrix); moved = cursor.MoveNext(); Host.Assert(!moved); } @@ -939,16 +942,18 @@ public BinaryClassificationMetrics Evaluate(IDataView data, string label, string var resultDict = ((IEvaluator)this).Evaluate(roles); Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); var overall = resultDict[MetricKinds.OverallMetrics]; + var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix]; BinaryClassificationMetrics result; using (var cursor = overall.GetRowCursorForAllColumns()) { var moved = cursor.MoveNext(); Host.Assert(moved); - result = new BinaryClassificationMetrics(Host, cursor); + result = new BinaryClassificationMetrics(Host, cursor, confusionMatrix); moved = cursor.MoveNext(); Host.Assert(!moved); } + return result; } @@ -985,6 +990,7 @@ public BinaryClassificationMetrics EvaluateWithPRCurve( var prCurveView = resultDict[MetricKinds.PrCurve]; Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); var overall = resultDict[MetricKinds.OverallMetrics]; + var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix]; var prCurveResult = new List(); using (var cursor = prCurveView.GetRowCursorForAllColumns()) @@ -1007,7 +1013,7 @@ public BinaryClassificationMetrics EvaluateWithPRCurve( { var moved = cursor.MoveNext(); Host.Assert(moved); - result = new BinaryClassificationMetrics(Host, cursor); + result = new BinaryClassificationMetrics(Host, cursor, confusionMatrix); moved = cursor.MoveNext(); Host.Assert(!moved); } diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index d3d36d737d..cf5d4b453a 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -1353,14 +1353,43 @@ public static string GetConfusionTable(IHost host, IDataView confusionDataView, host.CheckValue(confusionDataView, nameof(confusionDataView)); host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2"); - // Get the class names. - int countCol; - host.Check(confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.Count, out countCol), "Did not find the count column"); - var type = confusionDataView.Schema[countCol].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType; + var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight); + bool isWeighted = weightColumn.HasValue; + + var confusionMatrix = GetConfusionTableAsType(host, confusionDataView, binary, sample, false); + var confusionTableString = GetConfusionTableAsString(confusionMatrix, false); + + // if there is a Weight column, return the weighted confusionMatrix as well, from this function. + if (isWeighted) + { + confusionMatrix = GetConfusionTableAsType(host, confusionDataView, binary, sample, false); + weightedConfusionTable = GetConfusionTableAsString(confusionMatrix, true); + } + else + weightedConfusionTable = null; + + return confusionTableString; + } + + public static ConfusionMatrix GetConfusionTableAsType(IHost host, IDataView confusionDataView, bool binary = true, int sample = -1, bool getWeighted = false) + { + host.CheckValue(confusionDataView, nameof(confusionDataView)); + host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2"); + + // check that there is a Weight column, if isWeighted parameter is set to true. + var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight); + if (getWeighted) + host.CheckParam(weightColumn.HasValue, nameof(getWeighted), "There is no Weight column in the confusionMatrix data view."); + + // Get the counts names. + var countColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Count); + host.Check(countColumn.HasValue, "Did not find the count column"); + var type = countColumn.Value.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType; host.Check(type != null && type.IsKnownSize && type.ItemType is TextDataViewType, "The Count column does not have a text vector metadata of kind SlotNames."); + // Get the class names var labelNames = default(VBuffer>); - confusionDataView.Schema[countCol].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames); + countColumn.Value.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames); host.Check(labelNames.IsDense, "Slot names vector must be dense"); int numConfusionTableLabels = sample < 0 ? labelNames.Length : Math.Min(labelNames.Length, sample); @@ -1387,27 +1416,20 @@ public static string GetConfusionTable(IHost host, IDataView confusionDataView, double[] precisionSums; double[] recallSums; - var confusionTable = GetConfusionTableAsArray(confusionDataView, countCol, labelNames.Length, + double[][] confusionTable; + + if(getWeighted) + confusionTable = GetConfusionTableAsArray(confusionDataView, weightColumn.Value.Index, labelNames.Length, + labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums); + else + confusionTable = GetConfusionTableAsArray(confusionDataView, countColumn.Value.Index, labelNames.Length, labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums); var predictedLabelNames = GetPredictedLabelNames(in labelNames, labelIndexToConfIndexMap); - var confusionTableString = GetConfusionTableAsString(confusionTable, recallSums, precisionSums, - predictedLabelNames, - sampled: numConfusionTableLabels < labelNames.Length, binary: binary); + bool sampled = numConfusionTableLabels < labelNames.Length; - int weightIndex; - if (confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.Weight, out weightIndex)) - { - confusionTable = GetConfusionTableAsArray(confusionDataView, weightIndex, labelNames.Length, - labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums); - weightedConfusionTable = GetConfusionTableAsString(confusionTable, recallSums, precisionSums, - predictedLabelNames, - sampled: numConfusionTableLabels < labelNames.Length, prefix: "Weighted ", binary: binary); - } - else - weightedConfusionTable = null; + return new ConfusionMatrix(host, precisionSums, recallSums, confusionTable, predictedLabelNames, sampled, binary); - return confusionTableString; } private static List> GetPredictedLabelNames(in VBuffer> labelNames, int[] labelIndexToConfIndexMap) @@ -1553,13 +1575,13 @@ private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView dat } // Get a string representation of a confusion table. - private static string GetConfusionTableAsString(double[][] confusionTable, double[] rowSums, double[] columnSums, - List> predictedLabelNames, string prefix = "", bool sampled = false, bool binary = true) + internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix, bool isWeighted) { - int numLabels = Utils.Size(confusionTable); + string prefix = isWeighted ? "Weighted " : ""; + int numLabels = Utils.Size(confusionMatrix.ConfusionTableCounts); int colWidth = numLabels == 2 ? 8 : 5; - int maxNameLen = predictedLabelNames.Max(name => name.Length); + int maxNameLen = confusionMatrix.PredictedLabelNames.Max(name => name.Length); // If the names are too long to fit in the column header, we back off to using class indices // in the header. This will also require putting the indices in the row, but it's better than // the alternative of having ambiguous abbreviated column headers, or having a table potentially @@ -1572,7 +1594,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl { // The row label will also include the index, so a user can easily match against the header. // In such a case, a label like "Foo" would be presented as something like "5. Foo". - rowDigitLen = Math.Max(predictedLabelNames.Count - 1, 0).ToString().Length; + rowDigitLen = Math.Max(confusionMatrix.PredictedLabelNames.Count - 1, 0).ToString().Length; Contracts.Assert(rowDigitLen >= 1); rowLabelLen += rowDigitLen + 2; } @@ -1591,10 +1613,11 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl else rowLabelFormat = string.Format("{{1,{0}}} ||", paddingLen); + var confusionTable = confusionMatrix.ConfusionTableCounts; var sb = new StringBuilder(); - if (numLabels == 2 && binary) + if (numLabels == 2 && confusionMatrix.Binary) { - var positiveCaps = predictedLabelNames[0].ToString().ToUpper(); + var positiveCaps = confusionMatrix.PredictedLabelNames[0].ToString().ToUpper(); var numTruePos = confusionTable[0][0]; var numFalseNeg = confusionTable[0][1]; @@ -1607,7 +1630,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl sb.AppendLine(); sb.AppendFormat("{0}Confusion table", prefix); - if (sampled) + if (confusionMatrix.Sampled) sb.AppendLine(" (sampled)"); else sb.AppendLine(); @@ -1619,7 +1642,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl sb.AppendFormat("PREDICTED {0}||", pad); string format = string.Format(" {{{0},{1}}} |", useNumbersInHeader ? 0 : 1, colWidth); for (int i = 0; i < numLabels; i++) - sb.AppendFormat(format, i, predictedLabelNames[i]); + sb.AppendFormat(format, i, confusionMatrix.PredictedLabelNames[i]); sb.AppendLine(" Recall"); sb.AppendFormat("TRUTH {0}||", pad); for (int i = 0; i < numLabels; i++) @@ -1631,10 +1654,10 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl string.IsNullOrWhiteSpace(prefix) ? "N0" : "F1"); for (int i = 0; i < numLabels; i++) { - sb.AppendFormat(rowLabelFormat, i, predictedLabelNames[i]); + sb.AppendFormat(rowLabelFormat, i, confusionMatrix.PredictedLabelNames[i]); for (int j = 0; j < numLabels; j++) sb.AppendFormat(format2, confusionTable[i][j]); - Double recall = rowSums[i] > 0 ? confusionTable[i][i] / rowSums[i] : 0; + Double recall = confusionMatrix.RecallSums[i] > 0 ? confusionTable[i][i] / confusionMatrix.RecallSums[i] : 0; sb.AppendFormat(" {0,5:F4}", recall); sb.AppendLine(); } @@ -1646,7 +1669,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl format = string.Format("{{0,{0}:N4}} |", colWidth + 1); for (int i = 0; i < numLabels; i++) { - Double precision = columnSums[i] > 0 ? confusionTable[i][i] / columnSums[i] : 0; + Double precision = confusionMatrix.PrecisionSums[i] > 0 ? confusionTable[i][i] / confusionMatrix.PrecisionSums[i] : 0; sb.AppendFormat(format, precision); } sb.AppendLine(); diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs index 2f7cf2914d..1497d7df70 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Linq; using Microsoft.ML.Runtime; namespace Microsoft.ML.Data @@ -74,6 +75,8 @@ public class BinaryClassificationMetrics /// public double AreaUnderPrecisionRecallCurve { get; } + public ConfusionMatrix ConfusionMatrix { get; } + private protected static T Fetch(IExceptionContext ectx, DataViewRow row, string name) { var column = row.Schema.GetColumnOrNull(name); @@ -84,9 +87,9 @@ private protected static T Fetch(IExceptionContext ectx, DataViewRow row, str return val; } - internal BinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overallResult) + internal BinaryClassificationMetrics(IHost host, DataViewRow overallResult, IDataView confusionMatrix) { - double Fetch(string name) => Fetch(ectx, overallResult, name); + double Fetch(string name) => Fetch(host, overallResult, name); AreaUnderRocCurve = Fetch(BinaryClassifierEvaluator.Auc); Accuracy = Fetch(BinaryClassifierEvaluator.Accuracy); PositivePrecision = Fetch(BinaryClassifierEvaluator.PosPrecName); @@ -95,6 +98,7 @@ internal BinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overall NegativeRecall = Fetch(BinaryClassifierEvaluator.NegRecallName); F1Score = Fetch(BinaryClassifierEvaluator.F1); AreaUnderPrecisionRecallCurve = Fetch(BinaryClassifierEvaluator.AuPrc); + ConfusionMatrix = MetricWriter.GetConfusionTableAsType(host, confusionMatrix); } [BestFriend] diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs index 7c3d283137..2ca20fc998 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/CalibratedBinaryClassificationMetrics.cs @@ -41,10 +41,10 @@ public sealed class CalibratedBinaryClassificationMetrics : BinaryClassification /// public double Entropy { get; } - internal CalibratedBinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overallResult) - : base(ectx, overallResult) + internal CalibratedBinaryClassificationMetrics(IHost host, DataViewRow overallResult, IDataView confusionMatrix) + : base(host, overallResult, confusionMatrix) { - double Fetch(string name) => Fetch(ectx, overallResult, name); + double Fetch(string name) => Fetch(host, overallResult, name); LogLoss = Fetch(BinaryClassifierEvaluator.LogLoss); LogLossReduction = Fetch(BinaryClassifierEvaluator.LogLossReduction); Entropy = Fetch(BinaryClassifierEvaluator.Entropy); diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs new file mode 100644 index 0000000000..bf62602140 --- /dev/null +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs @@ -0,0 +1,49 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Data +{ + public sealed class ConfusionMatrix + { + public double[] PrecisionSums { get; } + public double[] RecallSums { get; } + public ReadOnlyMemory[] Labels => PredictedLabelNames.ToArray(); //sefilipi: return as string[] + + public double[][] ConfusionTableCounts { get; } + + internal readonly List> PredictedLabelNames; + internal readonly bool Sampled; + internal readonly bool Binary; + + private readonly IHost _host; + + internal ConfusionMatrix(IHost host, double[] precisionSums, double[] recallSums, double[][] confusionTableCounts, + List> labelNames, bool sampled, bool binary) + { + _host = host; + + PrecisionSums = precisionSums; + RecallSums = recallSums; + ConfusionTableCounts = confusionTableCounts; + PredictedLabelNames = labelNames; + Sampled = sampled; + Binary = binary; + } + + public override string ToString() => MetricWriter.GetConfusionTableAsString(this, false); + + public double GetCountForClassPair(string predictedLabel, string actualLabel) + { + int predictedLabelIndex = PredictedLabelNames.IndexOf(predictedLabel.AsMemory()); + int actualLabelIndex = PredictedLabelNames.IndexOf(actualLabel.AsMemory()); + + _host.CheckParam(predictedLabelIndex > -1, nameof(predictedLabel), "Unknown given PredictedLabel."); + _host.CheckParam(actualLabelIndex > -1, nameof(actualLabel), "Unknown given ActualLabel."); + + _host.Assert(predictedLabelIndex < ConfusionTableCounts.Length && actualLabelIndex < ConfusionTableCounts.Length); + + return ConfusionTableCounts[actualLabelIndex][predictedLabelIndex]; + } + } +} diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs index 843dc00284..3b62225da8 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs @@ -82,9 +82,12 @@ public sealed class MulticlassClassificationMetrics /// public IReadOnlyList PerClassLogLoss { get; } - internal MulticlassClassificationMetrics(IExceptionContext ectx, DataViewRow overallResult, int topKPredictionCount) + // The confusion matrix. + public ConfusionMatrix ConfusionMatrix { get; } + + internal MulticlassClassificationMetrics(IHost host, DataViewRow overallResult, int topKPredictionCount, IDataView confusionMatrix) { - double FetchDouble(string name) => RowCursorUtils.Fetch(ectx, overallResult, name); + double FetchDouble(string name) => RowCursorUtils.Fetch(host, overallResult, name); MicroAccuracy = FetchDouble(MulticlassClassificationEvaluator.AccuracyMicro); MacroAccuracy = FetchDouble(MulticlassClassificationEvaluator.AccuracyMacro); LogLoss = FetchDouble(MulticlassClassificationEvaluator.LogLoss); @@ -93,8 +96,9 @@ internal MulticlassClassificationMetrics(IExceptionContext ectx, DataViewRow ove if (topKPredictionCount > 0) TopKAccuracy = FetchDouble(MulticlassClassificationEvaluator.TopKAccuracy); - var perClassLogLoss = RowCursorUtils.Fetch>(ectx, overallResult, MulticlassClassificationEvaluator.PerClassLogLoss); + var perClassLogLoss = RowCursorUtils.Fetch>(host, overallResult, MulticlassClassificationEvaluator.PerClassLogLoss); PerClassLogLoss = perClassLogLoss.DenseValues().ToImmutableArray(); + ConfusionMatrix = MetricWriter.GetConfusionTableAsType(host, confusionMatrix, binary: false, perClassLogLoss.Length, getWeighted:false); } internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction, diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassificationEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassificationEvaluator.cs index 0f7c448ab3..dc4500a963 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassificationEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassificationEvaluator.cs @@ -518,13 +518,14 @@ public MulticlassClassificationMetrics Evaluate(IDataView data, string label, st var resultDict = ((IEvaluator)this).Evaluate(roles); Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); var overall = resultDict[MetricKinds.OverallMetrics]; + var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix]; MulticlassClassificationMetrics result; using (var cursor = overall.GetRowCursorForAllColumns()) { var moved = cursor.MoveNext(); Host.Assert(moved); - result = new MulticlassClassificationMetrics(Host, cursor, _outputTopKAcc ?? 0); + result = new MulticlassClassificationMetrics(Host, cursor, _outputTopKAcc ?? 0, confusionMatrix); moved = cursor.MoveNext(); Host.Assert(!moved); } From 076c04fb175aab75a10fe25f77b0971122b0e444 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 11 Apr 2019 19:53:49 -0700 Subject: [PATCH 2/9] PR comments --- src/Microsoft.ML.Core/Data/AnnotationUtils.cs | 4 +- .../Evaluators/EvaluatorUtils.cs | 51 ++++++------- .../Metrics/BinaryClassificationMetrics.cs | 1 - .../Evaluators/Metrics/ConfusionMatrix.cs | 72 +++++++++++++------ test/Microsoft.ML.Functional.Tests/Common.cs | 33 +++++++++ .../Evaluation.cs | 2 +- .../Api/Estimators/PredictAndMetadata.cs | 63 ++++++++++++++++ 7 files changed, 177 insertions(+), 49 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/AnnotationUtils.cs b/src/Microsoft.ML.Core/Data/AnnotationUtils.cs index 9ffdd04517..17671e6154 100644 --- a/src/Microsoft.ML.Core/Data/AnnotationUtils.cs +++ b/src/Microsoft.ML.Core/Data/AnnotationUtils.cs @@ -440,8 +440,8 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co /// Label column. public static IEnumerable AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null) { - var cols = new List(); - if (labelColumn.HasValue && labelColumn.Value.IsKey) + var cols = new List(); + if (labelColumn != null && labelColumn.Value.IsKey) { if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) && metaCol.Kind == SchemaShape.Column.VectorKind.Vector) diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index cf5d4b453a..45cd4316b4 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -1382,15 +1382,14 @@ public static ConfusionMatrix GetConfusionTableAsType(IHost host, IDataView conf host.CheckParam(weightColumn.HasValue, nameof(getWeighted), "There is no Weight column in the confusionMatrix data view."); // Get the counts names. - var countColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Count); - host.Check(countColumn.HasValue, "Did not find the count column"); - var type = countColumn.Value.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType; - host.Check(type != null && type.IsKnownSize && type.ItemType is TextDataViewType, "The Count column does not have a text vector metadata of kind SlotNames."); + var countColumn = confusionDataView.Schema[MetricKinds.ColumnNames.Count]; + var type = countColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType; + host.Assert(type != null && type.IsKnownSize && type.ItemType is TextDataViewType, "The Count column does not have a text vector metadata of kind SlotNames."); // Get the class names var labelNames = default(VBuffer>); - countColumn.Value.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames); - host.Check(labelNames.IsDense, "Slot names vector must be dense"); + countColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames); + host.Assert(labelNames.IsDense, "Slot names vector must be dense"); int numConfusionTableLabels = sample < 0 ? labelNames.Length : Math.Min(labelNames.Length, sample); @@ -1418,23 +1417,30 @@ public static ConfusionMatrix GetConfusionTableAsType(IHost host, IDataView conf double[] recallSums; double[][] confusionTable; - if(getWeighted) + if (getWeighted) confusionTable = GetConfusionTableAsArray(confusionDataView, weightColumn.Value.Index, labelNames.Length, labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums); else - confusionTable = GetConfusionTableAsArray(confusionDataView, countColumn.Value.Index, labelNames.Length, + confusionTable = GetConfusionTableAsArray(confusionDataView, countColumn.Index, labelNames.Length, labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums); + double[] precision = new double[numConfusionTableLabels]; + double[] recall = new double[numConfusionTableLabels]; + for (int i = 0; i < numConfusionTableLabels; i++) + { + recall[i] = recallSums[i] > 0 ? confusionTable[i][i] / recallSums[i] : 0; + precision[i] = precisionSums[i] > 0 ? confusionTable[i][i] / precisionSums[i] : 0; + } + var predictedLabelNames = GetPredictedLabelNames(in labelNames, labelIndexToConfIndexMap); bool sampled = numConfusionTableLabels < labelNames.Length; - return new ConfusionMatrix(host, precisionSums, recallSums, confusionTable, predictedLabelNames, sampled, binary); - + return new ConfusionMatrix(host, precision, recall, confusionTable, predictedLabelNames, sampled, binary, countColumn.Annotations); } private static List> GetPredictedLabelNames(in VBuffer> labelNames, int[] labelIndexToConfIndexMap) { - List> result = new List>(); + List > result = new List>(); var values = labelNames.GetValues(); for (int i = 0; i < values.Length; i++) { @@ -1578,10 +1584,10 @@ private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView dat internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix, bool isWeighted) { string prefix = isWeighted ? "Weighted " : ""; - int numLabels = Utils.Size(confusionMatrix.ConfusionTableCounts); + int numLabels = confusionMatrix?.ConfusionTableCounts == null? 0: confusionMatrix.ConfusionTableCounts.Length; int colWidth = numLabels == 2 ? 8 : 5; - int maxNameLen = confusionMatrix.PredictedLabelNames.Max(name => name.Length); + int maxNameLen = confusionMatrix.PredictedClassesIndicators.Max(name => name.Length); // If the names are too long to fit in the column header, we back off to using class indices // in the header. This will also require putting the indices in the row, but it's better than // the alternative of having ambiguous abbreviated column headers, or having a table potentially @@ -1594,7 +1600,7 @@ internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix { // The row label will also include the index, so a user can easily match against the header. // In such a case, a label like "Foo" would be presented as something like "5. Foo". - rowDigitLen = Math.Max(confusionMatrix.PredictedLabelNames.Count - 1, 0).ToString().Length; + rowDigitLen = Math.Max(confusionMatrix.PredictedClassesIndicators.Count - 1, 0).ToString().Length; Contracts.Assert(rowDigitLen >= 1); rowLabelLen += rowDigitLen + 2; } @@ -1617,7 +1623,7 @@ internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix var sb = new StringBuilder(); if (numLabels == 2 && confusionMatrix.Binary) { - var positiveCaps = confusionMatrix.PredictedLabelNames[0].ToString().ToUpper(); + var positiveCaps = confusionMatrix.PredictedClassesIndicators[0].ToString().ToUpper(); var numTruePos = confusionTable[0][0]; var numFalseNeg = confusionTable[0][1]; @@ -1642,7 +1648,7 @@ internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix sb.AppendFormat("PREDICTED {0}||", pad); string format = string.Format(" {{{0},{1}}} |", useNumbersInHeader ? 0 : 1, colWidth); for (int i = 0; i < numLabels; i++) - sb.AppendFormat(format, i, confusionMatrix.PredictedLabelNames[i]); + sb.AppendFormat(format, i, confusionMatrix.PredictedClassesIndicators[i]); sb.AppendLine(" Recall"); sb.AppendFormat("TRUTH {0}||", pad); for (int i = 0; i < numLabels; i++) @@ -1654,11 +1660,10 @@ internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix string.IsNullOrWhiteSpace(prefix) ? "N0" : "F1"); for (int i = 0; i < numLabels; i++) { - sb.AppendFormat(rowLabelFormat, i, confusionMatrix.PredictedLabelNames[i]); + sb.AppendFormat(rowLabelFormat, i, confusionMatrix.PredictedClassesIndicators[i]); for (int j = 0; j < numLabels; j++) sb.AppendFormat(format2, confusionTable[i][j]); - Double recall = confusionMatrix.RecallSums[i] > 0 ? confusionTable[i][i] / confusionMatrix.RecallSums[i] : 0; - sb.AppendFormat(" {0,5:F4}", recall); + sb.AppendFormat(" {0,5:F4}", confusionMatrix.PerClassRecall[i]); sb.AppendLine(); } sb.AppendFormat(" {0}||", pad); @@ -1668,10 +1673,8 @@ internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix sb.AppendFormat("Precision {0}||", pad); format = string.Format("{{0,{0}:N4}} |", colWidth + 1); for (int i = 0; i < numLabels; i++) - { - Double precision = confusionMatrix.PrecisionSums[i] > 0 ? confusionTable[i][i] / confusionMatrix.PrecisionSums[i] : 0; - sb.AppendFormat(format, precision); - } + sb.AppendFormat(format, confusionMatrix.PerClassPrecision[i]); + sb.AppendLine(); return sb.ToString(); } @@ -1724,7 +1727,7 @@ public static void PrintWarnings(IChannel ch, Dictionary metr if (metrics.TryGetValue(MetricKinds.Warnings, out warnings)) { var warningTextColumn = warnings.Schema.GetColumnOrNull(MetricKinds.ColumnNames.WarningText); - if (warningTextColumn !=null && warningTextColumn.HasValue && warningTextColumn.Value.Type is TextDataViewType) + if (warningTextColumn != null && warningTextColumn.HasValue && warningTextColumn.Value.Type is TextDataViewType) { using (var cursor = warnings.GetRowCursor(warnings.Schema[MetricKinds.ColumnNames.WarningText])) { diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs index 1497d7df70..24a8bccea1 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Linq; using Microsoft.ML.Runtime; namespace Microsoft.ML.Data diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs index bf62602140..ce3b67787b 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs @@ -1,49 +1,79 @@ using System; using System.Collections.Generic; +using System.Collections.Immutable; using Microsoft.ML.Runtime; namespace Microsoft.ML.Data { + /// + /// Represents the confusion matrix of the classification results. + /// public sealed class ConfusionMatrix { - public double[] PrecisionSums { get; } - public double[] RecallSums { get; } - public ReadOnlyMemory[] Labels => PredictedLabelNames.ToArray(); //sefilipi: return as string[] + /// + /// The calculated value of precision for each class. + /// + public ImmutableArray PerClassPrecision { get; } - public double[][] ConfusionTableCounts { get; } + /// + /// The calculated value of recall for each class. + /// + public ImmutableArray PerClassRecall { get; } + + /// + /// The confsion matrix counts for the combinations actual class/predicted class. + /// The actual classes are in the rows of the table, and the predicted classes in the columns. + /// + public ImmutableArray> ConfusionTableCounts { get; } + + /// + /// The indicators of the predicted classes. + /// It might be the classes names, or just indices of the predicted classes, if the name mapping is missing. + /// + public IReadOnlyList> PredictedClassesIndicators; - internal readonly List> PredictedLabelNames; internal readonly bool Sampled; internal readonly bool Binary; private readonly IHost _host; - internal ConfusionMatrix(IHost host, double[] precisionSums, double[] recallSums, double[][] confusionTableCounts, - List> labelNames, bool sampled, bool binary) + internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double[][] confusionTableCounts, + List> labelNames, bool sampled, bool binary, DataViewSchema.Annotations classIndicators) { _host = host; - PrecisionSums = precisionSums; - RecallSums = recallSums; - ConfusionTableCounts = confusionTableCounts; - PredictedLabelNames = labelNames; + PerClassPrecision = precision.ToImmutableArray(); + PerClassRecall = recall.ToImmutableArray(); Sampled = sampled; Binary = binary; - } + PredictedClassesIndicators = labelNames.AsReadOnly(); - public override string ToString() => MetricWriter.GetConfusionTableAsString(this, false); + var classNumber = confusionTableCounts.Length; + List> counts = new List>(classNumber); - public double GetCountForClassPair(string predictedLabel, string actualLabel) - { - int predictedLabelIndex = PredictedLabelNames.IndexOf(predictedLabel.AsMemory()); - int actualLabelIndex = PredictedLabelNames.IndexOf(actualLabel.AsMemory()); + for (int i = 0; i < classNumber; i++) + counts.Add(ImmutableArray.Create(confusionTableCounts[i])); - _host.CheckParam(predictedLabelIndex > -1, nameof(predictedLabel), "Unknown given PredictedLabel."); - _host.CheckParam(actualLabelIndex > -1, nameof(actualLabel), "Unknown given ActualLabel."); + ConfusionTableCounts = counts.ToImmutableArray(); - _host.Assert(predictedLabelIndex < ConfusionTableCounts.Length && actualLabelIndex < ConfusionTableCounts.Length); + } + + /// + /// Returns a human readable representation of the confusion table. + /// + /// + public string GetFormattedConfusionTable() => MetricWriter.GetConfusionTableAsString(this, false); - return ConfusionTableCounts[actualLabelIndex][predictedLabelIndex]; + /// + /// Gets the confusion table count for the pair /. + /// + /// The index of the predicted label indicator, in the . + /// The index of the actual label indicator, in the . + /// + public double GetCountForClassPair(uint predictedClassIndicatorIndex, uint actualClassIndicatorIndex) + { + _host.Assert(predictedClassIndicatorIndex < ConfusionTableCounts.Length && actualClassIndicatorIndex < ConfusionTableCounts.Length); + return ConfusionTableCounts[(int)actualClassIndicatorIndex][(int)predictedClassIndicatorIndex]; } } } diff --git a/test/Microsoft.ML.Functional.Tests/Common.cs b/test/Microsoft.ML.Functional.Tests/Common.cs index 648f01351f..bf9b3dbe0e 100644 --- a/test/Microsoft.ML.Functional.Tests/Common.cs +++ b/test/Microsoft.ML.Functional.Tests/Common.cs @@ -183,6 +183,10 @@ public static void AssertMetrics(BinaryClassificationMetrics metrics) Assert.InRange(metrics.NegativeRecall, 0, 1); Assert.InRange(metrics.PositivePrecision, 0, 1); Assert.InRange(metrics.PositiveRecall, 0, 1); + + // Confusion matrix validations + Assert.NotNull(metrics.ConfusionMatrix); + AssertConfusionMatrix(metrics.ConfusionMatrix); } /// @@ -195,6 +199,10 @@ public static void AssertMetrics(CalibratedBinaryClassificationMetrics metrics) Assert.InRange(metrics.LogLoss, double.NegativeInfinity, 1); Assert.InRange(metrics.LogLossReduction, double.NegativeInfinity, 100); AssertMetrics(metrics as BinaryClassificationMetrics); + + // Confusion matrix validations + Assert.NotNull(metrics.ConfusionMatrix); + AssertConfusionMatrix(metrics.ConfusionMatrix); } /// @@ -219,6 +227,31 @@ public static void AssertMetrics(MulticlassClassificationMetrics metrics) Assert.InRange(metrics.MicroAccuracy, 0, 1); Assert.True(metrics.LogLoss >= 0); Assert.InRange(metrics.TopKAccuracy, 0, 1); + + // Confusion matrix validations + Assert.NotNull(metrics.ConfusionMatrix); + AssertConfusionMatrix(metrics.ConfusionMatrix); + + } + + internal static void AssertConfusionMatrix(ConfusionMatrix confusionMatrix) + { + // Confusion matrix validations + Assert.NotNull(confusionMatrix); + Assert.NotEmpty(confusionMatrix.ConfusionTableCounts); + Assert.NotEmpty(confusionMatrix.PerClassPrecision); + Assert.NotEmpty(confusionMatrix.PerClassRecall); + Assert.NotNull(confusionMatrix.PredictedClassesIndicators); + + foreach (var precision in confusionMatrix.PerClassPrecision) + Assert.InRange(precision, 0, 1); + + foreach (var recall in confusionMatrix.PerClassRecall) + Assert.InRange(recall, 0, 1); + + foreach (var value in confusionMatrix.PredictedClassesIndicators) + Assert.False(value.IsEmpty); + } /// diff --git a/test/Microsoft.ML.Functional.Tests/Evaluation.cs b/test/Microsoft.ML.Functional.Tests/Evaluation.cs index 13909716ae..2b4d0c136a 100644 --- a/test/Microsoft.ML.Functional.Tests/Evaluation.cs +++ b/test/Microsoft.ML.Functional.Tests/Evaluation.cs @@ -152,7 +152,7 @@ public void TrainAndEvaluateMulticlassClassification() .Append(mlContext.Transforms.Conversion.MapValueToKey("Label")) .AppendCacheCheckpoint(mlContext) .Append(mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy( - new SdcaMaximumEntropyMulticlassTrainer.Options { NumberOfThreads = 1})); + new SdcaMaximumEntropyMulticlassTrainer.Options { NumberOfThreads = 1 })); // Train the model. var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs index 30fdafe90b..908aef0449 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs @@ -67,5 +67,68 @@ void PredictAndMetadata() Assert.True(deciphieredLabel == input.Label); } } + + [Fact] + void MulticlassConfusionMatrixSlotNames() + { + var mlContext = new MLContext(seed: 1); + + var dataPath = GetDataPath(TestDatasets.irisData.trainFilename); + var data = mlContext.Data.LoadFromTextFile(dataPath, separatorChar: ','); + + var pipeline = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") + .Append(mlContext.Transforms.Conversion.MapValueToKey("Label")) + .Append(mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy( + new SdcaMaximumEntropyMulticlassTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1, })); + + var model = pipeline.Fit(data); + + // Evaluate the model. + var scoredData = model.Transform(data); + var metrics = mlContext.MulticlassClassification.Evaluate(scoredData); + + // Check that the SlotNames column is there. + Assert.NotNull(scoredData.Schema["Score"].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)); + + //Assert that the confusion matrix has the class names, in the Annotations of the Count column + Assert.Equal("Iris-setosa", metrics.ConfusionMatrix.PredictedClassesIndicators[0].ToString()); + Assert.Equal("Iris-versicolor", metrics.ConfusionMatrix.PredictedClassesIndicators[1].ToString()); + Assert.Equal("Iris-virginica", metrics.ConfusionMatrix.PredictedClassesIndicators[2].ToString()); + + var dataReader = mlContext.Data.CreateTextLoader( + columns: new[] + { + new TextLoader.Column("Label", DataKind.Single, 0), //notice the label being loaded as a float + new TextLoader.Column("Features", DataKind.Single, new[]{ new TextLoader.Range(1,4) }) + }, + hasHeader: false, + separatorChar: '\t' + ); + + var dataPath2 = GetDataPath(TestDatasets.iris.trainFilename); + var data2 = dataReader.Load(dataPath2); + + var singleTrainer = mlContext.BinaryClassification.Trainers.LightGbm(); + + // Create a training pipeline. + var pipelineUnamed = mlContext.Transforms.Conversion.MapValueToKey("Label") + .Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(singleTrainer)); + + // Train the model. + var model2 = pipelineUnamed.Fit(data2); + + // Evaluate the model. + var scoredData2 = model2.Transform(data2); + var metrics2 = mlContext.MulticlassClassification.Evaluate(scoredData2); + + // Check that the SlotNames column is not there. + Assert.Null(scoredData2.Schema["Score"].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)); + + //Assert that the confusion matrix has just ints, as class indicators, in the Annotations of the Count column + Assert.Equal("0", metrics2.ConfusionMatrix.PredictedClassesIndicators[0].ToString()); + Assert.Equal("1", metrics2.ConfusionMatrix.PredictedClassesIndicators[1].ToString()); + Assert.Equal("2", metrics2.ConfusionMatrix.PredictedClassesIndicators[2].ToString()); + + } } } From 0aea0fc547c8fe88d445eb71de5b508d95064d07 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 12 Apr 2019 09:49:57 -0700 Subject: [PATCH 3/9] Changing the single learner from LightGBM to FastTree so that the tests can run in WIndows debug. --- .../Scenarios/Api/Estimators/PredictAndMetadata.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs index 908aef0449..149dd7af31 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs @@ -108,7 +108,7 @@ void MulticlassConfusionMatrixSlotNames() var dataPath2 = GetDataPath(TestDatasets.iris.trainFilename); var data2 = dataReader.Load(dataPath2); - var singleTrainer = mlContext.BinaryClassification.Trainers.LightGbm(); + var singleTrainer = mlContext.BinaryClassification.Trainers.FastTree(); // Create a training pipeline. var pipelineUnamed = mlContext.Transforms.Conversion.MapValueToKey("Label") From c435990342f7758fb5aa2936d1c24fa69d0b7aa5 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 12 Apr 2019 14:18:20 -0700 Subject: [PATCH 4/9] Addressing Tom's comments --- src/Microsoft.ML.Core/Data/AnnotationUtils.cs | 2 +- .../Evaluators/Metrics/ConfusionMatrix.cs | 61 ++++++++++++++++--- test/Microsoft.ML.Functional.Tests/Common.cs | 11 +++- 3 files changed, 62 insertions(+), 12 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/AnnotationUtils.cs b/src/Microsoft.ML.Core/Data/AnnotationUtils.cs index 17671e6154..6810fa3959 100644 --- a/src/Microsoft.ML.Core/Data/AnnotationUtils.cs +++ b/src/Microsoft.ML.Core/Data/AnnotationUtils.cs @@ -440,7 +440,7 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co /// Label column. public static IEnumerable AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null) { - var cols = new List(); + var cols = new List(); if (labelColumn != null && labelColumn.Value.IsKey) { if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) && diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs index ce3b67787b..5fe189aafc 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; using System.Collections.Immutable; using Microsoft.ML.Runtime; @@ -30,32 +34,67 @@ public sealed class ConfusionMatrix /// The indicators of the predicted classes. /// It might be the classes names, or just indices of the predicted classes, if the name mapping is missing. /// - public IReadOnlyList> PredictedClassesIndicators; + public int NumberOfPredictedClasses { get; } + + /// + /// The associated with the Confusion Matrix counts. + /// It contains information about the predicted classes, if that information is available. + /// It can be the classes names, or their indices. + /// + public DataViewSchema.Annotations ClassIndicators { get; } + + /// + /// The indicators of the predicted classes. + /// It might be the classes names, or just indices of the predicted classes, if the name mapping is missing. + /// + internal IReadOnlyList> PredictedClassesIndicators; internal readonly bool Sampled; internal readonly bool Binary; private readonly IHost _host; + /// + /// The confusion matrix as a structured type, built from the counts of the confusion table idv. + /// + /// The IHost instance. + /// The values of precision per class. + /// The vales of recall per class. + /// The counts of the confusion table. The actual classes values are in the rows of the 2D array, + /// and the counts of the predicted classes are in the columns. + /// The predicted classes names, or the indexes of the classes, if the names are missing. + /// Whether the classes are sampled. + /// Whether the confusion table is the result of a binary classification. + /// The Annotations of the Count column, in the confusionTable idv. internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double[][] confusionTableCounts, List> labelNames, bool sampled, bool binary, DataViewSchema.Annotations classIndicators) { _host = host; + _host.AssertNonEmpty(precision); + _host.AssertNonEmpty(recall); + _host.AssertNonEmpty(confusionTableCounts); + _host.AssertNonEmpty(labelNames); + _host.AssertNonEmpty(precision); + + _host.Assert(precision.Length == confusionTableCounts.Length); + _host.Assert(recall.Length == confusionTableCounts.Length); + _host.Assert(labelNames.Count == confusionTableCounts.Length); + PerClassPrecision = precision.ToImmutableArray(); PerClassRecall = recall.ToImmutableArray(); Sampled = sampled; Binary = binary; PredictedClassesIndicators = labelNames.AsReadOnly(); - var classNumber = confusionTableCounts.Length; - List> counts = new List>(classNumber); + NumberOfPredictedClasses = confusionTableCounts.Length; + List> counts = new List>(NumberOfPredictedClasses); - for (int i = 0; i < classNumber; i++) + for (int i = 0; i < NumberOfPredictedClasses; i++) counts.Add(ImmutableArray.Create(confusionTableCounts[i])); ConfusionTableCounts = counts.ToImmutableArray(); - + ClassIndicators = classIndicators; } /// @@ -70,10 +109,14 @@ internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double /// The index of the predicted label indicator, in the . /// The index of the actual label indicator, in the . /// - public double GetCountForClassPair(uint predictedClassIndicatorIndex, uint actualClassIndicatorIndex) + public double GetCountForClassPair(int predictedClassIndicatorIndex, int actualClassIndicatorIndex) { - _host.Assert(predictedClassIndicatorIndex < ConfusionTableCounts.Length && actualClassIndicatorIndex < ConfusionTableCounts.Length); - return ConfusionTableCounts[(int)actualClassIndicatorIndex][(int)predictedClassIndicatorIndex]; + _host.CheckParam(predictedClassIndicatorIndex > -1 && predictedClassIndicatorIndex < ConfusionTableCounts.Length, + nameof(predictedClassIndicatorIndex), "Invalid index. Should be non-negative, less than the number of classes."); + _host.CheckParam(actualClassIndicatorIndex > -1 && actualClassIndicatorIndex < ConfusionTableCounts.Length, + nameof(actualClassIndicatorIndex), "Invalid index. Should be non-negative, less than the number of classes."); + + return ConfusionTableCounts[actualClassIndicatorIndex][predictedClassIndicatorIndex]; } } } diff --git a/test/Microsoft.ML.Functional.Tests/Common.cs b/test/Microsoft.ML.Functional.Tests/Common.cs index bf9b3dbe0e..0cf4383ffa 100644 --- a/test/Microsoft.ML.Functional.Tests/Common.cs +++ b/test/Microsoft.ML.Functional.Tests/Common.cs @@ -241,7 +241,7 @@ internal static void AssertConfusionMatrix(ConfusionMatrix confusionMatrix) Assert.NotEmpty(confusionMatrix.ConfusionTableCounts); Assert.NotEmpty(confusionMatrix.PerClassPrecision); Assert.NotEmpty(confusionMatrix.PerClassRecall); - Assert.NotNull(confusionMatrix.PredictedClassesIndicators); + Assert.NotNull(confusionMatrix.ClassIndicators); foreach (var precision in confusionMatrix.PerClassPrecision) Assert.InRange(precision, 0, 1); @@ -249,7 +249,14 @@ internal static void AssertConfusionMatrix(ConfusionMatrix confusionMatrix) foreach (var recall in confusionMatrix.PerClassRecall) Assert.InRange(recall, 0, 1); - foreach (var value in confusionMatrix.PredictedClassesIndicators) + + // Get the values in the annotations + var classIndicatorsBuffer = new VBuffer>(); + confusionMatrix.ClassIndicators.GetValue("SlotNames", ref classIndicatorsBuffer); + + var classIndicators = classIndicatorsBuffer.GetValues(); + + foreach (var value in classIndicators) Assert.False(value.IsEmpty); } From 7b34ac682e801fc7f05e3c212bead6066101b3b9 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 12 Apr 2019 16:15:24 -0700 Subject: [PATCH 5/9] Ivan's suggestions --- .../Evaluators/BinaryClassifierEvaluator.cs | 2 +- .../Evaluators/EvaluatorUtils.cs | 16 ++++++------- .../Metrics/BinaryClassificationMetrics.cs | 6 ++++- .../Evaluators/Metrics/ConfusionMatrix.cs | 24 +++++++++---------- .../MulticlassClassificationMetrics.cs | 7 ++++-- .../MulticlassClassificationEvaluator.cs | 2 +- test/Microsoft.ML.Functional.Tests/Common.cs | 2 +- 7 files changed, 33 insertions(+), 26 deletions(-) diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index 4ac5d0e773..283e85a203 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -1383,7 +1383,7 @@ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary /// Indicates whether the confusion table is for binary classification. /// Indicates how many classes to sample from the confusion table (-1 indicates no sampling) - public static string GetConfusionTable(IHost host, IDataView confusionDataView, out string weightedConfusionTable, bool binary = true, int sample = -1) + public static string GetConfusionTableAsFormattedString(IHost host, IDataView confusionDataView, out string weightedConfusionTable, bool binary = true, int sample = -1) { host.CheckValue(confusionDataView, nameof(confusionDataView)); host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2"); @@ -1356,13 +1356,13 @@ public static string GetConfusionTable(IHost host, IDataView confusionDataView, var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight); bool isWeighted = weightColumn.HasValue; - var confusionMatrix = GetConfusionTableAsType(host, confusionDataView, binary, sample, false); + var confusionMatrix = GetConfusionTable(host, confusionDataView, binary, sample, false); var confusionTableString = GetConfusionTableAsString(confusionMatrix, false); // if there is a Weight column, return the weighted confusionMatrix as well, from this function. if (isWeighted) { - confusionMatrix = GetConfusionTableAsType(host, confusionDataView, binary, sample, false); + confusionMatrix = GetConfusionTable(host, confusionDataView, binary, sample, true); weightedConfusionTable = GetConfusionTableAsString(confusionMatrix, true); } else @@ -1371,7 +1371,7 @@ public static string GetConfusionTable(IHost host, IDataView confusionDataView, return confusionTableString; } - public static ConfusionMatrix GetConfusionTableAsType(IHost host, IDataView confusionDataView, bool binary = true, int sample = -1, bool getWeighted = false) + public static ConfusionMatrix GetConfusionTable(IHost host, IDataView confusionDataView, bool binary = true, int sample = -1, bool getWeighted = false) { host.CheckValue(confusionDataView, nameof(confusionDataView)); host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2"); @@ -1584,7 +1584,7 @@ private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView dat internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix, bool isWeighted) { string prefix = isWeighted ? "Weighted " : ""; - int numLabels = confusionMatrix?.ConfusionTableCounts == null? 0: confusionMatrix.ConfusionTableCounts.Length; + int numLabels = confusionMatrix?.Counts == null? 0: confusionMatrix.Counts.Length; int colWidth = numLabels == 2 ? 8 : 5; int maxNameLen = confusionMatrix.PredictedClassesIndicators.Max(name => name.Length); @@ -1619,9 +1619,9 @@ internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix else rowLabelFormat = string.Format("{{1,{0}}} ||", paddingLen); - var confusionTable = confusionMatrix.ConfusionTableCounts; + var confusionTable = confusionMatrix.Counts; var sb = new StringBuilder(); - if (numLabels == 2 && confusionMatrix.Binary) + if (numLabels == 2 && confusionMatrix.IsBinary) { var positiveCaps = confusionMatrix.PredictedClassesIndicators[0].ToString().ToUpper(); @@ -1636,7 +1636,7 @@ internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix sb.AppendLine(); sb.AppendFormat("{0}Confusion table", prefix); - if (confusionMatrix.Sampled) + if (confusionMatrix.IsSampled) sb.AppendLine(" (sampled)"); else sb.AppendLine(); diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs index 24a8bccea1..92f6f2ad6b 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs @@ -74,6 +74,10 @@ public class BinaryClassificationMetrics /// public double AreaUnderPrecisionRecallCurve { get; } + /// + /// The 2 dimentional confusion matrix giving the counts of the + /// true positives, true negatives, false positives and false negatives for the two classes of data. + /// public ConfusionMatrix ConfusionMatrix { get; } private protected static T Fetch(IExceptionContext ectx, DataViewRow row, string name) @@ -97,7 +101,7 @@ internal BinaryClassificationMetrics(IHost host, DataViewRow overallResult, IDat NegativeRecall = Fetch(BinaryClassifierEvaluator.NegRecallName); F1Score = Fetch(BinaryClassifierEvaluator.F1); AreaUnderPrecisionRecallCurve = Fetch(BinaryClassifierEvaluator.AuPrc); - ConfusionMatrix = MetricWriter.GetConfusionTableAsType(host, confusionMatrix); + ConfusionMatrix = MetricWriter.GetConfusionTable(host, confusionMatrix); } [BestFriend] diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs index 5fe189aafc..f40364da05 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs @@ -28,7 +28,7 @@ public sealed class ConfusionMatrix /// The confsion matrix counts for the combinations actual class/predicted class. /// The actual classes are in the rows of the table, and the predicted classes in the columns. /// - public ImmutableArray> ConfusionTableCounts { get; } + public ImmutableArray> Counts { get; } /// /// The indicators of the predicted classes. @@ -49,8 +49,8 @@ public sealed class ConfusionMatrix /// internal IReadOnlyList> PredictedClassesIndicators; - internal readonly bool Sampled; - internal readonly bool Binary; + internal readonly bool IsSampled; + internal readonly bool IsBinary; private readonly IHost _host; @@ -63,11 +63,11 @@ public sealed class ConfusionMatrix /// The counts of the confusion table. The actual classes values are in the rows of the 2D array, /// and the counts of the predicted classes are in the columns. /// The predicted classes names, or the indexes of the classes, if the names are missing. - /// Whether the classes are sampled. - /// Whether the confusion table is the result of a binary classification. + /// Whether the classes are sampled. + /// Whether the confusion table is the result of a binary classification. /// The Annotations of the Count column, in the confusionTable idv. internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double[][] confusionTableCounts, - List> labelNames, bool sampled, bool binary, DataViewSchema.Annotations classIndicators) + List> labelNames, bool isSampled, bool isBinary, DataViewSchema.Annotations classIndicators) { _host = host; @@ -83,8 +83,8 @@ internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double PerClassPrecision = precision.ToImmutableArray(); PerClassRecall = recall.ToImmutableArray(); - Sampled = sampled; - Binary = binary; + IsSampled = isSampled; + IsBinary = isBinary; PredictedClassesIndicators = labelNames.AsReadOnly(); NumberOfPredictedClasses = confusionTableCounts.Length; @@ -93,7 +93,7 @@ internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double for (int i = 0; i < NumberOfPredictedClasses; i++) counts.Add(ImmutableArray.Create(confusionTableCounts[i])); - ConfusionTableCounts = counts.ToImmutableArray(); + Counts = counts.ToImmutableArray(); ClassIndicators = classIndicators; } @@ -111,12 +111,12 @@ internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double /// public double GetCountForClassPair(int predictedClassIndicatorIndex, int actualClassIndicatorIndex) { - _host.CheckParam(predictedClassIndicatorIndex > -1 && predictedClassIndicatorIndex < ConfusionTableCounts.Length, + _host.CheckParam(predictedClassIndicatorIndex > -1 && predictedClassIndicatorIndex < Counts.Length, nameof(predictedClassIndicatorIndex), "Invalid index. Should be non-negative, less than the number of classes."); - _host.CheckParam(actualClassIndicatorIndex > -1 && actualClassIndicatorIndex < ConfusionTableCounts.Length, + _host.CheckParam(actualClassIndicatorIndex > -1 && actualClassIndicatorIndex < Counts.Length, nameof(actualClassIndicatorIndex), "Invalid index. Should be non-negative, less than the number of classes."); - return ConfusionTableCounts[actualClassIndicatorIndex][predictedClassIndicatorIndex]; + return Counts[actualClassIndicatorIndex][predictedClassIndicatorIndex]; } } } diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs index 3b62225da8..c9d81c4e00 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs @@ -82,7 +82,10 @@ public sealed class MulticlassClassificationMetrics /// public IReadOnlyList PerClassLogLoss { get; } - // The confusion matrix. + /// + /// The confusion matrix giving the counts of the + /// predicted classes versus the actual classes. + /// public ConfusionMatrix ConfusionMatrix { get; } internal MulticlassClassificationMetrics(IHost host, DataViewRow overallResult, int topKPredictionCount, IDataView confusionMatrix) @@ -98,7 +101,7 @@ internal MulticlassClassificationMetrics(IHost host, DataViewRow overallResult, var perClassLogLoss = RowCursorUtils.Fetch>(host, overallResult, MulticlassClassificationEvaluator.PerClassLogLoss); PerClassLogLoss = perClassLogLoss.DenseValues().ToImmutableArray(); - ConfusionMatrix = MetricWriter.GetConfusionTableAsType(host, confusionMatrix, binary: false, perClassLogLoss.Length, getWeighted:false); + ConfusionMatrix = MetricWriter.GetConfusionTable(host, confusionMatrix, binary: false, perClassLogLoss.Length); } internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction, diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassificationEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassificationEvaluator.cs index dc4500a963..e616d19a55 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassificationEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassificationEvaluator.cs @@ -890,7 +890,7 @@ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary Date: Mon, 15 Apr 2019 16:38:36 -0700 Subject: [PATCH 6/9] Pr review comments --- .../Evaluators/EvaluatorUtils.cs | 13 +++--- .../Metrics/BinaryClassificationMetrics.cs | 4 +- .../Evaluators/Metrics/ConfusionMatrix.cs | 42 +++++++++---------- .../MulticlassClassificationMetrics.cs | 2 +- test/Microsoft.ML.Functional.Tests/Common.cs | 13 +----- 5 files changed, 32 insertions(+), 42 deletions(-) diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index c1870cbbd7..c9ce29ce00 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -1356,13 +1356,13 @@ public static string GetConfusionTableAsFormattedString(IHost host, IDataView co var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight); bool isWeighted = weightColumn.HasValue; - var confusionMatrix = GetConfusionTable(host, confusionDataView, binary, sample, false); + var confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, false); var confusionTableString = GetConfusionTableAsString(confusionMatrix, false); // if there is a Weight column, return the weighted confusionMatrix as well, from this function. if (isWeighted) { - confusionMatrix = GetConfusionTable(host, confusionDataView, binary, sample, true); + confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, true); weightedConfusionTable = GetConfusionTableAsString(confusionMatrix, true); } else @@ -1371,7 +1371,7 @@ public static string GetConfusionTableAsFormattedString(IHost host, IDataView co return confusionTableString; } - public static ConfusionMatrix GetConfusionTable(IHost host, IDataView confusionDataView, bool binary = true, int sample = -1, bool getWeighted = false) + public static ConfusionMatrix GetConfusionMatrix(IHost host, IDataView confusionDataView, bool binary = true, int sample = -1, bool getWeighted = false) { host.CheckValue(confusionDataView, nameof(confusionDataView)); host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2"); @@ -1384,7 +1384,8 @@ public static ConfusionMatrix GetConfusionTable(IHost host, IDataView confusionD // Get the counts names. var countColumn = confusionDataView.Schema[MetricKinds.ColumnNames.Count]; var type = countColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType; - host.Assert(type != null && type.IsKnownSize && type.ItemType is TextDataViewType, "The Count column does not have a text vector metadata of kind SlotNames."); + //"The Count column does not have a text vector metadata of kind SlotNames." + host.Assert(type != null && type.IsKnownSize && type.ItemType is TextDataViewType); // Get the class names var labelNames = default(VBuffer>); @@ -1435,7 +1436,7 @@ public static ConfusionMatrix GetConfusionTable(IHost host, IDataView confusionD var predictedLabelNames = GetPredictedLabelNames(in labelNames, labelIndexToConfIndexMap); bool sampled = numConfusionTableLabels < labelNames.Length; - return new ConfusionMatrix(host, precision, recall, confusionTable, predictedLabelNames, sampled, binary, countColumn.Annotations); + return new ConfusionMatrix(host, precision, recall, confusionTable, predictedLabelNames, sampled, binary); } private static List> GetPredictedLabelNames(in VBuffer> labelNames, int[] labelIndexToConfIndexMap) @@ -1584,7 +1585,7 @@ private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView dat internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix, bool isWeighted) { string prefix = isWeighted ? "Weighted " : ""; - int numLabels = confusionMatrix?.Counts == null? 0: confusionMatrix.Counts.Length; + int numLabels = confusionMatrix?.Counts == null? 0: confusionMatrix.Counts.Count; int colWidth = numLabels == 2 ? 8 : 5; int maxNameLen = confusionMatrix.PredictedClassesIndicators.Max(name => name.Length); diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs index 92f6f2ad6b..81cac683f1 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs @@ -75,7 +75,7 @@ public class BinaryClassificationMetrics public double AreaUnderPrecisionRecallCurve { get; } /// - /// The 2 dimentional confusion matrix giving the counts of the + /// The confusion matrix giving the counts of the /// true positives, true negatives, false positives and false negatives for the two classes of data. /// public ConfusionMatrix ConfusionMatrix { get; } @@ -101,7 +101,7 @@ internal BinaryClassificationMetrics(IHost host, DataViewRow overallResult, IDat NegativeRecall = Fetch(BinaryClassifierEvaluator.NegRecallName); F1Score = Fetch(BinaryClassifierEvaluator.F1); AreaUnderPrecisionRecallCurve = Fetch(BinaryClassifierEvaluator.AuPrc); - ConfusionMatrix = MetricWriter.GetConfusionTable(host, confusionMatrix); + ConfusionMatrix = MetricWriter.GetConfusionMatrix(host, confusionMatrix); } [BestFriend] diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs index f40364da05..552d82b701 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Linq; using Microsoft.ML.Runtime; namespace Microsoft.ML.Data @@ -17,18 +18,18 @@ public sealed class ConfusionMatrix /// /// The calculated value of precision for each class. /// - public ImmutableArray PerClassPrecision { get; } + public IReadOnlyList PerClassPrecision { get; } /// /// The calculated value of recall for each class. /// - public ImmutableArray PerClassRecall { get; } + public IReadOnlyList PerClassRecall { get; } /// /// The confsion matrix counts for the combinations actual class/predicted class. /// The actual classes are in the rows of the table, and the predicted classes in the columns. /// - public ImmutableArray> Counts { get; } + public IReadOnlyList> Counts { get; } /// /// The indicators of the predicted classes. @@ -36,13 +37,6 @@ public sealed class ConfusionMatrix /// public int NumberOfPredictedClasses { get; } - /// - /// The associated with the Confusion Matrix counts. - /// It contains information about the predicted classes, if that information is available. - /// It can be the classes names, or their indices. - /// - public DataViewSchema.Annotations ClassIndicators { get; } - /// /// The indicators of the predicted classes. /// It might be the classes names, or just indices of the predicted classes, if the name mapping is missing. @@ -53,9 +47,11 @@ public sealed class ConfusionMatrix internal readonly bool IsBinary; private readonly IHost _host; + private string _formattedConfusionMatrix; /// - /// The confusion matrix as a structured type, built from the counts of the confusion table idv. + /// The confusion matrix as a structured type, built from the counts of the confusion table that the or + /// the construct. /// /// The IHost instance. /// The values of precision per class. @@ -65,9 +61,8 @@ public sealed class ConfusionMatrix /// The predicted classes names, or the indexes of the classes, if the names are missing. /// Whether the classes are sampled. /// Whether the confusion table is the result of a binary classification. - /// The Annotations of the Count column, in the confusionTable idv. - internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double[][] confusionTableCounts, - List> labelNames, bool isSampled, bool isBinary, DataViewSchema.Annotations classIndicators) + internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double[][] confusionTableCounts, + List> labelNames, bool isSampled, bool isBinary) { _host = host; @@ -88,20 +83,25 @@ internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double PredictedClassesIndicators = labelNames.AsReadOnly(); NumberOfPredictedClasses = confusionTableCounts.Length; - List> counts = new List>(NumberOfPredictedClasses); + List> counts = new List>(NumberOfPredictedClasses); for (int i = 0; i < NumberOfPredictedClasses; i++) - counts.Add(ImmutableArray.Create(confusionTableCounts[i])); + counts.Add(confusionTableCounts[i].ToList().AsReadOnly()); - Counts = counts.ToImmutableArray(); - ClassIndicators = classIndicators; + Counts = counts.AsReadOnly(); } /// /// Returns a human readable representation of the confusion table. /// /// - public string GetFormattedConfusionTable() => MetricWriter.GetConfusionTableAsString(this, false); + public string GetFormattedConfusionTable() { + + if(_formattedConfusionMatrix == null) + _formattedConfusionMatrix = MetricWriter.GetConfusionTableAsString(this, false); + + return _formattedConfusionMatrix; + } /// /// Gets the confusion table count for the pair /. @@ -111,9 +111,9 @@ internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double /// public double GetCountForClassPair(int predictedClassIndicatorIndex, int actualClassIndicatorIndex) { - _host.CheckParam(predictedClassIndicatorIndex > -1 && predictedClassIndicatorIndex < Counts.Length, + _host.CheckParam(predictedClassIndicatorIndex > -1 && predictedClassIndicatorIndex < Counts.Count, nameof(predictedClassIndicatorIndex), "Invalid index. Should be non-negative, less than the number of classes."); - _host.CheckParam(actualClassIndicatorIndex > -1 && actualClassIndicatorIndex < Counts.Length, + _host.CheckParam(actualClassIndicatorIndex > -1 && actualClassIndicatorIndex < Counts.Count, nameof(actualClassIndicatorIndex), "Invalid index. Should be non-negative, less than the number of classes."); return Counts[actualClassIndicatorIndex][predictedClassIndicatorIndex]; diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs index c9d81c4e00..b91638c8f6 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs @@ -101,7 +101,7 @@ internal MulticlassClassificationMetrics(IHost host, DataViewRow overallResult, var perClassLogLoss = RowCursorUtils.Fetch>(host, overallResult, MulticlassClassificationEvaluator.PerClassLogLoss); PerClassLogLoss = perClassLogLoss.DenseValues().ToImmutableArray(); - ConfusionMatrix = MetricWriter.GetConfusionTable(host, confusionMatrix, binary: false, perClassLogLoss.Length); + ConfusionMatrix = MetricWriter.GetConfusionMatrix(host, confusionMatrix, binary: false, perClassLogLoss.Length); } internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction, diff --git a/test/Microsoft.ML.Functional.Tests/Common.cs b/test/Microsoft.ML.Functional.Tests/Common.cs index 21ce5b13b9..f9f1965b13 100644 --- a/test/Microsoft.ML.Functional.Tests/Common.cs +++ b/test/Microsoft.ML.Functional.Tests/Common.cs @@ -231,7 +231,7 @@ public static void AssertMetrics(MulticlassClassificationMetrics metrics) // Confusion matrix validations Assert.NotNull(metrics.ConfusionMatrix); AssertConfusionMatrix(metrics.ConfusionMatrix); - + } internal static void AssertConfusionMatrix(ConfusionMatrix confusionMatrix) @@ -241,7 +241,6 @@ internal static void AssertConfusionMatrix(ConfusionMatrix confusionMatrix) Assert.NotEmpty(confusionMatrix.Counts); Assert.NotEmpty(confusionMatrix.PerClassPrecision); Assert.NotEmpty(confusionMatrix.PerClassRecall); - Assert.NotNull(confusionMatrix.ClassIndicators); foreach (var precision in confusionMatrix.PerClassPrecision) Assert.InRange(precision, 0, 1); @@ -249,16 +248,6 @@ internal static void AssertConfusionMatrix(ConfusionMatrix confusionMatrix) foreach (var recall in confusionMatrix.PerClassRecall) Assert.InRange(recall, 0, 1); - - // Get the values in the annotations - var classIndicatorsBuffer = new VBuffer>(); - confusionMatrix.ClassIndicators.GetValue("SlotNames", ref classIndicatorsBuffer); - - var classIndicators = classIndicatorsBuffer.GetValues(); - - foreach (var value in classIndicators) - Assert.False(value.IsEmpty); - } /// From 889988c2d8a90eb397563b8f0fde4f72512c0a8e Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Tue, 16 Apr 2019 09:58:50 -0700 Subject: [PATCH 7/9] Asserting the host before assigning it --- src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs index 552d82b701..3b179d8857 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs @@ -64,6 +64,7 @@ public sealed class ConfusionMatrix internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double[][] confusionTableCounts, List> labelNames, bool isSampled, bool isBinary) { + Contracts.AssertValue(host); _host = host; _host.AssertNonEmpty(precision); From 74090d645abec268beda63dd00601c07064f6084 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Thu, 18 Apr 2019 13:48:11 -0700 Subject: [PATCH 8/9] fixing nits --- src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs index 3b179d8857..a0750237bb 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs @@ -51,7 +51,7 @@ public sealed class ConfusionMatrix /// /// The confusion matrix as a structured type, built from the counts of the confusion table that the or - /// the construct. + /// the constructor. /// /// The IHost instance. /// The values of precision per class. @@ -96,8 +96,8 @@ internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double /// Returns a human readable representation of the confusion table. /// /// - public string GetFormattedConfusionTable() { - + public string GetFormattedConfusionTable() + { if(_formattedConfusionMatrix == null) _formattedConfusionMatrix = MetricWriter.GetConfusionTableAsString(this, false); From 415937d2381ff05a94bd7c4643ab081905394aeb Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 19 Apr 2019 17:07:31 -0700 Subject: [PATCH 9/9] Artidoro's comments --- src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs | 2 +- .../Evaluators/Metrics/ConfusionMatrix.cs | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index c9ce29ce00..9b18e8be9d 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -1359,7 +1359,7 @@ public static string GetConfusionTableAsFormattedString(IHost host, IDataView co var confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, false); var confusionTableString = GetConfusionTableAsString(confusionMatrix, false); - // if there is a Weight column, return the weighted confusionMatrix as well, from this function. + // If there is a Weight column, return the weighted confusionMatrix as well, from this function. if (isWeighted) { confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, true); diff --git a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs index a0750237bb..148a3daf57 100644 --- a/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs +++ b/src/Microsoft.ML.Data/Evaluators/Metrics/ConfusionMatrix.cs @@ -26,8 +26,9 @@ public sealed class ConfusionMatrix public IReadOnlyList PerClassRecall { get; } /// - /// The confsion matrix counts for the combinations actual class/predicted class. - /// The actual classes are in the rows of the table, and the predicted classes in the columns. + /// The confusion matrix counts for the combinations actual class/predicted class. + /// The actual classes are in the rows of the table (stored in the outer ), and the predicted classes + /// in the columns(stored in the inner ). /// public IReadOnlyList> Counts { get; } @@ -35,7 +36,7 @@ public sealed class ConfusionMatrix /// The indicators of the predicted classes. /// It might be the classes names, or just indices of the predicted classes, if the name mapping is missing. /// - public int NumberOfPredictedClasses { get; } + public int NumberOfClasses { get; } /// /// The indicators of the predicted classes. @@ -83,10 +84,10 @@ internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double IsBinary = isBinary; PredictedClassesIndicators = labelNames.AsReadOnly(); - NumberOfPredictedClasses = confusionTableCounts.Length; - List> counts = new List>(NumberOfPredictedClasses); + NumberOfClasses = confusionTableCounts.Length; + List> counts = new List>(NumberOfClasses); - for (int i = 0; i < NumberOfPredictedClasses; i++) + for (int i = 0; i < NumberOfClasses; i++) counts.Add(confusionTableCounts[i].ToList().AsReadOnly()); Counts = counts.AsReadOnly();