-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Exposing the confusion matrix #3250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ab146fc
076c04f
0aea0fc
c435990
7b34ac6
6efb8d8
889988c
74090d6
415937d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1348,20 +1348,49 @@ internal static class MetricWriter | |
/// is assigned the string representation of the weighted confusion table. Otherwise it is assigned null.</param> | ||
/// <param name="binary">Indicates whether the confusion table is for binary classification.</param> | ||
/// <param name="sample">Indicates how many classes to sample from the confusion table (-1 indicates no sampling)</param> | ||
public static string GetConfusionTable(IHost host, IDataView confusionDataView, out string weightedConfusionTable, bool binary = true, int sample = -1) | ||
public static string GetConfusionTableAsFormattedString(IHost host, IDataView confusionDataView, out string weightedConfusionTable, bool binary = true, int sample = -1) | ||
{ | ||
host.CheckValue(confusionDataView, nameof(confusionDataView)); | ||
host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2"); | ||
|
||
// Get the class names. | ||
int countCol; | ||
host.Check(confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.Count, out countCol), "Did not find the count column"); | ||
var type = confusionDataView.Schema[countCol].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType; | ||
host.Check(type != null && type.IsKnownSize && type.ItemType is TextDataViewType, "The Count column does not have a text vector metadata of kind SlotNames."); | ||
var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight); | ||
bool isWeighted = weightColumn.HasValue; | ||
|
||
var confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, false); | ||
var confusionTableString = GetConfusionTableAsString(confusionMatrix, false); | ||
|
||
// If there is a Weight column, return the weighted confusionMatrix as well, from this function. | ||
if (isWeighted) | ||
{ | ||
confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, true); | ||
weightedConfusionTable = GetConfusionTableAsString(confusionMatrix, true); | ||
} | ||
else | ||
weightedConfusionTable = null; | ||
|
||
return confusionTableString; | ||
} | ||
|
||
public static ConfusionMatrix GetConfusionMatrix(IHost host, IDataView confusionDataView, bool binary = true, int sample = -1, bool getWeighted = false) | ||
{ | ||
host.CheckValue(confusionDataView, nameof(confusionDataView)); | ||
host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2"); | ||
|
||
// check that there is a Weight column, if isWeighted parameter is set to true. | ||
var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight); | ||
if (getWeighted) | ||
host.CheckParam(weightColumn.HasValue, nameof(getWeighted), "There is no Weight column in the confusionMatrix data view."); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, is this something we anticipate might actually happen? We control the creation of this matrix, as far as I am aware, and this is a purely internal utility. There are two possible answer to your questions, "no" in which case this should have been an assert (with the message as a comment, not a string!), or "yes" in which case this is something the user will potentially have to deal with, in which case having something that user might have a prayer of understanding would be required. (I do not view this message as actionable or even diagnostic in any way from a user's perspective.) Same for other check below... #Closed |
||
// Get the counts names. | ||
var countColumn = confusionDataView.Schema[MetricKinds.ColumnNames.Count]; | ||
var type = countColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType; | ||
//"The Count column does not have a text vector metadata of kind SlotNames." | ||
host.Assert(type != null && type.IsKnownSize && type.ItemType is TextDataViewType); | ||
|
||
// Get the class names | ||
var labelNames = default(VBuffer<ReadOnlyMemory<char>>); | ||
confusionDataView.Schema[countCol].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames); | ||
host.Check(labelNames.IsDense, "Slot names vector must be dense"); | ||
countColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames); | ||
host.Assert(labelNames.IsDense, "Slot names vector must be dense"); | ||
|
||
int numConfusionTableLabels = sample < 0 ? labelNames.Length : Math.Min(labelNames.Length, sample); | ||
|
||
|
@@ -1387,32 +1416,32 @@ public static string GetConfusionTable(IHost host, IDataView confusionDataView, | |
|
||
double[] precisionSums; | ||
double[] recallSums; | ||
var confusionTable = GetConfusionTableAsArray(confusionDataView, countCol, labelNames.Length, | ||
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums); | ||
double[][] confusionTable; | ||
|
||
var predictedLabelNames = GetPredictedLabelNames(in labelNames, labelIndexToConfIndexMap); | ||
var confusionTableString = GetConfusionTableAsString(confusionTable, recallSums, precisionSums, | ||
predictedLabelNames, | ||
sampled: numConfusionTableLabels < labelNames.Length, binary: binary); | ||
if (getWeighted) | ||
confusionTable = GetConfusionTableAsArray(confusionDataView, weightColumn.Value.Index, labelNames.Length, | ||
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums); | ||
else | ||
confusionTable = GetConfusionTableAsArray(confusionDataView, countColumn.Index, labelNames.Length, | ||
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums); | ||
|
||
int weightIndex; | ||
if (confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.Weight, out weightIndex)) | ||
double[] precision = new double[numConfusionTableLabels]; | ||
double[] recall = new double[numConfusionTableLabels]; | ||
for (int i = 0; i < numConfusionTableLabels; i++) | ||
{ | ||
confusionTable = GetConfusionTableAsArray(confusionDataView, weightIndex, labelNames.Length, | ||
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums); | ||
weightedConfusionTable = GetConfusionTableAsString(confusionTable, recallSums, precisionSums, | ||
predictedLabelNames, | ||
sampled: numConfusionTableLabels < labelNames.Length, prefix: "Weighted ", binary: binary); | ||
recall[i] = recallSums[i] > 0 ? confusionTable[i][i] / recallSums[i] : 0; | ||
precision[i] = precisionSums[i] > 0 ? confusionTable[i][i] / precisionSums[i] : 0; | ||
} | ||
else | ||
weightedConfusionTable = null; | ||
|
||
return confusionTableString; | ||
var predictedLabelNames = GetPredictedLabelNames(in labelNames, labelIndexToConfIndexMap); | ||
bool sampled = numConfusionTableLabels < labelNames.Length; | ||
|
||
return new ConfusionMatrix(host, precision, recall, confusionTable, predictedLabelNames, sampled, binary); | ||
} | ||
|
||
private static List<ReadOnlyMemory<char>> GetPredictedLabelNames(in VBuffer<ReadOnlyMemory<char>> labelNames, int[] labelIndexToConfIndexMap) | ||
{ | ||
List<ReadOnlyMemory<char>> result = new List<ReadOnlyMemory<char>>(); | ||
List <ReadOnlyMemory<char>> result = new List<ReadOnlyMemory<char>>(); | ||
var values = labelNames.GetValues(); | ||
for (int i = 0; i < values.Length; i++) | ||
{ | ||
|
@@ -1553,13 +1582,13 @@ private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView dat | |
} | ||
|
||
// Get a string representation of a confusion table. | ||
private static string GetConfusionTableAsString(double[][] confusionTable, double[] rowSums, double[] columnSums, | ||
List<ReadOnlyMemory<char>> predictedLabelNames, string prefix = "", bool sampled = false, bool binary = true) | ||
internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix, bool isWeighted) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Mix of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
{ | ||
int numLabels = Utils.Size(confusionTable); | ||
string prefix = isWeighted ? "Weighted " : ""; | ||
int numLabels = confusionMatrix?.Counts == null? 0: confusionMatrix.Counts.Count; | ||
|
||
int colWidth = numLabels == 2 ? 8 : 5; | ||
int maxNameLen = predictedLabelNames.Max(name => name.Length); | ||
int maxNameLen = confusionMatrix.PredictedClassesIndicators.Max(name => name.Length); | ||
// If the names are too long to fit in the column header, we back off to using class indices | ||
// in the header. This will also require putting the indices in the row, but it's better than | ||
// the alternative of having ambiguous abbreviated column headers, or having a table potentially | ||
|
@@ -1572,7 +1601,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl | |
{ | ||
// The row label will also include the index, so a user can easily match against the header. | ||
// In such a case, a label like "Foo" would be presented as something like "5. Foo". | ||
rowDigitLen = Math.Max(predictedLabelNames.Count - 1, 0).ToString().Length; | ||
rowDigitLen = Math.Max(confusionMatrix.PredictedClassesIndicators.Count - 1, 0).ToString().Length; | ||
Contracts.Assert(rowDigitLen >= 1); | ||
rowLabelLen += rowDigitLen + 2; | ||
} | ||
|
@@ -1591,10 +1620,11 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl | |
else | ||
rowLabelFormat = string.Format("{{1,{0}}} ||", paddingLen); | ||
|
||
var confusionTable = confusionMatrix.Counts; | ||
var sb = new StringBuilder(); | ||
if (numLabels == 2 && binary) | ||
if (numLabels == 2 && confusionMatrix.IsBinary) | ||
{ | ||
var positiveCaps = predictedLabelNames[0].ToString().ToUpper(); | ||
var positiveCaps = confusionMatrix.PredictedClassesIndicators[0].ToString().ToUpper(); | ||
|
||
var numTruePos = confusionTable[0][0]; | ||
var numFalseNeg = confusionTable[0][1]; | ||
|
@@ -1607,7 +1637,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl | |
|
||
sb.AppendLine(); | ||
sb.AppendFormat("{0}Confusion table", prefix); | ||
if (sampled) | ||
if (confusionMatrix.IsSampled) | ||
sb.AppendLine(" (sampled)"); | ||
else | ||
sb.AppendLine(); | ||
|
@@ -1619,7 +1649,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl | |
sb.AppendFormat("PREDICTED {0}||", pad); | ||
string format = string.Format(" {{{0},{1}}} |", useNumbersInHeader ? 0 : 1, colWidth); | ||
for (int i = 0; i < numLabels; i++) | ||
sb.AppendFormat(format, i, predictedLabelNames[i]); | ||
sb.AppendFormat(format, i, confusionMatrix.PredictedClassesIndicators[i]); | ||
sb.AppendLine(" Recall"); | ||
sb.AppendFormat("TRUTH {0}||", pad); | ||
for (int i = 0; i < numLabels; i++) | ||
|
@@ -1631,11 +1661,10 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl | |
string.IsNullOrWhiteSpace(prefix) ? "N0" : "F1"); | ||
for (int i = 0; i < numLabels; i++) | ||
{ | ||
sb.AppendFormat(rowLabelFormat, i, predictedLabelNames[i]); | ||
sb.AppendFormat(rowLabelFormat, i, confusionMatrix.PredictedClassesIndicators[i]); | ||
for (int j = 0; j < numLabels; j++) | ||
sb.AppendFormat(format2, confusionTable[i][j]); | ||
Double recall = rowSums[i] > 0 ? confusionTable[i][i] / rowSums[i] : 0; | ||
sb.AppendFormat(" {0,5:F4}", recall); | ||
sb.AppendFormat(" {0,5:F4}", confusionMatrix.PerClassRecall[i]); | ||
sb.AppendLine(); | ||
} | ||
sb.AppendFormat(" {0}||", pad); | ||
|
@@ -1645,10 +1674,8 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl | |
sb.AppendFormat("Precision {0}||", pad); | ||
format = string.Format("{{0,{0}:N4}} |", colWidth + 1); | ||
for (int i = 0; i < numLabels; i++) | ||
{ | ||
Double precision = columnSums[i] > 0 ? confusionTable[i][i] / columnSums[i] : 0; | ||
sb.AppendFormat(format, precision); | ||
} | ||
sb.AppendFormat(format, confusionMatrix.PerClassPrecision[i]); | ||
|
||
sb.AppendLine(); | ||
return sb.ToString(); | ||
} | ||
|
@@ -1701,7 +1728,7 @@ public static void PrintWarnings(IChannel ch, Dictionary<string, IDataView> metr | |
if (metrics.TryGetValue(MetricKinds.Warnings, out warnings)) | ||
{ | ||
var warningTextColumn = warnings.Schema.GetColumnOrNull(MetricKinds.ColumnNames.WarningText); | ||
if (warningTextColumn !=null && warningTextColumn.HasValue && warningTextColumn.Value.Type is TextDataViewType) | ||
if (warningTextColumn != null && warningTextColumn.HasValue && warningTextColumn.Value.Type is TextDataViewType) | ||
{ | ||
using (var cursor = warnings.GetRowCursor(warnings.Schema[MetricKinds.ColumnNames.WarningText])) | ||
{ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,6 +74,12 @@ public class BinaryClassificationMetrics | |
/// </remarks> | ||
public double AreaUnderPrecisionRecallCurve { get; } | ||
|
||
/// <summary> | ||
/// The <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a> giving the counts of the | ||
/// true positives, true negatives, false positives and false negatives for the two classes of data. | ||
/// </summary> | ||
public ConfusionMatrix ConfusionMatrix { get; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
need documentation #Closed |
||
|
||
private protected static T Fetch<T>(IExceptionContext ectx, DataViewRow row, string name) | ||
{ | ||
var column = row.Schema.GetColumnOrNull(name); | ||
|
@@ -84,9 +90,9 @@ private protected static T Fetch<T>(IExceptionContext ectx, DataViewRow row, str | |
return val; | ||
} | ||
|
||
internal BinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overallResult) | ||
internal BinaryClassificationMetrics(IHost host, DataViewRow overallResult, IDataView confusionMatrix) | ||
{ | ||
double Fetch(string name) => Fetch<double>(ectx, overallResult, name); | ||
double Fetch(string name) => Fetch<double>(host, overallResult, name); | ||
AreaUnderRocCurve = Fetch(BinaryClassifierEvaluator.Auc); | ||
Accuracy = Fetch(BinaryClassifierEvaluator.Accuracy); | ||
PositivePrecision = Fetch(BinaryClassifierEvaluator.PosPrecName); | ||
|
@@ -95,6 +101,7 @@ internal BinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overall | |
NegativeRecall = Fetch(BinaryClassifierEvaluator.NegRecallName); | ||
F1Score = Fetch(BinaryClassifierEvaluator.F1); | ||
AreaUnderPrecisionRecallCurve = Fetch(BinaryClassifierEvaluator.AuPrc); | ||
ConfusionMatrix = MetricWriter.GetConfusionMatrix(host, confusionMatrix); | ||
} | ||
|
||
[BestFriend] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nameof(confusionDataView)
? #ResolvedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can't do extrapolation in ChecParam, they need to be string literal.
In reply to: 275462838 [](ancestors = 275462838)