Skip to content

Add API to get Precision-Recall Curve data #3039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 23, 2019
Merged

Add API to get Precision-Recall Curve data #3039

merged 9 commits into from
Mar 23, 2019

Conversation

ganik
Copy link
Member

@ganik ganik commented Mar 20, 2019

fixes #2645

/// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param>
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
/// <returns>The precision recall curve data points.</returns>
public IDataView GetPrecisionRecallCurve(IDataView data, string label, string score, string probability, string predictedLabel)
Copy link
Member Author

@ganik ganik Mar 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests & sample is tbd #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample is done, tests tbd


In reply to: 267591483 [](ancestors = 267591483)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test is added


In reply to: 268037217 [](ancestors = 268037217,267591483)

@ganik ganik changed the title Add GetPrecisionRecallCurve API [WIP] Add GetPrecisionRecallCurve API Mar 20, 2019
/// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param>
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
/// <returns>The precision recall curve data points.</returns>
public IDataView GetPrecisionRecallCurve(IDataView data, string label, string score, string probability, string predictedLabel)
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Mar 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan add mlContext.BinaryClassification.GetPresicionRecallCurve? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be added


In reply to: 267595498 [](ancestors = 267595498)

@codecov
Copy link

codecov bot commented Mar 21, 2019

Codecov Report

Merging #3039 into master will increase coverage by 0.04%.
The diff coverage is 62.94%.

@@            Coverage Diff             @@
##           master    #3039      +/-   ##
==========================================
+ Coverage   72.48%   72.52%   +0.04%     
==========================================
  Files         802      807       +5     
  Lines      144039   144476     +437     
  Branches    16179    16190      +11     
==========================================
+ Hits       104404   104784     +380     
- Misses      35221    35287      +66     
+ Partials     4414     4405       -9
Flag Coverage Δ
#Debug 72.52% <62.94%> (+0.04%) ⬆️
#production 68.14% <56.84%> (+0.02%) ⬆️
#test 88.77% <100%> (+0.07%) ⬆️
Impacted Files Coverage Δ
src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs 73.57% <ø> (ø) ⬆️
...est/Microsoft.ML.StaticPipelineTesting/Training.cs 99.3% <100%> (+0.01%) ⬆️
...crosoft.ML.StaticPipe/EvaluatorStaticExtensions.cs 86.15% <47.05%> (-13.85%) ⬇️
...ft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs 77.19% <56.52%> (-2.09%) ⬇️
...aluators/Metrics/BinaryPrecisionRecallDataPoint.cs 75% <75%> (ø)
.../Microsoft.ML.Data/Model/ModelOperationsCatalog.cs 91.73% <0%> (-5.54%) ⬇️
src/Microsoft.ML.Data/TrainCatalog.cs 82.91% <0%> (-1.2%) ⬇️
...soft.ML.Data/DataLoadSave/DataOperationsCatalog.cs 72.92% <0%> (-0.32%) ⬇️
...enarios/Api/Estimators/TrainSaveModelAndPredict.cs 95.23% <0%> (-0.22%) ⬇️
...s/Api/CookbookSamples/CookbookSamplesDynamicApi.cs 93.49% <0%> (-0.05%) ⬇️
... and 39 more

/// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param>
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
/// <returns>The precision recall curve data points.</returns>
public IDataView GetPrecisionRecallCurve(IDataView data, string label, string score, string probability, string predictedLabel)
Copy link
Contributor

@TomFinley TomFinley Mar 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ganik, thanks for working on this! I view the API in its current form as having a few potential improvements...

The most obvious thing is this return value... an IDataView. Probably not a very good transmission format for metrics results.

It may be informative to note how we surface other metrics results in our public API. You'll note that we do have IEvaluator (which is an internal interface) return metrics via IDataView (because, it needs to communicate them in some common format, and also because that's one of the types that entry-points uses), but note what we do with them for the sake of our public API.

internal BinaryClassificationMetrics(IExceptionContext ectx, DataViewRow overallResult)
{
double Fetch(string name) => Fetch<double>(ectx, overallResult, name);
AreaUnderRocCurve = Fetch(BinaryClassifierEvaluator.Auc);
Accuracy = Fetch(BinaryClassifierEvaluator.Accuracy);
PositivePrecision = Fetch(BinaryClassifierEvaluator.PosPrecName);
PositiveRecall = Fetch(BinaryClassifierEvaluator.PosRecallName);
NegativePrecision = Fetch(BinaryClassifierEvaluator.NegPrecName);
NegativeRecall = Fetch(BinaryClassifierEvaluator.NegRecallName);
F1Score = Fetch(BinaryClassifierEvaluator.F1);
AreaUnderPrecisionRecallCurve = Fetch(BinaryClassifierEvaluator.AuPrc);
}

We actually take the data out, and store them in an actual strongly typed structure, with intellisense, and whatnot.

We could have done what you did here -- just sort of take the IDataView even though it was one whose structure we knew exactly what was in it, called it a day, and said, "here, good luck figuring out what to do with this," but we didn't do that because that would be a needlessly terrible experience. Here we caneasy to use, discoverable API for something as important as metrics, and there was an obvious way to do it, just as there are with this PR. So, I was expecting, and I kind of think people might appreciate something more closely resembling that, with something with actual types, since unlike in some scenarios (general dataflow) the types and column names are well defined.

Note, here is where this data view is composed, for your own reference, so you can see what goes in it.

if (scores.Count > 0)
{
var dvBldr = new ArrayDataViewBuilder(Host);
if (hasStrats)
{
dvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, GetKeyValueGetter(dictionaries), (ulong)dictionaries.Length, prStratCol.ToArray());
dvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, prStratVal.ToArray());
}
dvBldr.AddColumn(Threshold, NumberDataViewType.Single, scores.ToArray());
dvBldr.AddColumn(Precision, NumberDataViewType.Double, precision.ToArray());
dvBldr.AddColumn(Recall, NumberDataViewType.Double, recall.ToArray());
dvBldr.AddColumn(FalsePositiveRate, NumberDataViewType.Double, fpr.ToArray());
if (weightedPrecision.Count > 0)
{
dvBldr.AddColumn("Weighted " + Precision, NumberDataViewType.Double, weightedPrecision.ToArray());
dvBldr.AddColumn("Weighted " + Recall, NumberDataViewType.Double, weightedRecall.ToArray());
dvBldr.AddColumn("Weighted " + FalsePositiveRate, NumberDataViewType.Double, weightedFpr.ToArray());
}
result.Add(PrCurve, dvBldr.GetDataView());
}

