Skip to content

Exposing the confusion matrix #3250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/AnnotationUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co
public static IEnumerable<SchemaShape.Column> AnnotationsForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
{
var cols = new List<SchemaShape.Column>();
if (labelColumn.HasValue && labelColumn.Value.IsKey)
if (labelColumn != null && labelColumn.Value.IsKey)
{
if (labelColumn.Value.Annotations.TryFindColumn(Kinds.KeyValues, out var metaCol) &&
metaCol.Kind == SchemaShape.Column.VectorKind.Vector)
Expand Down
16 changes: 11 additions & 5 deletions src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -815,16 +815,18 @@ public CalibratedBinaryClassificationMetrics Evaluate(IDataView data, string lab
var resultDict = ((IEvaluator)this).Evaluate(roles);
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];

CalibratedBinaryClassificationMetrics result;
using (var cursor = overall.GetRowCursorForAllColumns())
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new CalibratedBinaryClassificationMetrics(Host, cursor);
result = new CalibratedBinaryClassificationMetrics(Host, cursor, confusionMatrix);
moved = cursor.MoveNext();
Host.Assert(!moved);
}

return result;
}

Expand Down Expand Up @@ -879,13 +881,14 @@ public CalibratedBinaryClassificationMetrics EvaluateWithPRCurve(
}
}
prCurve = prCurveResult;
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];

CalibratedBinaryClassificationMetrics result;
using (var cursor = overall.GetRowCursorForAllColumns())
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new CalibratedBinaryClassificationMetrics(Host, cursor);
result = new CalibratedBinaryClassificationMetrics(Host, cursor, confusionMatrix);
moved = cursor.MoveNext();
Host.Assert(!moved);
}
Expand Down Expand Up @@ -939,16 +942,18 @@ public BinaryClassificationMetrics Evaluate(IDataView data, string label, string
var resultDict = ((IEvaluator)this).Evaluate(roles);
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];

BinaryClassificationMetrics result;
using (var cursor = overall.GetRowCursorForAllColumns())
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new BinaryClassificationMetrics(Host, cursor);
result = new BinaryClassificationMetrics(Host, cursor, confusionMatrix);
moved = cursor.MoveNext();
Host.Assert(!moved);
}

return result;
}

Expand Down Expand Up @@ -985,6 +990,7 @@ public BinaryClassificationMetrics EvaluateWithPRCurve(
var prCurveView = resultDict[MetricKinds.PrCurve];
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];
var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];

var prCurveResult = new List<BinaryPrecisionRecallDataPoint>();
using (var cursor = prCurveView.GetRowCursorForAllColumns())
Expand All @@ -1007,7 +1013,7 @@ public BinaryClassificationMetrics EvaluateWithPRCurve(
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new BinaryClassificationMetrics(Host, cursor);
result = new BinaryClassificationMetrics(Host, cursor, confusionMatrix);
moved = cursor.MoveNext();
Host.Assert(!moved);
}
Expand Down Expand Up @@ -1377,7 +1383,7 @@ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary<str
fold = ColumnSelectingTransformer.CreateKeep(Host, fold, colsToKeep.ToArray());

string weightedConf;
var unweightedConf = MetricWriter.GetConfusionTable(Host, conf, out weightedConf);
var unweightedConf = MetricWriter.GetConfusionTableAsFormattedString(Host, conf, out weightedConf);
string weightedFold;
var unweightedFold = MetricWriter.GetPerFoldResults(Host, fold, out weightedFold);
ch.Assert(string.IsNullOrEmpty(weightedConf) == string.IsNullOrEmpty(weightedFold));
Expand Down
111 changes: 69 additions & 42 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1348,20 +1348,49 @@ internal static class MetricWriter
/// is assigned the string representation of the weighted confusion table. Otherwise it is assigned null.</param>
/// <param name="binary">Indicates whether the confusion table is for binary classification.</param>
/// <param name="sample">Indicates how many classes to sample from the confusion table (-1 indicates no sampling)</param>
public static string GetConfusionTable(IHost host, IDataView confusionDataView, out string weightedConfusionTable, bool binary = true, int sample = -1)
public static string GetConfusionTableAsFormattedString(IHost host, IDataView confusionDataView, out string weightedConfusionTable, bool binary = true, int sample = -1)
{
host.CheckValue(confusionDataView, nameof(confusionDataView));
host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2");

// Get the class names.
int countCol;
host.Check(confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.Count, out countCol), "Did not find the count column");
var type = confusionDataView.Schema[countCol].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType;
host.Check(type != null && type.IsKnownSize && type.ItemType is TextDataViewType, "The Count column does not have a text vector metadata of kind SlotNames.");
var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight);
bool isWeighted = weightColumn.HasValue;

var confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, false);
var confusionTableString = GetConfusionTableAsString(confusionMatrix, false);

// If there is a Weight column, return the weighted confusionMatrix as well, from this function.
if (isWeighted)
{
confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, true);
weightedConfusionTable = GetConfusionTableAsString(confusionMatrix, true);
}
else
weightedConfusionTable = null;

return confusionTableString;
}

public static ConfusionMatrix GetConfusionMatrix(IHost host, IDataView confusionDataView, bool binary = true, int sample = -1, bool getWeighted = false)
{
host.CheckValue(confusionDataView, nameof(confusionDataView));
host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2");

// check that there is a Weight column, if isWeighted parameter is set to true.
var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight);
if (getWeighted)
host.CheckParam(weightColumn.HasValue, nameof(getWeighted), "There is no Weight column in the confusionMatrix data view.");
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Apr 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confusionMatrix data view [](start = 110, length = 25)

nameof(confusionDataView) ? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't do extrapolation in ChecParam, they need to be string literal.


In reply to: 275462838 [](ancestors = 275462838)


Copy link
Contributor

@TomFinley TomFinley Apr 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, is this something we anticipate might actually happen? We control the creation of this matrix, as far as I am aware, and this is a purely internal utility.

There are two possible answer to your questions, "no" in which case this should have been an assert (with the message as a comment, not a string!), or "yes" in which case this is something the user will potentially have to deal with, in which case having something that user might have a prayer of understanding would be required. (I do not view this message as actionable or even diagnostic in any way from a user's perspective.)

Same for other check below... #Closed

// Get the counts names.
var countColumn = confusionDataView.Schema[MetricKinds.ColumnNames.Count];
var type = countColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType;
//"The Count column does not have a text vector metadata of kind SlotNames."
host.Assert(type != null && type.IsKnownSize && type.ItemType is TextDataViewType);

// Get the class names
var labelNames = default(VBuffer<ReadOnlyMemory<char>>);
confusionDataView.Schema[countCol].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames);
host.Check(labelNames.IsDense, "Slot names vector must be dense");
countColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames);
host.Assert(labelNames.IsDense, "Slot names vector must be dense");

int numConfusionTableLabels = sample < 0 ? labelNames.Length : Math.Min(labelNames.Length, sample);

Expand All @@ -1387,32 +1416,32 @@ public static string GetConfusionTable(IHost host, IDataView confusionDataView,

double[] precisionSums;
double[] recallSums;
var confusionTable = GetConfusionTableAsArray(confusionDataView, countCol, labelNames.Length,
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);
double[][] confusionTable;

var predictedLabelNames = GetPredictedLabelNames(in labelNames, labelIndexToConfIndexMap);
var confusionTableString = GetConfusionTableAsString(confusionTable, recallSums, precisionSums,
predictedLabelNames,
sampled: numConfusionTableLabels < labelNames.Length, binary: binary);
if (getWeighted)
confusionTable = GetConfusionTableAsArray(confusionDataView, weightColumn.Value.Index, labelNames.Length,
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);
else
confusionTable = GetConfusionTableAsArray(confusionDataView, countColumn.Index, labelNames.Length,
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);

int weightIndex;
if (confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.Weight, out weightIndex))
double[] precision = new double[numConfusionTableLabels];
double[] recall = new double[numConfusionTableLabels];
for (int i = 0; i < numConfusionTableLabels; i++)
{
confusionTable = GetConfusionTableAsArray(confusionDataView, weightIndex, labelNames.Length,
labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);
weightedConfusionTable = GetConfusionTableAsString(confusionTable, recallSums, precisionSums,
predictedLabelNames,
sampled: numConfusionTableLabels < labelNames.Length, prefix: "Weighted ", binary: binary);
recall[i] = recallSums[i] > 0 ? confusionTable[i][i] / recallSums[i] : 0;
precision[i] = precisionSums[i] > 0 ? confusionTable[i][i] / precisionSums[i] : 0;
}
else
weightedConfusionTable = null;

