Skip to content

Commit db84060

Browse files
authored
Handle NaN optimization metric in AutoML (#5031)
* Handle all folds returning NaN optimization metric in CrossValSummaryRunner * Handle NaN in calculation of average scores and index of closest fold * Handle all metrics being NaN in finding Best Run * nit * nit * nit * Handle all NaNs in best model selection * Return average metrics instead of metrics form the fold with optimizing metric closest to average * nit * Add PerClassLogLoss and ConfusionMatrix from the fold closest to average score * feedback
1 parent 062be28 commit db84060

File tree

5 files changed

+108
-9
lines changed

5 files changed

+108
-9
lines changed

src/Microsoft.ML.AutoML/Experiment/Runners/CrossValRunner.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ public CrossValRunner(MLContext context,
6666

6767
private static double CalcAverageScore(IEnumerable<double> scores)
6868
{
69-
if (scores.Any(s => double.IsNaN(s)))
70-
{
69+
var newScores = scores.Where(r => !double.IsNaN(r));
70+
// Return NaN iff all scores are NaN
71+
if (newScores.Count() == 0)
7172
return double.NaN;
72-
}
73-
return scores.Average();
73+
return newScores.Average();
7474
}
7575
}
7676
}

src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs

+84-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Collections.Generic;
77
using System.IO;
88
using System.Linq;
9+
using Microsoft.ML.Data;
910
using Microsoft.ML.Runtime;
1011

1112
namespace Microsoft.ML.AutoML
@@ -70,27 +71,105 @@ public CrossValSummaryRunner(MLContext context,
7071

7172
// Get the model from the best fold
7273
var bestFoldIndex = BestResultUtil.GetIndexOfBestScore(trainResults.Select(r => r.score), _optimizingMetricInfo.IsMaximizing);
74+
// bestFoldIndex will be -1 if the optimization metric for all folds is NaN.
75+
// In this case, return model from the first fold.
76+
bestFoldIndex = bestFoldIndex != -1 ? bestFoldIndex : 0;
7377
var bestModel = trainResults.ElementAt(bestFoldIndex).model;
7478

75-
// Get the metrics from the fold whose score is closest to avg of all fold scores
76-
var avgScore = trainResults.Average(r => r.score);
79+
// Get the average metrics across all folds
80+
var avgScore = GetAverageOfNonNaNScores(trainResults.Select(x => x.score));
7781
var indexClosestToAvg = GetIndexClosestToAverage(trainResults.Select(r => r.score), avgScore);
7882
var metricsClosestToAvg = trainResults[indexClosestToAvg].metrics;
83+
var avgMetrics = GetAverageMetrics(trainResults.Select(x => x.metrics), metricsClosestToAvg);
7984

8085
// Build result objects
81-
var suggestedPipelineRunDetail = new SuggestedPipelineRunDetail<TMetrics>(pipeline, avgScore, allRunsSucceeded, metricsClosestToAvg, bestModel, null);
86+
var suggestedPipelineRunDetail = new SuggestedPipelineRunDetail<TMetrics>(pipeline, avgScore, allRunsSucceeded, avgMetrics, bestModel, null);
8287
var runDetail = suggestedPipelineRunDetail.ToIterationResult(_preFeaturizer);
8388
return (suggestedPipelineRunDetail, runDetail);
8489
}
8590

