Return average metrics in AutoML CrossValSummaryRunner #5042
Labels
AutoML.NET
Automating various steps of the machine learning process
P2
Priority of the issue for triage purpose: Needs to be fixed at some point.
Related to #4663
See also #5031 (comment)
CrossValSummaryRunner
in AutoML gets invoked when the dataset size is less than 15k rows. It runs 10-fold cross validation and computes the average optimization metric across the folds. It then reports all the metrics from the fold that has the optimization metric closest to this average.machinelearning/src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs
Lines 76 to 78 in 214926f
A better way to do this would be to calculate the average of all the metrics, instantiate a new metrics class with these averages, and return that in the run details. This could reuse the code for calculating the average of non_NaN metrics from #5031. The following two things will need to be figured out and need more discussion:
PerClassLogLoss
: For multiclass classification, the ordering of the class labels may be different across the 10 folds. So, thePerClassLogLoss
metric from each fold will have different indices for the class labels. In this situation, a standardized ordering would need to be figured out and the averages calculated for each class accordingly.ConfusionMatrix
: For multiclass classification, the same problem as above needs solving. In addition, for both binary and multiclass classification, what confusion matrix is returned needs discussion. The distribution of class labels will be different across the folds, so what exactly is the "average" of a confusion matrix? Do we return a confusion matrix at all? Do we just return the confusion matrix from the fold with optimization metric closest to the average (current behavior)? If we are going this route, the confusion matrix property in the metrics classes will need to be made internally settable, as there is no constructor that takes the confusion matrix as an argument.The text was updated successfully, but these errors were encountered: