-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add API to get Precision-Recall Curve data #3039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
/// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param> | ||
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param> | ||
/// <returns>The precision recall curve data points.</returns> | ||
public IDataView GetPrecisionRecallCurve(IDataView data, string label, string score, string probability, string predictedLabel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unit tests & sample is tbd #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param> | ||
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param> | ||
/// <returns>The precision recall curve data points.</returns> | ||
public IDataView GetPrecisionRecallCurve(IDataView data, string label, string score, string probability, string predictedLabel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you plan add mlContext.BinaryClassification.GetPresicionRecallCurve
? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Codecov Report
@@ Coverage Diff @@
## master #3039 +/- ##
==========================================
+ Coverage 72.48% 72.52% +0.04%
==========================================
Files 802 807 +5
Lines 144039 144476 +437
Branches 16179 16190 +11
==========================================
+ Hits 104404 104784 +380
- Misses 35221 35287 +66
+ Partials 4414 4405 -9
|
/// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param> | ||
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param> | ||
/// <returns>The precision recall curve data points.</returns> | ||
public IDataView GetPrecisionRecallCurve(IDataView data, string label, string score, string probability, string predictedLabel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ganik, thanks for working on this! I view the API in its current form as having a few potential improvements...
The most obvious thing is this return value... an IDataView
. Probably not a very good transmission format for metrics results.
It may be informative to note how we surface other metrics results in our public API. You'll note that we do have IEvaluator
(which is an internal interface) return metrics via IDataView
(because, it needs to communicate them in some common format, and also because that's one of the types that entry-points uses), but note what we do with them for the sake of our public API.
machinelearning/src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs
Lines 87 to 98 in 5b22420
internal BinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overallResult) | |
{ | |
double Fetch(string name) => Fetch<double>(ectx, overallResult, name); | |
AreaUnderRocCurve = Fetch(BinaryClassifierEvaluator.Auc); | |
Accuracy = Fetch(BinaryClassifierEvaluator.Accuracy); | |
PositivePrecision = Fetch(BinaryClassifierEvaluator.PosPrecName); | |
PositiveRecall = Fetch(BinaryClassifierEvaluator.PosRecallName); | |
NegativePrecision = Fetch(BinaryClassifierEvaluator.NegPrecName); | |
NegativeRecall = Fetch(BinaryClassifierEvaluator.NegRecallName); | |
F1Score = Fetch(BinaryClassifierEvaluator.F1); | |
AreaUnderPrecisionRecallCurve = Fetch(BinaryClassifierEvaluator.AuPrc); | |
} |
We actually take the data out, and store them in an actual strongly typed structure, with intellisense, and whatnot.
We could have done what you did here -- just sort of take the IDataView
even though it was one whose structure we knew exactly what was in it, called it a day, and said, "here, good luck figuring out what to do with this," but we didn't do that because that would be a needlessly terrible experience. Here we caneasy to use, discoverable API for something as important as metrics, and there was an obvious way to do it, just as there are with this PR. So, I was expecting, and I kind of think people might appreciate something more closely resembling that, with something with actual types, since unlike in some scenarios (general dataflow) the types and column names are well defined.
Note, here is where this data view is composed, for your own reference, so you can see what goes in it.
machinelearning/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
Lines 373 to 392 in 5b22420
if (scores.Count > 0) | |
{ | |
var dvBldr = new ArrayDataViewBuilder(Host); | |
if (hasStrats) | |
{ | |
dvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, GetKeyValueGetter(dictionaries), (ulong)dictionaries.Length, prStratCol.ToArray()); | |
dvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, prStratVal.ToArray()); | |
} | |
dvBldr.AddColumn(Threshold, NumberDataViewType.Single, scores.ToArray()); | |
dvBldr.AddColumn(Precision, NumberDataViewType.Double, precision.ToArray()); | |
dvBldr.AddColumn(Recall, NumberDataViewType.Double, recall.ToArray()); | |
dvBldr.AddColumn(FalsePositiveRate, NumberDataViewType.Double, fpr.ToArray()); | |
if (weightedPrecision.Count > 0) | |
{ | |
dvBldr.AddColumn("Weighted " + Precision, NumberDataViewType.Double, weightedPrecision.ToArray()); | |
dvBldr.AddColumn("Weighted " + Recall, NumberDataViewType.Double, weightedRecall.ToArray()); | |
dvBldr.AddColumn("Weighted " + FalsePositiveRate, NumberDataViewType.Double, weightedFpr.ToArray()); | |
} | |
result.Add(PrCurve, dvBldr.GetDataView()); | |
} |
So, a few columns that are always there. We'll need to figure out what's up with these "strat" columns, that is, "stratification" columns (we don't want to expose this as such re: @rogancarr), so perhaps we skip that one, or figure out some other structure that's appropriate. (We'd probably group them, somehow, I don't quite know offhand what that's for. CV?) Anyway, you'll want to figure out what's up with that.
You could imagine some alternate structure, whereby this "stuff" is put into a strongly typed structure. In the short term you could copy the results of the IDataView
out into this structure -- that's a little inefficient but will suffice. Could be as simple as an array of a struct containing the relevant PR points.
The second problem with the API is a bit less fundamnetal, and probably easier to fix. You will also note that that you are discarding all the other results of evaluation, for no particular reason. What I had suggested in the original issue, and I still think is a good idea, is EvaluateWithPRCurve
that still produces the metric results (since they're already right there and being calculated alongside the PR curves!). So, maybe return the metrics, but also have the PR curve structure, whatever it is (not just an IDataView
!!), as an out
parameter. Or reverse the two, I don't know. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thank you for great suggestions as usually @TomFinley #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this idea: EvaluateWithPRCurve
One thing that Gani and I talked about was also being able to return the data for an ROC curve too. The original issue (and related v1 scenario) was specifically about a PR curve, but I think it should be on the table to return an ROC curve too. These are the two major curves that people may want to plot.
Also, we may want to think about being able to extend this design to any sort of curve someone may want to plot, as different fields of applications have slightly different target metrics.
In reply to: 267839855 [](ancestors = 267839855)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be done. ROC Curve I suggest to do post V1.
In reply to: 267921726 [](ancestors = 267921726,267839855)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much extra work is an ROC Curve on top of getting the PR Curve? It's weird to have one but not the other from the API.
In reply to: 268027498 [](ancestors = 268027498,267921726,267839855)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, my bad, BinaryPrecisionRecallDataPoint contains FPR as well and Recall is actually TPR. This should be enough for ROC
In reply to: 268340685 [](ancestors = 268340685,268027498,267921726,267839855)
/// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param> | ||
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param> | ||
/// <returns>The precision recall curve data points.</returns> | ||
public IDataView GetPrecisionRecallCurve(IDataView data, string label, string score, string probability, string predictedLabel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
predictedLabel [](start = 120, length = 14)
Does the BinaryClassifierEvaluator
need that column to evaluate? (same question for the API below). #ByDesign
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incidentally @yaeldekel is correct, the original command line didn't need it -- it would use it if required, but it would not insist on it. That we have it for the API is somewhat unfortunate, but not that big of a deal. In future we could in fact improve this story: we could have this overload, but also another overload that doesn't take it, and takes a threshold for it. Thereby capturing the existing functionality of this C# API, while retaining the "flexibility" we used to achieve in our command line.
That however is beyond the scope of @ganik's work here, and not something we should add now. #ByDesign
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param> | ||
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param> | ||
/// <returns>The precision recall curve data points.</returns> | ||
public IDataView GetPrecisionRecallCurve(IDataView data, string label, string score, string probability, string predictedLabel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
string label, string score, string probability, string predictedLabel [](start = 65, length = 69)
Do these column names have default values? If so, let's specify them here as defaults. #ByDesign
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are the internal methods that are called by the public Evaluate
API. The public API does have default values for the column names.
In reply to: 267922053 [](ancestors = 267922053)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -98,7 +100,16 @@ public static void SdcaBinaryClassification() | |||
Console.WriteLine($"Negative Precision: {metrics.NegativePrecision}"); // 0.87 | |||
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall}"); // 0.91 | |||
Console.WriteLine($"Positive Precision: {metrics.PositivePrecision}"); // 0.65 | |||
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.55 | |||
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.55 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[](start = 84, length = 6)
extra white space? #Closed
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.55 | ||
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.55 | ||
|
||
var prData = prCurve.GetEnumerator(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prCurve [](start = 25, length = 7)
why not just foreach(var data in prCurve) ? #Closed
/// <summary> | ||
/// This class represents one data point on Precision-Recall curve. | ||
/// </summary> | ||
public sealed class PrecisionRecallMetrics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PrecisionRecallMetrics [](start = 24, length = 22)
This appears to be specific to binary PR, yet the name is generic (you use binaryevaluator.)
Moreover, the summary does not match the name. For example, if this is a point, it should be named a such. We have more PR metrics besides those:
If it is for binary, don't you need to also supply positive/negative {recall/precistion|} metrics?
If it multiclass, you would need a pr matrix afaik #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed name to reflect its not a metrics but data point for binary classifications
In reply to: 268281758 [](ancestors = 268281758)
this BinaryClassificationCatalog catalog, | ||
DataView<T> data, | ||
Func<T, Scalar<bool>> label, | ||
Func<T, (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel)> pred, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
predictedLabel [](start = 82, length = 14)
is this always a bool? don't we sometimes predict ints or keys?
Consider renaming the parameter to a full word such as prediction #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a twin method for public static CalibratedBinaryClassificationMetrics Evaluate(..) on line 28. It does everything the same way plus spits out PR Curve. The scope of this PR is to add spitting of the PRCurve, To change parameter types or names will mean refactoring the original Evaluate(..) method which is out oof scope for this PR
In reply to: 268282834 [](ancestors = 268282834)
DataView<T> data, | ||
Func<T, Scalar<bool>> label, | ||
Func<T, (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel)> pred, | ||
out IEnumerable<PrecisionRecallMetrics> prCurve) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prCurve [](start = 52, length = 7)
find the out param awkward. Can we extend the CalibratedBinaryClassificationMetrics instead? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would mean unsealing CalibratedBinaryClassificationMetrics class, its sealed currently
In reply to: 268283796 [](ancestors = 268283796)
string predName = indexer.Get(predCol); | ||
|
||
var eval = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { NumRocExamples = 100000 }); | ||
return eval.EvaluateWithPRCurve(data.AsDynamic, labelName, scoreName, probName, predName, out prCurve); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super simillar to another new method. Refactor? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would mean refactoring both Evaluate(..) on lines 28 & 102. As these new methods are based off them. I dont want to make bigger changes at this point and just want to follow the established pattern here to minimize risks
In reply to: 268285448 [](ancestors = 268285448)
List<PrecisionRecallMetrics> prCurveResult = new List<PrecisionRecallMetrics>(); | ||
using (var cursor = prCurveView.GetRowCursorForAllColumns()) | ||
{ | ||
while (cursor.MoveNext()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to put in the PR what we had talked about, it is important that we get the getters before the tight loop, and we use the getters inside the tight loop to retrieve the values, since constructing the delegates is extremely expensive but using them is very cheap. (See here for more architectural info on why this is so.) #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// </summary> | ||
public double FalsePositiveRate { get; } | ||
|
||
internal BinaryPrecisionRecallDataPoint(ValueGetter<float> thresholdGetter, ValueGetter<double> precisionGetter, ValueGetter<double> recallGetter, ValueGetter<double> fprGetter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ValueGetter precisionGetter [](start = 84, length = 35)
I would recommend that all logic persuant to the getters remain in the same block of code as the cursor (anything else is just a pointless increase of complexity), but this at least avoids the main problem so, I guess we can fix it later. #ByDesign
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// <summary> | ||
/// Gets the true positive rate for the current threshold. | ||
/// </summary> | ||
public double TruePositiveRate => Recall; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
public double TruePositiveRate => Recall; [](start = 8, length = 41)
So, what is the point of this? I see a class, one property exists merely to point to another, I think, what is up? #ByDesign
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to make life easier for users to plot ROC curve. ROC curve is plotted as TPR vs FPR. We do have FPR, but spent some time with Rogan looking for TPR in order to get ROC data. It appears Recall is synonymous with TPR. So for users to avoid same confusion and to reduce number of potential inqueries in future decided to have field TPR in here
In reply to: 268370748 [](ancestors = 268370748)
var precisionColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Precision); | ||
var recallColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Recall); | ||
var fprColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.FalsePositiveRate); | ||
Host.Assert(thresholdColumn != null); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Host.Assert [](start = 12, length = 11)
In future, consider AssertValue
to make this a bit more elegant. #ByDesign
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tried it, didnt work here bcs DataViewSchema.Column is not a reference type, its a struct
In reply to: 268380273 [](ancestors = 268380273)
@@ -18,6 +18,7 @@ | |||
using Microsoft.ML.Trainers.Recommender; | |||
using Xunit; | |||
using Xunit.Abstractions; | |||
using System.Collections.Generic; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We generally like to keep our namespaces sorted. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this @ganik! This is one of the last core scenarios we wanted to unlock.
fixes #2645