So, a few columns that are always there. We'll need to figure out what's up with these "strat" columns, that is, "stratification" columns (we don't want to expose this as such re: @rogancarr), so perhaps we skip that one, or figure out some other structure that's appropriate. (We'd probably group them, somehow, I don't quite know offhand what that's for. CV?) Anyway, you'll want to figure out what's up with that.

You could imagine some alternate structure, whereby this "stuff" is put into a strongly typed structure. In the short term you could copy the results of the IDataView out into this structure -- that's a little inefficient but will suffice. Could be as simple as an array of a struct containing the relevant PR points.

The second problem with the API is a bit less fundamnetal, and probably easier to fix. You will also note that that you are discarding all the other results of evaluation, for no particular reason. What I had suggested in the original issue, and I still think is a good idea, is EvaluateWithPRCurve that still produces the metric results (since they're already right there and being calculated alongside the PR curves!). So, maybe return the metrics, but also have the PR curve structure, whatever it is (not just an IDataView!!), as an out parameter. Or reverse the two, I don't know. #Resolved

Copy link
Member Author

@ganik ganik Mar 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thank you for great suggestions as usually @TomFinley #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea: EvaluateWithPRCurve

One thing that Gani and I talked about was also being able to return the data for an ROC curve too. The original issue (and related v1 scenario) was specifically about a PR curve, but I think it should be on the table to return an ROC curve too. These are the two major curves that people may want to plot.

Also, we may want to think about being able to extend this design to any sort of curve someone may want to plot, as different fields of applications have slightly different target metrics.


In reply to: 267839855 [](ancestors = 267839855)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be done. ROC Curve I suggest to do post V1.


In reply to: 267921726 [](ancestors = 267921726,267839855)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much extra work is an ROC Curve on top of getting the PR Curve? It's weird to have one but not the other from the API.


In reply to: 268027498 [](ancestors = 268027498,267921726,267839855)

Copy link
Member Author

@ganik ganik Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, my bad, BinaryPrecisionRecallDataPoint contains FPR as well and Recall is actually TPR. This should be enough for ROC


In reply to: 268340685 [](ancestors = 268340685,268027498,267921726,267839855)

/// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param>
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
/// <returns>The precision recall curve data points.</returns>
public IDataView GetPrecisionRecallCurve(IDataView data, string label, string score, string probability, string predictedLabel)
Copy link

@yaeldekel yaeldekel Mar 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predictedLabel [](start = 120, length = 14)

Does the BinaryClassifierEvaluator need that column to evaluate? (same question for the API below). #ByDesign

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is needed, for ex. for accuracy metrics.


In reply to: 267902785 [](ancestors = 267902785)

Copy link
Contributor

@TomFinley TomFinley Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incidentally @yaeldekel is correct, the original command line didn't need it -- it would use it if required, but it would not insist on it. That we have it for the API is somewhat unfortunate, but not that big of a deal. In future we could in fact improve this story: we could have this overload, but also another overload that doesn't take it, and takes a threshold for it. Thereby capturing the existing functionality of this C# API, while retaining the "flexibility" we used to achieve in our command line.

That however is beyond the scope of @ganik's work here, and not something we should add now. #ByDesign

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for explanation.


In reply to: 268327749 [](ancestors = 268327749)

/// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param>
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
/// <returns>The precision recall curve data points.</returns>
public IDataView GetPrecisionRecallCurve(IDataView data, string label, string score, string probability, string predictedLabel)
Copy link
Contributor

@rogancarr rogancarr Mar 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

string label, string score, string probability, string predictedLabel [](start = 65, length = 69)

Do these column names have default values? If so, let's specify them here as defaults. #ByDesign

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the internal methods that are called by the public Evaluate API. The public API does have default values for the column names.


In reply to: 267922053 [](ancestors = 267922053)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as Yael says


In reply to: 267926650 [](ancestors = 267926650,267922053)

@@ -98,7 +100,16 @@ public static void SdcaBinaryClassification()
Console.WriteLine($"Negative Precision: {metrics.NegativePrecision}"); // 0.87
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall}"); // 0.91
Console.WriteLine($"Positive Precision: {metrics.PositivePrecision}"); // 0.65
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.55
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.55
Copy link
Contributor

@glebuk glebuk Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  [](start = 84, length = 6)

extra white space? #Closed

Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.55
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.55

var prData = prCurve.GetEnumerator();
Copy link
Contributor

@glebuk glebuk Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prCurve [](start = 25, length = 7)

why not just foreach(var data in prCurve) ? #Closed

/// <summary>
/// This class represents one data point on Precision-Recall curve.
/// </summary>
public sealed class PrecisionRecallMetrics
Copy link
Contributor

@glebuk glebuk Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PrecisionRecallMetrics [](start = 24, length = 22)

This appears to be specific to binary PR, yet the name is generic (you use binaryevaluator.)
Moreover, the summary does not match the name. For example, if this is a point, it should be named a such. We have more PR metrics besides those:
If it is for binary, don't you need to also supply positive/negative {recall/precistion|} metrics?
If it multiclass, you would need a pr matrix afaik #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed name to reflect its not a metrics but data point for binary classifications


In reply to: 268281758 [](ancestors = 268281758)

this BinaryClassificationCatalog catalog,
DataView<T> data,
Func<T, Scalar<bool>> label,
Func<T, (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel)> pred,
Copy link
Contributor

@glebuk glebuk Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predictedLabel [](start = 82, length = 14)

is this always a bool? don't we sometimes predict ints or keys?
Consider renaming the parameter to a full word such as prediction #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a twin method for public static CalibratedBinaryClassificationMetrics Evaluate(..) on line 28. It does everything the same way plus spits out PR Curve. The scope of this PR is to add spitting of the PRCurve, To change parameter types or names will mean refactoring the original Evaluate(..) method which is out oof scope for this PR


In reply to: 268282834 [](ancestors = 268282834)

DataView<T> data,
Func<T, Scalar<bool>> label,
Func<T, (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel)> pred,
out IEnumerable<PrecisionRecallMetrics> prCurve)
Copy link
Contributor

@glebuk glebuk Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prCurve [](start = 52, length = 7)

find the out param awkward. Can we extend the CalibratedBinaryClassificationMetrics instead? #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would mean unsealing CalibratedBinaryClassificationMetrics class, its sealed currently


In reply to: 268283796 [](ancestors = 268283796)

string predName = indexer.Get(predCol);

var eval = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { NumRocExamples = 100000 });
return eval.EvaluateWithPRCurve(data.AsDynamic, labelName, scoreName, probName, predName, out prCurve);
Copy link
Contributor

