diff --git a/src/Microsoft.ML.PipelineInference/AutoInference.cs b/src/Microsoft.ML.PipelineInference/AutoInference.cs index a8681da559..2c94c58348 100644 --- a/src/Microsoft.ML.PipelineInference/AutoInference.cs +++ b/src/Microsoft.ML.PipelineInference/AutoInference.cs @@ -98,12 +98,13 @@ private SupportedMetric(string name, bool isMaximizing) public static SupportedMetric ByName(string name) { var fields = - typeof(SupportedMetric).GetMembers(BindingFlags.Static | BindingFlags.Public) - .Where(s => s.MemberType == MemberTypes.Field); + typeof(SupportedMetric).GetFields(BindingFlags.Static | BindingFlags.Public); + foreach (var field in fields) { - if (name.Equals(field.Name, StringComparison.OrdinalIgnoreCase)) - return (SupportedMetric)typeof(SupportedMetric).GetField(field.Name).GetValue(null); + var metric = (SupportedMetric)field.GetValue(Auc); + if (name.Equals(metric.Name, StringComparison.OrdinalIgnoreCase)) + return metric; } throw new NotSupportedException($"Metric '{name}' not supported."); } diff --git a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs index 07f6a70295..0cd6b0c906 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System.Linq; +using System.Collections.Generic; using Newtonsoft.Json.Linq; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -411,6 +412,27 @@ public void TestPipelineNodeCloning() } } + [Fact] + public void TestSupportedMetricsByName() + { + var names = new List() + { + AutoInference.SupportedMetric.AccuracyMacro.Name, + AutoInference.SupportedMetric.AccuracyMicro.Name, + AutoInference.SupportedMetric.Auc.Name, + AutoInference.SupportedMetric.AuPrc.Name, + AutoInference.SupportedMetric.Dbi.Name, + AutoInference.SupportedMetric.F1.Name, + AutoInference.SupportedMetric.LogLossReduction.Name + }; + + foreach (var name in names) + { + var metric = AutoInference.SupportedMetric.ByName(name); + Assert.Equal(metric.Name, name); + } + } + [Fact] public void TestHyperparameterFreezing() {