91+
private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetrics metricsClosestToAvg)
92+
{
93+
if (typeof(TMetrics) == typeof(BinaryClassificationMetrics))
94+
{
95+
var newMetrics = metrics.Select(x => x as BinaryClassificationMetrics);
96+
Contracts.Assert(newMetrics != null);
97+
98+
var result = new BinaryClassificationMetrics(
99+
auc: GetAverageOfNonNaNScores(newMetrics.Select(x => x.AreaUnderRocCurve)),
100+
accuracy: GetAverageOfNonNaNScores(newMetrics.Select(x => x.Accuracy)),
101+
positivePrecision: GetAverageOfNonNaNScores(newMetrics.Select(x => x.PositivePrecision)),
102+
positiveRecall: GetAverageOfNonNaNScores(newMetrics.Select(x => x.PositiveRecall)),
103+
negativePrecision: GetAverageOfNonNaNScores(newMetrics.Select(x => x.NegativePrecision)),
104+
negativeRecall: GetAverageOfNonNaNScores(newMetrics.Select(x => x.NegativeRecall)),
105+
f1Score: GetAverageOfNonNaNScores(newMetrics.Select(x => x.F1Score)),
106+
auprc: GetAverageOfNonNaNScores(newMetrics.Select(x => x.AreaUnderPrecisionRecallCurve)),
107+
// Return ConfusionMatrix from the fold closest to average score
108+
confusionMatrix: (metricsClosestToAvg as BinaryClassificationMetrics).ConfusionMatrix);
109+
return result as TMetrics;
110+
}
111+
112+
if (typeof(TMetrics) == typeof(MulticlassClassificationMetrics))
113+
{
114+
var newMetrics = metrics.Select(x => x as MulticlassClassificationMetrics);
115+
Contracts.Assert(newMetrics != null);
116+
117+
var result = new MulticlassClassificationMetrics(
118+
accuracyMicro: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MicroAccuracy)),
119+
accuracyMacro: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MacroAccuracy)),
120+
logLoss: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLoss)),
121+
logLossReduction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLossReduction)),
122+
topKPredictionCount: newMetrics.ElementAt(0).TopKPredictionCount,
123+
topKAccuracy: GetAverageOfNonNaNScores(newMetrics.Select(x => x.TopKAccuracy)),
124+
// Return PerClassLogLoss and ConfusionMatrix from the fold closest to average score
125+
perClassLogLoss: (metricsClosestToAvg as MulticlassClassificationMetrics).PerClassLogLoss.ToArray(),
126+
confusionMatrix: (metricsClosestToAvg as MulticlassClassificationMetrics).ConfusionMatrix);
127+
return result as TMetrics;
128+
}
129+
130+
if (typeof(TMetrics) == typeof(RegressionMetrics))
131+
{
132+
var newMetrics = metrics.Select(x => x as RegressionMetrics);
133+
Contracts.Assert(newMetrics != null);
134+
135+
var result = new RegressionMetrics(
136+
l1: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MeanAbsoluteError)),
137+
l2: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MeanSquaredError)),
138+
rms: GetAverageOfNonNaNScores(newMetrics.Select(x => x.RootMeanSquaredError)),
139+
lossFunction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LossFunction)),
140+
rSquared: GetAverageOfNonNaNScores(newMetrics.Select(x => x.RSquared)));
141+
return result as TMetrics;
142+
}
143+
144+
throw new NotImplementedException($"Metric {typeof(TMetrics)} not implemented");
145+
}
146+
147+
private static double GetAverageOfNonNaNScores(IEnumerable<double> results)
148+
{
149+
var newResults = results.Where(r => !double.IsNaN(r));
150+
// Return NaN iff all scores are NaN
151+
if (newResults.Count() == 0)
152+
return double.NaN;
153+
// Return average of non-NaN scores otherwise
154+
return newResults.Average(r => r);
155+
}
156+
86157
private static int GetIndexClosestToAverage(IEnumerable<double> values, double average)
87158
{
159+
// Average will be NaN iff all values are NaN.
160+
// Return the first index in this case.
161+
if (double.IsNaN(average))
162+
return 0;
163+
88164
int avgFoldIndex = -1;
89165
var smallestDistFromAvg = double.PositiveInfinity;
90166
for (var i = 0; i < values.Count(); i++)
91167
{
92-
var distFromAvg = Math.Abs(values.ElementAt(i) - average);
93-
if (distFromAvg < smallestDistFromAvg || smallestDistFromAvg == double.PositiveInfinity)
168+
var value = values.ElementAt(i);
169+
if (double.IsNaN(value))
170+
continue;
171+
var distFromAvg = Math.Abs(value - average);
172+
if (distFromAvg < smallestDistFromAvg)
94173
{
95174
smallestDistFromAvg = distFromAvg;
96175
avgFoldIndex = i;

src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs

+6
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ public static RunDetail<TMetrics> GetBestRun<TMetrics>(IEnumerable<RunDetail<TMe
4141
if (!results.Any()) { return null; }
4242
var scores = results.Select(r => metricsAgent.GetScore(r.ValidationMetrics));
4343
var indexOfBestScore = GetIndexOfBestScore(scores, isMetricMaximizing);
44+
// indexOfBestScore will be -1 if the optimization metric for all models is NaN.
45+
// In this case, return the first model.
46+
indexOfBestScore = indexOfBestScore != -1 ? indexOfBestScore : 0;
4447
return results.ElementAt(indexOfBestScore);
4548
}
4649

@@ -51,6 +54,9 @@ public static CrossValidationRunDetail<TMetrics> GetBestRun<TMetrics>(IEnumerabl
5154
if (!results.Any()) { return null; }
5255
var scores = results.Select(r => r.Results.Average(x => metricsAgent.GetScore(x.ValidationMetrics)));
5356
var indexOfBestScore = GetIndexOfBestScore(scores, isMetricMaximizing);
57+
// indexOfBestScore will be -1 if the optimization metric for all models is NaN.
58+
// In this case, return the first model.
59+
indexOfBestScore = indexOfBestScore != -1 ? indexOfBestScore : 0;
5460
return results.ElementAt(indexOfBestScore);
5561
}
5662

src/Microsoft.ML.Data/Evaluators/Metrics/BinaryClassificationMetrics.cs

+7
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,12 @@ internal BinaryClassificationMetrics(double auc, double accuracy, double positiv
122122
F1Score = f1Score;
123123
AreaUnderPrecisionRecallCurve = auprc;
124124
}
125+
126+
internal BinaryClassificationMetrics(double auc, double accuracy, double positivePrecision, double positiveRecall,
127+
double negativePrecision, double negativeRecall, double f1Score, double auprc, ConfusionMatrix confusionMatrix)
128+
: this(auc, accuracy, positivePrecision, positiveRecall, negativePrecision, negativeRecall, f1Score, auprc)
129+
{
130+
ConfusionMatrix = confusionMatrix;
131+
}
125132
}
126133
}

src/Microsoft.ML.Data/Evaluators/Metrics/MulticlassClassificationMetrics.cs

+7
Original file line numberDiff line numberDiff line change
@@ -134,5 +134,12 @@ internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMa
134134
TopKAccuracy = topKAccuracy;
135135
PerClassLogLoss = perClassLogLoss.ToImmutableArray();
136136
}
137+
138+
internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction,
139+
int topKPredictionCount, double topKAccuracy, double[] perClassLogLoss, ConfusionMatrix confusionMatrix)
140+
: this(accuracyMicro, accuracyMacro, logLoss, logLossReduction, topKPredictionCount, topKAccuracy, perClassLogLoss)
141+
{
142+
ConfusionMatrix = confusionMatrix;
143+
}
137144
}
138145
}

0 commit comments

Comments
 (0)