@glebuk glebuk Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super simillar to another new method. Refactor? #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would mean refactoring both Evaluate(..) on lines 28 & 102. As these new methods are based off them. I dont want to make bigger changes at this point and just want to follow the established pattern here to minimize risks


In reply to: 268285448 [](ancestors = 268285448)

List<PrecisionRecallMetrics> prCurveResult = new List<PrecisionRecallMetrics>();
using (var cursor = prCurveView.GetRowCursorForAllColumns())
{
while (cursor.MoveNext())
Copy link
Contributor

@TomFinley TomFinley Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to put in the PR what we had talked about, it is important that we get the getters before the tight loop, and we use the getters inside the tight loop to retrieve the values, since constructing the delegates is extremely expensive but using them is very cheap. (See here for more architectural info on why this is so.) #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx


In reply to: 268291075 [](ancestors = 268291075)

@ganik ganik changed the title [WIP] Add GetPrecisionRecallCurve API Add GetPrecisionRecallCurve API Mar 22, 2019
@ganik ganik changed the title Add GetPrecisionRecallCurve API Add API to get Precision-Recall Curve data Mar 22, 2019
Copy link
Contributor

@glebuk glebuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

/// </summary>
public double FalsePositiveRate { get; }

internal BinaryPrecisionRecallDataPoint(ValueGetter<float> thresholdGetter, ValueGetter<double> precisionGetter, ValueGetter<double> recallGetter, ValueGetter<double> fprGetter)
Copy link
Contributor

@TomFinley TomFinley Mar 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ValueGetter precisionGetter [](start = 84, length = 35)

I would recommend that all logic persuant to the getters remain in the same block of code as the cursor (anything else is just a pointless increase of complexity), but this at least avoids the main problem so, I guess we can fix it later. #ByDesign

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, we can do it in this or in a separate PR.


In reply to: 268370725 [](ancestors = 268370725)

/// <summary>
/// Gets the true positive rate for the current threshold.
/// </summary>
public double TruePositiveRate => Recall;
Copy link
Contributor

@TomFinley TomFinley Mar 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public double TruePositiveRate => Recall; [](start = 8, length = 41)

So, what is the point of this? I see a class, one property exists merely to point to another, I think, what is up? #ByDesign

Copy link
Member Author

@ganik ganik Mar 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to make life easier for users to plot ROC curve. ROC curve is plotted as TPR vs FPR. We do have FPR, but spent some time with Rogan looking for TPR in order to get ROC data. It appears Recall is synonymous with TPR. So for users to avoid same confusion and to reduce number of potential inqueries in future decided to have field TPR in here


In reply to: 268370748 [](ancestors = 268370748)

var precisionColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Precision);
var recallColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Recall);
var fprColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.FalsePositiveRate);
Host.Assert(thresholdColumn != null);
Copy link
Contributor

@TomFinley TomFinley Mar 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Host.Assert [](start = 12, length = 11)

In future, consider AssertValue to make this a bit more elegant. #ByDesign

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried it, didnt work here bcs DataViewSchema.Column is not a reference type, its a struct


In reply to: 268380273 [](ancestors = 268380273)

@@ -18,6 +18,7 @@
using Microsoft.ML.Trainers.Recommender;
using Xunit;
using Xunit.Abstractions;
using System.Collections.Generic;
Copy link
Contributor

@TomFinley TomFinley Mar 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally like to keep our namespaces sorted. #Resolved

Copy link
Contributor

@TomFinley TomFinley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this @ganik! This is one of the last core scenarios we wanted to unlock.

@ganik ganik merged commit 122c319 into dotnet:master Mar 23, 2019
@ganik ganik deleted the ganik/pr1 branch April 1, 2019 20:56
@ghost ghost locked as resolved and limited conversation to collaborators Mar 23, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cannot produce a Precision-Recall Curve from our Binary Classification Evaluator
6 participants