|
6 | 6 | using System.Collections.Generic;
|
7 | 7 | using System.IO;
|
8 | 8 | using System.Linq;
|
| 9 | +using Microsoft.ML.Data; |
9 | 10 | using Microsoft.ML.Runtime;
|
10 | 11 |
|
11 | 12 | namespace Microsoft.ML.AutoML
|
@@ -70,27 +71,105 @@ public CrossValSummaryRunner(MLContext context,
|
70 | 71 |
|
71 | 72 | // Get the model from the best fold
|
72 | 73 | var bestFoldIndex = BestResultUtil.GetIndexOfBestScore(trainResults.Select(r => r.score), _optimizingMetricInfo.IsMaximizing);
|
| 74 | + // bestFoldIndex will be -1 if the optimization metric for all folds is NaN. |
| 75 | + // In this case, return model from the first fold. |
| 76 | + bestFoldIndex = bestFoldIndex != -1 ? bestFoldIndex : 0; |
73 | 77 | var bestModel = trainResults.ElementAt(bestFoldIndex).model;
|
74 | 78 |
|
75 |
| - // Get the metrics from the fold whose score is closest to avg of all fold scores |
76 |
| - var avgScore = trainResults.Average(r => r.score); |
| 79 | + // Get the average metrics across all folds |
| 80 | + var avgScore = GetAverageOfNonNaNScores(trainResults.Select(x => x.score)); |
77 | 81 | var indexClosestToAvg = GetIndexClosestToAverage(trainResults.Select(r => r.score), avgScore);
|
78 | 82 | var metricsClosestToAvg = trainResults[indexClosestToAvg].metrics;
|
| 83 | + var avgMetrics = GetAverageMetrics(trainResults.Select(x => x.metrics), metricsClosestToAvg); |
79 | 84 |
|
80 | 85 | // Build result objects
|
81 |
| - var suggestedPipelineRunDetail = new SuggestedPipelineRunDetail<TMetrics>(pipeline, avgScore, allRunsSucceeded, metricsClosestToAvg, bestModel, null); |
| 86 | + var suggestedPipelineRunDetail = new SuggestedPipelineRunDetail<TMetrics>(pipeline, avgScore, allRunsSucceeded, avgMetrics, bestModel, null); |
82 | 87 | var runDetail = suggestedPipelineRunDetail.ToIterationResult(_preFeaturizer);
|
83 | 88 | return (suggestedPipelineRunDetail, runDetail);
|
84 | 89 | }
|
85 | 90 |
|
| 91 | + private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetrics metricsClosestToAvg) |
| 92 | + { |
| 93 | + if (typeof(TMetrics) == typeof(BinaryClassificationMetrics)) |
| 94 | + { |
| 95 | + var newMetrics = metrics.Select(x => x as BinaryClassificationMetrics); |
| 96 | + Contracts.Assert(newMetrics != null); |
| 97 | + |
| 98 | + var result = new BinaryClassificationMetrics( |
| 99 | + auc: GetAverageOfNonNaNScores(newMetrics.Select(x => x.AreaUnderRocCurve)), |
| 100 | + accuracy: GetAverageOfNonNaNScores(newMetrics.Select(x => x.Accuracy)), |
| 101 | + positivePrecision: GetAverageOfNonNaNScores(newMetrics.Select(x => x.PositivePrecision)), |
| 102 | + positiveRecall: GetAverageOfNonNaNScores(newMetrics.Select(x => x.PositiveRecall)), |
| 103 | + negativePrecision: GetAverageOfNonNaNScores(newMetrics.Select(x => x.NegativePrecision)), |
| 104 | + negativeRecall: GetAverageOfNonNaNScores(newMetrics.Select(x => x.NegativeRecall)), |
| 105 | + f1Score: GetAverageOfNonNaNScores(newMetrics.Select(x => x.F1Score)), |
| 106 | + auprc: GetAverageOfNonNaNScores(newMetrics.Select(x => x.AreaUnderPrecisionRecallCurve)), |
| 107 | + // Return ConfusionMatrix from the fold closest to average score |
| 108 | + confusionMatrix: (metricsClosestToAvg as BinaryClassificationMetrics).ConfusionMatrix); |
| 109 | + return result as TMetrics; |
| 110 | + } |
| 111 | + |
| 112 | + if (typeof(TMetrics) == typeof(MulticlassClassificationMetrics)) |
| 113 | + { |
| 114 | + var newMetrics = metrics.Select(x => x as MulticlassClassificationMetrics); |
| 115 | + Contracts.Assert(newMetrics != null); |
| 116 | + |
| 117 | + var result = new MulticlassClassificationMetrics( |
| 118 | + accuracyMicro: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MicroAccuracy)), |
| 119 | + accuracyMacro: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MacroAccuracy)), |
| 120 | + logLoss: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLoss)), |
| 121 | + logLossReduction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLossReduction)), |
| 122 | + topKPredictionCount: newMetrics.ElementAt(0).TopKPredictionCount, |
| 123 | + topKAccuracy: GetAverageOfNonNaNScores(newMetrics.Select(x => x.TopKAccuracy)), |
| 124 | + // Return PerClassLogLoss and ConfusionMatrix from the fold closest to average score |
| 125 | + perClassLogLoss: (metricsClosestToAvg as MulticlassClassificationMetrics).PerClassLogLoss.ToArray(), |
| 126 | + confusionMatrix: (metricsClosestToAvg as MulticlassClassificationMetrics).ConfusionMatrix); |
| 127 | + return result as TMetrics; |
| 128 | + } |
| 129 | + |
| 130 | + if (typeof(TMetrics) == typeof(RegressionMetrics)) |
| 131 | + { |
| 132 | + var newMetrics = metrics.Select(x => x as RegressionMetrics); |
| 133 | + Contracts.Assert(newMetrics != null); |
| 134 | + |
| 135 | + var result = new RegressionMetrics( |
| 136 | + l1: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MeanAbsoluteError)), |
| 137 | + l2: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MeanSquaredError)), |
| 138 | + rms: GetAverageOfNonNaNScores(newMetrics.Select(x => x.RootMeanSquaredError)), |
| 139 | + lossFunction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LossFunction)), |
| 140 | + rSquared: GetAverageOfNonNaNScores(newMetrics.Select(x => x.RSquared))); |
| 141 | + return result as TMetrics; |
| 142 | + } |
| 143 | + |
| 144 | + throw new NotImplementedException($"Metric {typeof(TMetrics)} not implemented"); |
| 145 | + } |
| 146 | + |
| 147 | + private static double GetAverageOfNonNaNScores(IEnumerable<double> results) |
| 148 | + { |
| 149 | + var newResults = results.Where(r => !double.IsNaN(r)); |
| 150 | + // Return NaN iff all scores are NaN |
| 151 | + if (newResults.Count() == 0) |
| 152 | + return double.NaN; |
| 153 | + // Return average of non-NaN scores otherwise |
| 154 | + return newResults.Average(r => r); |
| 155 | + } |
| 156 | + |
86 | 157 | private static int GetIndexClosestToAverage(IEnumerable<double> values, double average)
|
87 | 158 | {
|
| 159 | + // Average will be NaN iff all values are NaN. |
| 160 | + // Return the first index in this case. |
| 161 | + if (double.IsNaN(average)) |
| 162 | + return 0; |
| 163 | + |
88 | 164 | int avgFoldIndex = -1;
|
89 | 165 | var smallestDistFromAvg = double.PositiveInfinity;
|
90 | 166 | for (var i = 0; i < values.Count(); i++)
|
91 | 167 | {
|
92 |
| - var distFromAvg = Math.Abs(values.ElementAt(i) - average); |
93 |
| - if (distFromAvg < smallestDistFromAvg || smallestDistFromAvg == double.PositiveInfinity) |
| 168 | + var value = values.ElementAt(i); |
| 169 | + if (double.IsNaN(value)) |
| 170 | + continue; |
| 171 | + var distFromAvg = Math.Abs(value - average); |
| 172 | + if (distFromAvg < smallestDistFromAvg) |
94 | 173 | {
|
95 | 174 | smallestDistFromAvg = distFromAvg;
|
96 | 175 | avgFoldIndex = i;
|
|
0 commit comments