return confusionTableString;
var predictedLabelNames = GetPredictedLabelNames(in labelNames, labelIndexToConfIndexMap);
bool sampled = numConfusionTableLabels < labelNames.Length;

return new ConfusionMatrix(host, precision, recall, confusionTable, predictedLabelNames, sampled, binary);
}

private static List<ReadOnlyMemory<char>> GetPredictedLabelNames(in VBuffer<ReadOnlyMemory<char>> labelNames, int[] labelIndexToConfIndexMap)
{
List<ReadOnlyMemory<char>> result = new List<ReadOnlyMemory<char>>();
List <ReadOnlyMemory<char>> result = new List<ReadOnlyMemory<char>>();
var values = labelNames.GetValues();
for (int i = 0; i < values.Length; i++)
{
Expand Down Expand Up @@ -1553,13 +1582,13 @@ private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView dat
}

// Get a string representation of a confusion table.
private static string GetConfusionTableAsString(double[][] confusionTable, double[] rowSums, double[] columnSums,
List<ReadOnlyMemory<char>> predictedLabelNames, string prefix = "", bool sampled = false, bool binary = true)
internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix, bool isWeighted)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetConfusionTableAsString [](start = 31, length = 25)

Mix of ConfusionMatrix and ConfusionTable makes me
...
http://i.imgur.com/lgSDkvC.gif
...
Confuse!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is also confusionDataView


In reply to: 275081264 [](ancestors = 275081264)

{
int numLabels = Utils.Size(confusionTable);
string prefix = isWeighted ? "Weighted " : "";
int numLabels = confusionMatrix?.Counts == null? 0: confusionMatrix.Counts.Count;

int colWidth = numLabels == 2 ? 8 : 5;
int maxNameLen = predictedLabelNames.Max(name => name.Length);
int maxNameLen = confusionMatrix.PredictedClassesIndicators.Max(name => name.Length);
// If the names are too long to fit in the column header, we back off to using class indices
// in the header. This will also require putting the indices in the row, but it's better than
// the alternative of having ambiguous abbreviated column headers, or having a table potentially
Expand All @@ -1572,7 +1601,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
{
// The row label will also include the index, so a user can easily match against the header.
// In such a case, a label like "Foo" would be presented as something like "5. Foo".
rowDigitLen = Math.Max(predictedLabelNames.Count - 1, 0).ToString().Length;
rowDigitLen = Math.Max(confusionMatrix.PredictedClassesIndicators.Count - 1, 0).ToString().Length;
Contracts.Assert(rowDigitLen >= 1);
rowLabelLen += rowDigitLen + 2;
}
Expand All @@ -1591,10 +1620,11 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
else
rowLabelFormat = string.Format("{{1,{0}}} ||", paddingLen);

var confusionTable = confusionMatrix.Counts;
var sb = new StringBuilder();
if (numLabels == 2 && binary)
if (numLabels == 2 && confusionMatrix.IsBinary)
{
var positiveCaps = predictedLabelNames[0].ToString().ToUpper();
var positiveCaps = confusionMatrix.PredictedClassesIndicators[0].ToString().ToUpper();

var numTruePos = confusionTable[0][0];
var numFalseNeg = confusionTable[0][1];
Expand All @@ -1607,7 +1637,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl

sb.AppendLine();
sb.AppendFormat("{0}Confusion table", prefix);
if (sampled)
if (confusionMatrix.IsSampled)
sb.AppendLine(" (sampled)");
else
sb.AppendLine();
Expand All @@ -1619,7 +1649,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
sb.AppendFormat("PREDICTED {0}||", pad);
string format = string.Format(" {{{0},{1}}} |", useNumbersInHeader ? 0 : 1, colWidth);
for (int i = 0; i < numLabels; i++)
sb.AppendFormat(format, i, predictedLabelNames[i]);
sb.AppendFormat(format, i, confusionMatrix.PredictedClassesIndicators[i]);
sb.AppendLine(" Recall");
sb.AppendFormat("TRUTH {0}||", pad);
for (int i = 0; i < numLabels; i++)
Expand All @@ -1631,11 +1661,10 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
string.IsNullOrWhiteSpace(prefix) ? "N0" : "F1");
for (int i = 0; i < numLabels; i++)
{
sb.AppendFormat(rowLabelFormat, i, predictedLabelNames[i]);
sb.AppendFormat(rowLabelFormat, i, confusionMatrix.PredictedClassesIndicators[i]);
for (int j = 0; j < numLabels; j++)
sb.AppendFormat(format2, confusionTable[i][j]);
Double recall = rowSums[i] > 0 ? confusionTable[i][i] / rowSums[i] : 0;
sb.AppendFormat(" {0,5:F4}", recall);
sb.AppendFormat(" {0,5:F4}", confusionMatrix.PerClassRecall[i]);
sb.AppendLine();
}
sb.AppendFormat(" {0}||", pad);
Expand All @@ -1645,10 +1674,8 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
sb.AppendFormat("Precision {0}||", pad);
format = string.Format("{{0,{0}:N4}} |", colWidth + 1);
for (int i = 0; i < numLabels; i++)
{
Double precision = columnSums[i] > 0 ? confusionTable[i][i] / columnSums[i] : 0;
sb.AppendFormat(format, precision);
}
sb.AppendFormat(format, confusionMatrix.PerClassPrecision[i]);

sb.AppendLine();
return sb.ToString();
}
Expand Down Expand Up @@ -1701,7 +1728,7 @@ public static void PrintWarnings(IChannel ch, Dictionary<string, IDataView> metr
if (metrics.TryGetValue(MetricKinds.Warnings, out warnings))
{
var warningTextColumn = warnings.Schema.GetColumnOrNull(MetricKinds.ColumnNames.WarningText);
if (warningTextColumn !=null && warningTextColumn.HasValue && warningTextColumn.Value.Type is TextDataViewType)
if (warningTextColumn != null && warningTextColumn.HasValue && warningTextColumn.Value.Type is TextDataViewType)
{
using (var cursor = warnings.GetRowCursor(warnings.Schema[MetricKinds.ColumnNames.WarningText]))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ public class BinaryClassificationMetrics
/// </remarks>
public double AreaUnderPrecisionRecallCurve { get; }

/// <summary>
/// The <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a> giving the counts of the
/// true positives, true negatives, false positives and false negatives for the two classes of data.
/// </summary>
public ConfusionMatrix ConfusionMatrix { get; }
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Apr 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConfusionMatrix [](start = 31, length = 15)

need documentation #Closed


private protected static T Fetch<T>(IExceptionContext ectx, DataViewRow row, string name)
{
var column = row.Schema.GetColumnOrNull(name);
Expand All @@ -84,9 +90,9 @@ private protected static T Fetch<T>(IExceptionContext ectx, DataViewRow row, str
return val;
}

internal BinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overallResult)
internal BinaryClassificationMetrics(IHost host, DataViewRow overallResult, IDataView confusionMatrix)
{
double Fetch(string name) => Fetch<double>(ectx, overallResult, name);
double Fetch(string name) => Fetch<double>(host, overallResult, name);
AreaUnderRocCurve = Fetch(BinaryClassifierEvaluator.Auc);
Accuracy = Fetch(BinaryClassifierEvaluator.Accuracy);
PositivePrecision = Fetch(BinaryClassifierEvaluator.PosPrecName);
Expand All @@ -95,6 +101,7 @@ internal BinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overall
NegativeRecall = Fetch(BinaryClassifierEvaluator.NegRecallName);
F1Score = Fetch(BinaryClassifierEvaluator.F1);
AreaUnderPrecisionRecallCurve = Fetch(BinaryClassifierEvaluator.AuPrc);
ConfusionMatrix = MetricWriter.GetConfusionMatrix(host, confusionMatrix);
}

[BestFriend]
Expand Down
Loading