From 52aba237887051a301c504984639971f5abc7096 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Fri, 1 Jun 2018 13:51:16 -0700 Subject: [PATCH 1/6] Add label/grou/weight column name arguments to CV and train-test macros --- .../DataView/AppendRowsDataView.cs | 2 + .../EntryPoints/EntryPointNode.cs | 62 +++++- .../EntryPoints/InputBuilder.cs | 203 +++++++++--------- src/Microsoft.ML/CSharpApi.cs | 39 +++- .../EntryPoints/CrossValidationMacro.cs | 30 ++- .../Runtime/EntryPoints/MacroUtils.cs | 20 +- .../Runtime/EntryPoints/TrainTestMacro.cs | 20 +- .../UnitTests/TestCSharpApi.cs | 113 ++++++++++ 8 files changed, 354 insertions(+), 135 deletions(-) diff --git a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs index e7b1ff6741..6633e2535f 100644 --- a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs +++ b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs @@ -190,6 +190,8 @@ protected Delegate CreateGetter(int col) public ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col), "The column must be active against the defined predicate."); + if (!(Getters[col] is ValueGetter)) + throw Ch.Except($"Invalid TValue in GetGetter: '{typeof(TValue)}'"); return Getters[col] as ValueGetter; } diff --git a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs index 2230329c78..71eb7fe82f 100644 --- a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs +++ b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs @@ -473,9 +473,9 @@ public float Cost } } - public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunContext context, + private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCatalog, RunContext context, string id, string entryPointName, JObject inputs, JObject outputs, bool checkpoint = false, - string stageId = "", float cost = float.NaN) + string stageId = "", float cost = float.NaN, string label = null, string group = null, string weight = null) { Contracts.AssertValue(env); env.AssertNonEmpty(id); @@ -497,6 +497,7 @@ public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunCont _inputMap = new Dictionary(); _inputBindingMap = new Dictionary>(); _inputBuilder = new InputBuilder(_host, _entryPoint.InputType, moduleCatalog); + // REVIEW: This logic should move out of Node eventually and be delegated to // a class that can nest to handle Components with variables. if (inputs != null) @@ -508,6 +509,43 @@ public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunCont if (missing.Length > 0) throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}"); + var inputInstance = _inputBuilder.GetInstance(); + var warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." + + " Using column '{2}'. To column use '{1}' instead, please specify this name in" + + "the trainer node arguments."; + if (!string.IsNullOrEmpty(label) && inputInstance is LearnerInputBaseWithLabel) + { + var labelInputInstance = inputInstance as LearnerInputBaseWithLabel; + if (label != labelInputInstance.LabelColumn) + ch.Warning(warning, "label", label, labelInputInstance.LabelColumn); + else + labelInputInstance.LabelColumn = label; + } + if (!string.IsNullOrEmpty(group) && inputInstance is LearnerInputBaseWithGroupId) + { + var groupInputInstance = inputInstance as LearnerInputBaseWithGroupId; + if (group != groupInputInstance.GroupIdColumn) + ch.Warning(warning, "group Id", group, groupInputInstance.GroupIdColumn); + else + groupInputInstance.GroupIdColumn = group; + } + if (!string.IsNullOrEmpty(weight) && inputInstance is LearnerInputBaseWithWeight) + { + var weightInputInstance = inputInstance as LearnerInputBaseWithWeight; + if (weight != weightInputInstance.WeightColumn) + ch.Warning(warning, "weight", weight, weightInputInstance.WeightColumn); + else + weightInputInstance.WeightColumn = weight; + } + if (!string.IsNullOrEmpty(weight) && inputInstance is UnsupervisedLearnerInputBaseWithWeight) + { + var weightInputInstance = inputInstance as UnsupervisedLearnerInputBaseWithWeight; + if (weight != weightInputInstance.WeightColumn) + ch.Warning(warning, "weight", weight, weightInputInstance.WeightColumn); + else + weightInputInstance.WeightColumn = weight; + } + // Validate outputs. _outputHelper = new OutputHelper(_host, _entryPoint.OutputType); _outputMap = new Dictionary(); @@ -550,10 +588,15 @@ public static EntryPointNode Create( var inputBuilder = new InputBuilder(env, info.InputType, catalog); var outputHelper = new OutputHelper(env, info.OutputType); - var entryPointNode = new EntryPointNode(env, catalog, context, context.GenerateId(entryPointName), entryPointName, - inputBuilder.GetJsonObject(arguments, inputBindingMap, inputMap), - outputHelper.GetJsonObject(outputMap), checkpoint, stageId, cost); - return entryPointNode; + using (var ch = env.Start("Create EntryPointNode")) + { + var entryPointNode = new EntryPointNode(env, ch, catalog, context, context.GenerateId(entryPointName), entryPointName, + inputBuilder.GetJsonObject(arguments, inputBindingMap, inputMap), + outputHelper.GetJsonObject(outputMap), checkpoint, stageId, cost); + + ch.Done(); + return entryPointNode; + } } public static EntryPointNode Create( @@ -850,7 +893,8 @@ private object BuildParameterValue(List bindings) throw _host.ExceptNotImpl("Unsupported ParameterBinding"); } - public static List ValidateNodes(IHostEnvironment env, RunContext context, JArray nodes, ModuleCatalog moduleCatalog) + public static List ValidateNodes(IHostEnvironment env, RunContext context, JArray nodes, + ModuleCatalog moduleCatalog, string label = null, string group = null, string weight = null) { Contracts.AssertValue(env); env.AssertValue(context); @@ -890,8 +934,10 @@ public static List ValidateNodes(IHostEnvironment env, RunContex ch.Warning("Node '{0}' has unexpected fields that are ignored: {1}", id, string.Join(", ", unexpectedFields.Select(x => x.Name))); } - result.Add(new EntryPointNode(env, moduleCatalog, context, id, name, inputs, outputs, checkpoint, stageId, cost)); + result.Add(new EntryPointNode(env, ch, moduleCatalog, context, id, name, inputs, outputs, checkpoint, stageId, cost, label, group, weight)); } + + ch.Done(); } return result; } diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs index 939d49893b..e27d0d4400 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs @@ -9,7 +9,6 @@ using System.Reflection; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Data; using Newtonsoft.Json.Linq; namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils @@ -405,7 +404,11 @@ private static object ParseJsonValue(IExceptionContext ectx, Type type, Attribut return null; if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Optional<>) || type.GetGenericTypeDefinition() == typeof(Nullable<>))) + { + if (type.GetGenericTypeDefinition() == typeof(Optional<>) && value.HasValues) + value = value.Values().FirstOrDefault(); type = type.GetGenericArguments()[0]; + } if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Var<>))) { @@ -426,81 +429,81 @@ private static object ParseJsonValue(IExceptionContext ectx, Type type, Attribut { switch (dt) { - case TlcModule.DataKind.Bool: - return value.Value(); - case TlcModule.DataKind.String: - return value.Value(); - case TlcModule.DataKind.Char: - return value.Value(); - case TlcModule.DataKind.Enum: - if (!Enum.IsDefined(type, value.Value())) - throw ectx.Except($"Requested value '{value.Value()}' is not a member of the Enum type '{type.Name}'"); - return Enum.Parse(type, value.Value()); - case TlcModule.DataKind.Float: - if (type == typeof(double)) - return value.Value(); - else if (type == typeof(float)) - return value.Value(); - else - { - ectx.Assert(false); - throw ectx.ExceptNotSupp(); - } - case TlcModule.DataKind.Array: - var ja = value as JArray; - ectx.Check(ja != null, "Expected array value"); - Func makeArray = MakeArray; - return Utils.MarshalInvoke(makeArray, type.GetElementType(), ectx, ja, attributes, catalog); - case TlcModule.DataKind.Int: - if (type == typeof(long)) - return value.Value(); - if (type == typeof(int)) - return value.Value(); - ectx.Assert(false); - throw ectx.ExceptNotSupp(); - case TlcModule.DataKind.UInt: - if (type == typeof(ulong)) - return value.Value(); - if (type == typeof(uint)) - return value.Value(); + case TlcModule.DataKind.Bool: + return value.Value(); + case TlcModule.DataKind.String: + return value.Value(); + case TlcModule.DataKind.Char: + return value.Value(); + case TlcModule.DataKind.Enum: + if (!Enum.IsDefined(type, value.Value())) + throw ectx.Except($"Requested value '{value.Value()}' is not a member of the Enum type '{type.Name}'"); + return Enum.Parse(type, value.Value()); + case TlcModule.DataKind.Float: + if (type == typeof(double)) + return value.Value(); + else if (type == typeof(float)) + return value.Value(); + else + { ectx.Assert(false); throw ectx.ExceptNotSupp(); - case TlcModule.DataKind.Dictionary: - ectx.Check(value is JObject, "Expected object value"); - Func makeDict = MakeDictionary; - return Utils.MarshalInvoke(makeDict, type.GetGenericArguments()[1], ectx, (JObject)value, attributes, catalog); - case TlcModule.DataKind.Component: - var jo = value as JObject; - ectx.Check(jo != null, "Expected object value"); - // REVIEW: consider accepting strings alone. - var jName = jo[FieldNames.Name]; - ectx.Check(jName != null, "Field '" + FieldNames.Name + "' is required for component."); - ectx.Check(jName is JValue, "Expected '" + FieldNames.Name + "' field to be a string."); - var name = jName.Value(); - ectx.Check(jo[FieldNames.Settings] == null || jo[FieldNames.Settings] is JObject, - "Expected '" + FieldNames.Settings + "' field to be an object"); - return GetComponentJson(ectx, type, name, jo[FieldNames.Settings] as JObject, catalog); - default: - var settings = value as JObject; - ectx.Check(settings != null, "Expected object value"); - var inputBuilder = new InputBuilder(ectx, type, catalog); - - if (inputBuilder._fields.Length == 0) - throw ectx.Except($"Unsupported input type: {dt}"); - - if (settings != null) + } + case TlcModule.DataKind.Array: + var ja = value as JArray; + ectx.Check(ja != null, "Expected array value"); + Func makeArray = MakeArray; + return Utils.MarshalInvoke(makeArray, type.GetElementType(), ectx, ja, attributes, catalog); + case TlcModule.DataKind.Int: + if (type == typeof(long)) + return value.Value(); + if (type == typeof(int)) + return value.Value(); + ectx.Assert(false); + throw ectx.ExceptNotSupp(); + case TlcModule.DataKind.UInt: + if (type == typeof(ulong)) + return value.Value(); + if (type == typeof(uint)) + return value.Value(); + ectx.Assert(false); + throw ectx.ExceptNotSupp(); + case TlcModule.DataKind.Dictionary: + ectx.Check(value is JObject, "Expected object value"); + Func makeDict = MakeDictionary; + return Utils.MarshalInvoke(makeDict, type.GetGenericArguments()[1], ectx, (JObject)value, attributes, catalog); + case TlcModule.DataKind.Component: + var jo = value as JObject; + ectx.Check(jo != null, "Expected object value"); + // REVIEW: consider accepting strings alone. + var jName = jo[FieldNames.Name]; + ectx.Check(jName != null, "Field '" + FieldNames.Name + "' is required for component."); + ectx.Check(jName is JValue, "Expected '" + FieldNames.Name + "' field to be a string."); + var name = jName.Value(); + ectx.Check(jo[FieldNames.Settings] == null || jo[FieldNames.Settings] is JObject, + "Expected '" + FieldNames.Settings + "' field to be an object"); + return GetComponentJson(ectx, type, name, jo[FieldNames.Settings] as JObject, catalog); + default: + var settings = value as JObject; + ectx.Check(settings != null, "Expected object value"); + var inputBuilder = new InputBuilder(ectx, type, catalog); + + if (inputBuilder._fields.Length == 0) + throw ectx.Except($"Unsupported input type: {dt}"); + + if (settings != null) + { + foreach (var pair in settings) { - foreach (var pair in settings) - { - if (!inputBuilder.TrySetValueJson(pair.Key, pair.Value)) - throw ectx.Except($"Unexpected value for component '{type}', field '{pair.Key}': '{pair.Value}'"); - } + if (!inputBuilder.TrySetValueJson(pair.Key, pair.Value)) + throw ectx.Except($"Unexpected value for component '{type}', field '{pair.Key}': '{pair.Value}'"); } + } - var missing = inputBuilder.GetMissingValues().ToArray(); - if (missing.Length > 0) - throw ectx.Except($"The following required inputs were not provided for component '{type}': {string.Join(", ", missing)}"); - return inputBuilder.GetInstance(); + var missing = inputBuilder.GetMissingValues().ToArray(); + if (missing.Length > 0) + throw ectx.Except($"The following required inputs were not provided for component '{type}': {string.Join(", ", missing)}"); + return inputBuilder.GetInstance(); } } catch (FormatException ex) @@ -833,35 +836,35 @@ public static class SweepableDiscreteParam public static class PipelineSweeperSupportedMetrics { public new static string ToString() => "SupportedMetric"; - public const string Auc = BinaryClassifierEvaluator.Auc; - public const string AccuracyMicro = Data.MultiClassClassifierEvaluator.AccuracyMicro; - public const string AccuracyMacro = MultiClassClassifierEvaluator.AccuracyMacro; - public const string F1 = BinaryClassifierEvaluator.F1; - public const string AuPrc = BinaryClassifierEvaluator.AuPrc; - public const string TopKAccuracy = MultiClassClassifierEvaluator.TopKAccuracy; - public const string L1 = RegressionLossEvaluatorBase.L1; - public const string L2 = RegressionLossEvaluatorBase.L2; - public const string Rms = RegressionLossEvaluatorBase.Rms; - public const string LossFn = RegressionLossEvaluatorBase.Loss; - public const string RSquared = RegressionLossEvaluatorBase.RSquared; - public const string LogLoss = BinaryClassifierEvaluator.LogLoss; - public const string LogLossReduction = BinaryClassifierEvaluator.LogLossReduction; - public const string Ndcg = RankerEvaluator.Ndcg; - public const string Dcg = RankerEvaluator.Dcg; - public const string PositivePrecision = BinaryClassifierEvaluator.PosPrecName; - public const string PositiveRecall = BinaryClassifierEvaluator.PosRecallName; - public const string NegativePrecision = BinaryClassifierEvaluator.NegPrecName; - public const string NegativeRecall = BinaryClassifierEvaluator.NegRecallName; - public const string DrAtK = AnomalyDetectionEvaluator.OverallMetrics.DrAtK; - public const string DrAtPFpr = AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr; - public const string DrAtNumPos = AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos; - public const string NumAnomalies = AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies; - public const string ThreshAtK = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK; - public const string ThreshAtP = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP; - public const string ThreshAtNumPos = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos; - public const string Nmi = ClusteringEvaluator.Nmi; - public const string AvgMinScore = ClusteringEvaluator.AvgMinScore; - public const string Dbi = ClusteringEvaluator.Dbi; + public const string Auc = "AUC"; + public const string AccuracyMicro = "AccuracyMicro"; + public const string AccuracyMacro = "AccuracyMacro"; + public const string F1 = "F1"; + public const string AuPrc = "AUPRC"; + public const string TopKAccuracy = "TopKAccuracy"; + public const string L1 = "L1"; + public const string L2 = "L2"; + public const string Rms = "RMS"; + public const string LossFn = "LossFn"; + public const string RSquared = "RSquared"; + public const string LogLoss = "LogLoss"; + public const string LogLossReduction = "LogLossReduction"; + public const string Ndcg = "NDCG"; + public const string Dcg = "DCG"; + public const string PositivePrecision = "PositivePrecision"; + public const string PositiveRecall = "PositiveRecall"; + public const string NegativePrecision = "NegativePrecision"; + public const string NegativeRecall = "NegativeRecall"; + public const string DrAtK = "DrAtK"; + public const string DrAtPFpr = "DrAtPFpr"; + public const string DrAtNumPos = "DrAtNumPos"; + public const string NumAnomalies = "NumAnomalies"; + public const string ThreshAtK = "ThreshAtK"; + public const string ThreshAtP = "ThreshAtP"; + public const string ThreshAtNumPos = "ThreshAtNumPos"; + public const string Nmi = "NMI"; + public const string AvgMinScore = "AvgMinScore"; + public const string Dbi = "DBI"; } } } diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 058e8bafe3..fc1e85c493 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -2270,6 +2270,21 @@ public sealed partial class CrossValidator /// public Models.MacroUtilsTrainerKinds Kind { get; set; } = Models.MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; + /// + /// Column to use for labels + /// + public string LabelColumn { get; set; } = "Label"; + + /// + /// Column to use for example weight + /// + public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + + /// + /// Column to use for grouping + /// + public Microsoft.ML.Runtime.EntryPoints.Optional GroupColumn { get; set; } + public sealed class Output { @@ -3456,6 +3471,21 @@ public sealed partial class TrainTestEvaluator /// public bool IncludeTrainingMetrics { get; set; } = false; + /// + /// Column to use for labels + /// + public string LabelColumn { get; set; } = "Label"; + + /// + /// Column to use for example weight + /// + public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + + /// + /// Column to use for grouping + /// + public Microsoft.ML.Runtime.EntryPoints.Optional GroupColumn { get; set; } + public sealed class Output { @@ -6173,7 +6203,7 @@ public enum KMeansPlusPlusTrainerInitAlgorithm /// /// K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers. /// - public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem + public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem { @@ -6208,6 +6238,11 @@ public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.Entry /// public int? NumThreads { get; set; } + /// + /// Column to use for example weight + /// + public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + /// /// The data to be used for training /// @@ -7024,7 +7059,7 @@ namespace Trainers /// /// Train an PCA Anomaly model. /// - public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem + public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem { diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index 325b8a093a..00cc584ab8 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -77,6 +77,15 @@ public sealed class Arguments // (and the same for the TrainTest macro). I currently do not know how to do this, so this should be revisited in the future. [Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 8)] public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer; + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 10)] + public string LabelColumn = DefaultColumnNames.Label; + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 11)] + public Optional WeightColumn = Optional.Implicit(DefaultColumnNames.Weight); + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)] + public Optional GroupColumn = Optional.Implicit(DefaultColumnNames.GroupId); } // REVIEW: This output would be much better as an array of CommonOutputs.ClassificationEvaluateOutput, @@ -121,6 +130,12 @@ public sealed class CombineMetricsInput [Argument(ArgumentType.AtMostOnce, HelpText = "The label column name", ShortName = "Label", SortOrder = 5)] public string LabelColumn = DefaultColumnNames.Label; + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 6)] + public Optional WeightColumn = Optional.Implicit(DefaultColumnNames.Weight); + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)] + public Optional GroupColumn = Optional.Implicit(DefaultColumnNames.GroupId); + [Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 6)] public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer; } @@ -188,7 +203,10 @@ public static CommonOutputs.MacroOutput CrossValidate( var args = new TrainTestMacro.Arguments { Nodes = new JArray(graph.Select(n => n.ToJson()).ToArray()), - TransformModel = null + TransformModel = null, + LabelColumn = input.LabelColumn, + GroupColumn = input.GroupColumn.IsExplicit ? input.GroupColumn : Optional.Implicit(DefaultColumnNames.GroupId), + WeightColumn = input.WeightColumn.IsExplicit ? input.WeightColumn : Optional.Implicit(DefaultColumnNames.Weight) }; if (transformModelVarName != null) @@ -356,6 +374,7 @@ public static CommonOutputs.MacroOutput CrossValidate( var combineArgs = new CombineMetricsInput(); combineArgs.Kind = input.Kind; + combineArgs.LabelColumn = input.LabelColumn; // Set the input bindings for the CombineMetrics entry point. var combineInputBindingMap = new Dictionary>(); @@ -383,10 +402,13 @@ public static CommonOutputs.MacroOutput CrossValidate( var combineInstanceMetric = new Var(); combineInstanceMetric.VarName = node.GetOutputVariableName(nameof(Output.PerInstanceMetrics)); combineOutputMap.Add(nameof(Output.PerInstanceMetrics), combineInstanceMetric.VarName); - var combineConfusionMatrix = new Var(); - combineConfusionMatrix.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix)); - combineOutputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), combineConfusionMatrix.VarName); + if (confusionMatricesOutput != null) + { + var combineConfusionMatrix = new Var(); + combineConfusionMatrix.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix)); + combineOutputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), combineConfusionMatrix.VarName); + } subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); subGraphNodes.Add(EntryPointNode.Create(env, "Models.CrossValidationResultsCombiner", combineArgs, node.Catalog, node.Context, combineInputBindingMap, combineInputMap, combineOutputMap)); return new CommonOutputs.MacroOutput() { Nodes = subGraphNodes }; diff --git a/src/Microsoft.ML/Runtime/EntryPoints/MacroUtils.cs b/src/Microsoft.ML/Runtime/EntryPoints/MacroUtils.cs index aa5e8abc16..7e7f6acbac 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/MacroUtils.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/MacroUtils.cs @@ -30,9 +30,8 @@ public sealed class EvaluatorSettings { public string LabelColumn { get; set; } public string NameColumn { get; set; } - public string ScoreColumn { get; set; } - public string[] StratColumn { get; set; } public string WeightColumn { get; set; } + public string GroupColumn { get; set; } public string FeatureColumn { get; set; } public EvaluatorSettings() @@ -61,8 +60,6 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, WeightColumn = settings.WeightColumn }, EvaluatorOutput = () => new Models.BinaryClassificationEvaluator.Output() @@ -77,8 +74,6 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, WeightColumn = settings.WeightColumn }, EvaluatorOutput = () => new Models.ClassificationEvaluator.Output() @@ -93,9 +88,8 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, - WeightColumn = settings.WeightColumn + WeightColumn = settings.WeightColumn, + GroupIdColumn = settings.GroupColumn }, EvaluatorOutput = () => new Models.RankerEvaluator.Output() } @@ -109,8 +103,6 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, WeightColumn = settings.WeightColumn }, EvaluatorOutput = () => new Models.RegressionEvaluator.Output() @@ -125,8 +117,6 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, WeightColumn = settings.WeightColumn, }, EvaluatorOutput = () => new Models.MultiOutputRegressionEvaluator.Output() @@ -141,8 +131,6 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, WeightColumn = settings.WeightColumn }, EvaluatorOutput = () => new Models.AnomalyDetectionEvaluator.Output() @@ -157,8 +145,6 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, WeightColumn = settings.WeightColumn, FeatureColumn = settings.FeatureColumn }, diff --git a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs index c3c06b1031..3e2ba94615 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs @@ -62,6 +62,15 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates whether to include and output training dataset metrics.", SortOrder = 9)] public Boolean IncludeTrainingMetrics = false; + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 10)] + public string LabelColumn = DefaultColumnNames.Label; + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 11)] + public Optional WeightColumn = Optional.Implicit(DefaultColumnNames.Weight); + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)] + public Optional GroupColumn = Optional.Implicit(DefaultColumnNames.GroupId); } public sealed class Output @@ -110,7 +119,8 @@ public static CommonOutputs.MacroOutput TrainTest( // Parse the subgraph. var subGraphRunContext = new RunContext(env); - var subGraphNodes = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes, node.Catalog); + var subGraphNodes = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes, node.Catalog, input.LabelColumn, + input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null, input.WeightColumn.IsExplicit ? input.WeightColumn.Value : null); // Change the subgraph to use the training data as input. var varName = input.Inputs.Data.VarName; @@ -206,11 +216,13 @@ public static CommonOutputs.MacroOutput TrainTest( // Do not double-add previous nodes. exp.Reset(); - // REVIEW: we need to extract the proper label column name here to pass to the evaluators. - // This is where you would add code to do it. + + // REVIEW: add similar support for NameColumn and FeatureColumn. var settings = new MacroUtils.EvaluatorSettings { - LabelColumn = DefaultColumnNames.Label + LabelColumn = input.LabelColumn, + WeightColumn = input.WeightColumn.IsExplicit ? input.WeightColumn.Value : null, + GroupColumn = input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null }; string outVariableName; diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index 4c1969c2d9..f502dee76b 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -593,5 +593,118 @@ public void TestCrossValidationMacroWithStratification() } } } + + [Fact] + public void TestCrossValidationMacroWithNonDefaultNames() + { + string dataPath = GetDataPath(@"adult.tiny.with-schema.txt"); + using (var env = new TlcEnvironment(42)) + { + var subGraph = env.CreateExperiment(); + + var textToKey = new ML.Transforms.TextToKeyConverter(); + textToKey.Column = new[] { new ML.Transforms.TermTransformColumn() { Name = "Label1", Source = "Label" } }; + var textToKeyOutput = subGraph.Add(textToKey); + + var hash = new ML.Transforms.HashConverter(); + hash.Column = new[] { new ML.Transforms.HashJoinTransformColumn() { Name = "GroupId1", Source = "Workclass" } }; + hash.Data = textToKeyOutput.OutputData; + var hashOutput = subGraph.Add(hash); + + var learnerInput = new Trainers.FastTreeRanker + { + TrainingData = hashOutput.OutputData, + NumThreads = 1, + LabelColumn = "Label1", + GroupIdColumn = "GroupId1" + }; + var learnerOutput = subGraph.Add(learnerInput); + + var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner + { + TransformModels = new ArrayVar(textToKeyOutput.Model, hashOutput.Model), + PredictorModel = learnerOutput.PredictorModel + }; + var modelCombineOutput = subGraph.Add(modelCombine); + + var experiment = env.CreateExperiment(); + var importInput = new ML.Data.TextLoader(dataPath); + importInput.Arguments.HasHeader = true; + importInput.Arguments.Column = new TextLoaderColumn[] + { + new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, + new TextLoaderColumn { Name = "Workclass", Source = new[] { new TextLoaderRange(1) }, Type = DataKind.Text }, + new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(9, 14) } } + }; + var importOutput = experiment.Add(importInput); + + var crossValidate = new Models.CrossValidator + { + Data = importOutput.Data, + Nodes = subGraph, + TransformModel = null, + LabelColumn = "Label1", + GroupColumn = "GroupId1", + Kind = Models.MacroUtilsTrainerKinds.SignatureRankerTrainer + }; + crossValidate.Inputs.Data = textToKey.Data; + crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; + var crossValidateOutput = experiment.Add(crossValidate); + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); + + var schema = data.Schema; + var b = schema.TryGetColumnIndex("NDCG", out int metricCol); + Assert.True(b); + b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + { + var getter = cursor.GetGetter>(metricCol); + var foldGetter = cursor.GetGetter(foldCol); + DvText fold = default; + + // Get the verage. + b = cursor.MoveNext(); + Assert.True(b); + var avg = default(VBuffer); + getter(ref avg); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Average")); + + // Get the standard deviation. + b = cursor.MoveNext(); + Assert.True(b); + var stdev = default(VBuffer); + getter(ref stdev); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Standard Deviation")); + Assert.Equal(10.87, stdev.Values[0], 3); + Assert.Equal(6.804, stdev.Values[1], 3); + Assert.Equal(7.568, stdev.Values[2], 3); + + var sumBldr = new BufferBuilder(R8Adder.Instance); + sumBldr.Reset(avg.Length, true); + var val = default(VBuffer); + for (int f = 0; f < 2; f++) + { + b = cursor.MoveNext(); + Assert.True(b); + getter(ref val); + foldGetter(ref fold); + sumBldr.AddFeatures(0, ref val); + Assert.True(fold.EqualsStr("Fold " + f)); + } + var sum = default(VBuffer); + sumBldr.GetResult(ref sum); + for (int i = 0; i < avg.Length; i++) + Assert.Equal(avg.Values[i], sum.Values[i] / 2); + b = cursor.MoveNext(); + Assert.False(b); + } + } + } } } From 7826c715388d7aeae3a4a7c8054eb591603351cf Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 4 Jun 2018 08:49:16 -0700 Subject: [PATCH 2/6] Fix unit test. --- .../EntryPoints/CrossValidationMacro.cs | 41 +++++++++++-------- .../UnitTests/TestCSharpApi.cs | 6 +-- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index 00cc584ab8..0c2950a87b 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -375,6 +375,8 @@ public static CommonOutputs.MacroOutput CrossValidate( var combineArgs = new CombineMetricsInput(); combineArgs.Kind = input.Kind; combineArgs.LabelColumn = input.LabelColumn; + combineArgs.WeightColumn = input.WeightColumn; + combineArgs.GroupColumn = input.GroupColumn; // Set the input bindings for the CombineMetrics entry point. var combineInputBindingMap = new Dictionary>(); @@ -420,7 +422,12 @@ public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetrics var eval = GetEvaluator(env, input.Kind); var perInst = EvaluateUtils.ConcatenatePerInstanceDataViews(env, eval, true, true, input.PerInstanceMetrics.Select( - idv => RoleMappedData.Create(idv, RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn))).ToArray(), + idv => RoleMappedData.CreateOpt(idv, new[] + { + RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn), + RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, input.WeightColumn.Value), + RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, input.GroupColumn.Value) + })).ToArray(), out var variableSizeVectorColumnNames); var warnings = input.Warnings != null ? new List(input.Warnings) : new List(); @@ -477,22 +484,22 @@ private static IMamlEvaluator GetEvaluator(IHostEnvironment env, MacroUtils.Trai { switch (kind) { - case MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer: - return new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer: - return new MultiClassMamlEvaluator(env, new MultiClassMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureRegressorTrainer: - return new RegressionMamlEvaluator(env, new RegressionMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureRankerTrainer: - return new RankerMamlEvaluator(env, new RankerMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureAnomalyDetectorTrainer: - return new AnomalyDetectionMamlEvaluator(env, new AnomalyDetectionMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureClusteringTrainer: - return new ClusteringMamlEvaluator(env, new ClusteringMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureMultiOutputRegressorTrainer: - return new MultiOutputRegressionMamlEvaluator(env, new MultiOutputRegressionMamlEvaluator.Arguments()); - default: - throw env.ExceptParam(nameof(kind), $"Trainer kind {kind} does not have an evaluator"); + case MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer: + return new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer: + return new MultiClassMamlEvaluator(env, new MultiClassMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureRegressorTrainer: + return new RegressionMamlEvaluator(env, new RegressionMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureRankerTrainer: + return new RankerMamlEvaluator(env, new RankerMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureAnomalyDetectorTrainer: + return new AnomalyDetectionMamlEvaluator(env, new AnomalyDetectionMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureClusteringTrainer: + return new ClusteringMamlEvaluator(env, new ClusteringMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureMultiOutputRegressorTrainer: + return new MultiOutputRegressionMamlEvaluator(env, new MultiOutputRegressionMamlEvaluator.Arguments()); + default: + throw env.ExceptParam(nameof(kind), $"Trainer kind {kind} does not have an evaluator"); } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index f502dee76b..029ab670fd 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -681,9 +681,9 @@ public void TestCrossValidationMacroWithNonDefaultNames() getter(ref stdev); foldGetter(ref fold); Assert.True(fold.EqualsStr("Standard Deviation")); - Assert.Equal(10.87, stdev.Values[0], 3); - Assert.Equal(6.804, stdev.Values[1], 3); - Assert.Equal(7.568, stdev.Values[2], 3); + Assert.Equal(2.462, stdev.Values[0], 3); + Assert.Equal(2.763, stdev.Values[1], 3); + Assert.Equal(3.273, stdev.Values[2], 3); var sumBldr = new BufferBuilder(R8Adder.Instance); sumBldr.Reset(avg.Length, true); From 43a8bbefcc1df16a50123669a166943b33481215 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 4 Jun 2018 08:59:51 -0700 Subject: [PATCH 3/6] Merge. --- .../EntryPoints/InputBuilder.cs | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs index e27d0d4400..bd4b35dad1 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs @@ -4,10 +4,10 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Reflection; using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Newtonsoft.Json.Linq; @@ -836,35 +836,35 @@ public static class SweepableDiscreteParam public static class PipelineSweeperSupportedMetrics { public new static string ToString() => "SupportedMetric"; - public const string Auc = "AUC"; - public const string AccuracyMicro = "AccuracyMicro"; - public const string AccuracyMacro = "AccuracyMacro"; - public const string F1 = "F1"; - public const string AuPrc = "AUPRC"; - public const string TopKAccuracy = "TopKAccuracy"; - public const string L1 = "L1"; - public const string L2 = "L2"; - public const string Rms = "RMS"; - public const string LossFn = "LossFn"; - public const string RSquared = "RSquared"; - public const string LogLoss = "LogLoss"; - public const string LogLossReduction = "LogLossReduction"; - public const string Ndcg = "NDCG"; - public const string Dcg = "DCG"; - public const string PositivePrecision = "PositivePrecision"; - public const string PositiveRecall = "PositiveRecall"; - public const string NegativePrecision = "NegativePrecision"; - public const string NegativeRecall = "NegativeRecall"; - public const string DrAtK = "DrAtK"; - public const string DrAtPFpr = "DrAtPFpr"; - public const string DrAtNumPos = "DrAtNumPos"; - public const string NumAnomalies = "NumAnomalies"; - public const string ThreshAtK = "ThreshAtK"; - public const string ThreshAtP = "ThreshAtP"; - public const string ThreshAtNumPos = "ThreshAtNumPos"; - public const string Nmi = "NMI"; - public const string AvgMinScore = "AvgMinScore"; - public const string Dbi = "DBI"; + public const string Auc = BinaryClassifierEvaluator.Auc; + public const string AccuracyMicro = Data.MultiClassClassifierEvaluator.AccuracyMicro; + public const string AccuracyMacro = MultiClassClassifierEvaluator.AccuracyMacro; + public const string F1 = BinaryClassifierEvaluator.F1; + public const string AuPrc = BinaryClassifierEvaluator.AuPrc; + public const string TopKAccuracy = MultiClassClassifierEvaluator.TopKAccuracy; + public const string L1 = RegressionLossEvaluatorBase.L1; + public const string L2 = RegressionLossEvaluatorBase.L2; + public const string Rms = RegressionLossEvaluatorBase.Rms; + public const string LossFn = RegressionLossEvaluatorBase.Loss; + public const string RSquared = RegressionLossEvaluatorBase.RSquared; + public const string LogLoss = BinaryClassifierEvaluator.LogLoss; + public const string LogLossReduction = BinaryClassifierEvaluator.LogLossReduction; + public const string Ndcg = RankerEvaluator.Ndcg; + public const string Dcg = RankerEvaluator.Dcg; + public const string PositivePrecision = BinaryClassifierEvaluator.PosPrecName; + public const string PositiveRecall = BinaryClassifierEvaluator.PosRecallName; + public const string NegativePrecision = BinaryClassifierEvaluator.NegPrecName; + public const string NegativeRecall = BinaryClassifierEvaluator.NegRecallName; + public const string DrAtK = AnomalyDetectionEvaluator.OverallMetrics.DrAtK; + public const string DrAtPFpr = AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr; + public const string DrAtNumPos = AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos; + public const string NumAnomalies = AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies; + public const string ThreshAtK = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK; + public const string ThreshAtP = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP; + public const string ThreshAtNumPos = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos; + public const string Nmi = ClusteringEvaluator.Nmi; + public const string AvgMinScore = ClusteringEvaluator.AvgMinScore; + public const string Dbi = ClusteringEvaluator.Dbi; } } } From 40ae3c2a53160ad14b781e1f23ab77f00decf41e Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 4 Jun 2018 09:10:52 -0700 Subject: [PATCH 4/6] Update CSharp API. --- src/Microsoft.ML/CSharpApi.cs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index fc1e85c493..e35fb2ff81 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -2165,6 +2165,16 @@ public sealed partial class CrossValidationResultsCombiner /// public string LabelColumn { get; set; } = "Label"; + /// + /// Column to use for example weight + /// + public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + + /// + /// Column to use for grouping + /// + public Microsoft.ML.Runtime.EntryPoints.Optional GroupColumn { get; set; } + /// /// Specifies the trainer kind, which determines the evaluator to be used. /// From 4e62e1416e12d7b333e9c08fde6eff7e3be46f6e Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 4 Jun 2018 09:38:52 -0700 Subject: [PATCH 5/6] Fix EntryPointCatalog test. --- .../Common/EntryPoints/core_manifest.json | 96 +++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 7e06c77821..e237902f6f 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -1333,6 +1333,18 @@ "IsNullable": false, "Default": "Label" }, + { + "Name": "WeightColumn", + "Type": "String", + "Desc": "Column to use for example weight", + "Aliases": [ + "weight" + ], + "Required": false, + "SortOrder": 6.0, + "IsNullable": false, + "Default": "Weight" + }, { "Name": "Kind", "Type": { @@ -1352,6 +1364,18 @@ "SortOrder": 6.0, "IsNullable": false, "Default": "SignatureBinaryClassifierTrainer" + }, + { + "Name": "GroupColumn", + "Type": "String", + "Desc": "Column to use for grouping", + "Aliases": [ + "group" + ], + "Required": false, + "SortOrder": 12.0, + "IsNullable": false, + "Default": "GroupId" } ], "Outputs": [ @@ -1504,6 +1528,42 @@ "SortOrder": 8.0, "IsNullable": false, "Default": "SignatureBinaryClassifierTrainer" + }, + { + "Name": "LabelColumn", + "Type": "String", + "Desc": "Column to use for labels", + "Aliases": [ + "lab" + ], + "Required": false, + "SortOrder": 10.0, + "IsNullable": false, + "Default": "Label" + }, + { + "Name": "WeightColumn", + "Type": "String", + "Desc": "Column to use for example weight", + "Aliases": [ + "weight" + ], + "Required": false, + "SortOrder": 11.0, + "IsNullable": false, + "Default": "Weight" + }, + { + "Name": "GroupColumn", + "Type": "String", + "Desc": "Column to use for grouping", + "Aliases": [ + "group" + ], + "Required": false, + "SortOrder": 12.0, + "IsNullable": false, + "Default": "GroupId" } ], "Outputs": [ @@ -3107,6 +3167,42 @@ "SortOrder": 9.0, "IsNullable": false, "Default": false + }, + { + "Name": "LabelColumn", + "Type": "String", + "Desc": "Column to use for labels", + "Aliases": [ + "lab" + ], + "Required": false, + "SortOrder": 10.0, + "IsNullable": false, + "Default": "Label" + }, + { + "Name": "WeightColumn", + "Type": "String", + "Desc": "Column to use for example weight", + "Aliases": [ + "weight" + ], + "Required": false, + "SortOrder": 11.0, + "IsNullable": false, + "Default": "Weight" + }, + { + "Name": "GroupColumn", + "Type": "String", + "Desc": "Column to use for grouping", + "Aliases": [ + "group" + ], + "Required": false, + "SortOrder": 12.0, + "IsNullable": false, + "Default": "GroupId" } ], "Outputs": [ From c862d369d535d6e0a668bb03836e40a3a4751caf Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 4 Jun 2018 17:07:16 -0700 Subject: [PATCH 6/6] Address PR comments. --- .../EntryPoints/EntryPointNode.cs | 54 ++++--- .../EntryPoints/InputBuilder.cs | 140 +++++++++--------- .../EntryPoints/CrossValidationMacro.cs | 37 +++-- .../UnitTests/TestCSharpApi.cs | 95 ++++++++---- 4 files changed, 183 insertions(+), 143 deletions(-) diff --git a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs index 71eb7fe82f..1ff3daee02 100644 --- a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs +++ b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs @@ -513,37 +513,45 @@ private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCa var warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." + " Using column '{2}'. To column use '{1}' instead, please specify this name in" + "the trainer node arguments."; - if (!string.IsNullOrEmpty(label) && inputInstance is LearnerInputBaseWithLabel) + if (!string.IsNullOrEmpty(label) && Utils.Size(_entryPoint.InputKinds) > 0 && + _entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithLabel))) { - var labelInputInstance = inputInstance as LearnerInputBaseWithLabel; - if (label != labelInputInstance.LabelColumn) - ch.Warning(warning, "label", label, labelInputInstance.LabelColumn); + var labelColField = _inputBuilder.GetFieldNameOrNull("LabelColumn"); + ch.AssertNonEmpty(labelColField); + var labelColFieldType = _inputBuilder.GetFieldTypeOrNull(labelColField); + ch.Assert(labelColFieldType == typeof(string)); + var inputLabel = inputInstance.GetType().GetField(labelColField).GetValue(inputInstance); + if (label != (string)inputLabel) + ch.Warning(warning, "label", label, inputLabel); else - labelInputInstance.LabelColumn = label; + _inputBuilder.TrySetValue(labelColField, label); } - if (!string.IsNullOrEmpty(group) && inputInstance is LearnerInputBaseWithGroupId) + if (!string.IsNullOrEmpty(group) && Utils.Size(_entryPoint.InputKinds) > 0 && + _entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithGroupId))) { - var groupInputInstance = inputInstance as LearnerInputBaseWithGroupId; - if (group != groupInputInstance.GroupIdColumn) - ch.Warning(warning, "group Id", group, groupInputInstance.GroupIdColumn); + var groupColField = _inputBuilder.GetFieldNameOrNull("GroupIdColumn"); + ch.AssertNonEmpty(groupColField); + var groupColFieldType = _inputBuilder.GetFieldTypeOrNull(groupColField); + ch.Assert(groupColFieldType == typeof(string)); + var inputGroup = inputInstance.GetType().GetField(groupColField).GetValue(inputInstance); + if (group != (Optional)inputGroup) + ch.Warning(warning, "group Id", label, inputGroup); else - groupInputInstance.GroupIdColumn = group; + _inputBuilder.TrySetValue(groupColField, label); } - if (!string.IsNullOrEmpty(weight) && inputInstance is LearnerInputBaseWithWeight) + if (!string.IsNullOrEmpty(weight) && Utils.Size(_entryPoint.InputKinds) > 0 && + (_entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithWeight)) || + _entryPoint.InputKinds.Contains(typeof(CommonInputs.IUnsupervisedTrainerWithWeight)))) { - var weightInputInstance = inputInstance as LearnerInputBaseWithWeight; - if (weight != weightInputInstance.WeightColumn) - ch.Warning(warning, "weight", weight, weightInputInstance.WeightColumn); + var weightColField = _inputBuilder.GetFieldNameOrNull("WeightColumn"); + ch.AssertNonEmpty(weightColField); + var weightColFieldType = _inputBuilder.GetFieldTypeOrNull(weightColField); + ch.Assert(weightColFieldType == typeof(string)); + var inputWeight = inputInstance.GetType().GetField(weightColField).GetValue(inputInstance); + if (weight != (Optional)inputWeight) + ch.Warning(warning, "weight", label, inputWeight); else - weightInputInstance.WeightColumn = weight; - } - if (!string.IsNullOrEmpty(weight) && inputInstance is UnsupervisedLearnerInputBaseWithWeight) - { - var weightInputInstance = inputInstance as UnsupervisedLearnerInputBaseWithWeight; - if (weight != weightInputInstance.WeightColumn) - ch.Warning(warning, "weight", weight, weightInputInstance.WeightColumn); - else - weightInputInstance.WeightColumn = weight; + _inputBuilder.TrySetValue(weightColField, label); } // Validate outputs. diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs index bd4b35dad1..e5afd8dbb5 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs @@ -429,81 +429,81 @@ private static object ParseJsonValue(IExceptionContext ectx, Type type, Attribut { switch (dt) { - case TlcModule.DataKind.Bool: - return value.Value(); - case TlcModule.DataKind.String: - return value.Value(); - case TlcModule.DataKind.Char: - return value.Value(); - case TlcModule.DataKind.Enum: - if (!Enum.IsDefined(type, value.Value())) - throw ectx.Except($"Requested value '{value.Value()}' is not a member of the Enum type '{type.Name}'"); - return Enum.Parse(type, value.Value()); - case TlcModule.DataKind.Float: - if (type == typeof(double)) - return value.Value(); - else if (type == typeof(float)) - return value.Value(); - else - { + case TlcModule.DataKind.Bool: + return value.Value(); + case TlcModule.DataKind.String: + return value.Value(); + case TlcModule.DataKind.Char: + return value.Value(); + case TlcModule.DataKind.Enum: + if (!Enum.IsDefined(type, value.Value())) + throw ectx.Except($"Requested value '{value.Value()}' is not a member of the Enum type '{type.Name}'"); + return Enum.Parse(type, value.Value()); + case TlcModule.DataKind.Float: + if (type == typeof(double)) + return value.Value(); + else if (type == typeof(float)) + return value.Value(); + else + { + ectx.Assert(false); + throw ectx.ExceptNotSupp(); + } + case TlcModule.DataKind.Array: + var ja = value as JArray; + ectx.Check(ja != null, "Expected array value"); + Func makeArray = MakeArray; + return Utils.MarshalInvoke(makeArray, type.GetElementType(), ectx, ja, attributes, catalog); + case TlcModule.DataKind.Int: + if (type == typeof(long)) + return value.Value(); + if (type == typeof(int)) + return value.Value(); ectx.Assert(false); throw ectx.ExceptNotSupp(); - } - case TlcModule.DataKind.Array: - var ja = value as JArray; - ectx.Check(ja != null, "Expected array value"); - Func makeArray = MakeArray; - return Utils.MarshalInvoke(makeArray, type.GetElementType(), ectx, ja, attributes, catalog); - case TlcModule.DataKind.Int: - if (type == typeof(long)) - return value.Value(); - if (type == typeof(int)) - return value.Value(); - ectx.Assert(false); - throw ectx.ExceptNotSupp(); - case TlcModule.DataKind.UInt: - if (type == typeof(ulong)) - return value.Value(); - if (type == typeof(uint)) - return value.Value(); - ectx.Assert(false); - throw ectx.ExceptNotSupp(); - case TlcModule.DataKind.Dictionary: - ectx.Check(value is JObject, "Expected object value"); - Func makeDict = MakeDictionary; - return Utils.MarshalInvoke(makeDict, type.GetGenericArguments()[1], ectx, (JObject)value, attributes, catalog); - case TlcModule.DataKind.Component: - var jo = value as JObject; - ectx.Check(jo != null, "Expected object value"); - // REVIEW: consider accepting strings alone. - var jName = jo[FieldNames.Name]; - ectx.Check(jName != null, "Field '" + FieldNames.Name + "' is required for component."); - ectx.Check(jName is JValue, "Expected '" + FieldNames.Name + "' field to be a string."); - var name = jName.Value(); - ectx.Check(jo[FieldNames.Settings] == null || jo[FieldNames.Settings] is JObject, - "Expected '" + FieldNames.Settings + "' field to be an object"); - return GetComponentJson(ectx, type, name, jo[FieldNames.Settings] as JObject, catalog); - default: - var settings = value as JObject; - ectx.Check(settings != null, "Expected object value"); - var inputBuilder = new InputBuilder(ectx, type, catalog); - - if (inputBuilder._fields.Length == 0) - throw ectx.Except($"Unsupported input type: {dt}"); - - if (settings != null) - { - foreach (var pair in settings) + case TlcModule.DataKind.UInt: + if (type == typeof(ulong)) + return value.Value(); + if (type == typeof(uint)) + return value.Value(); + ectx.Assert(false); + throw ectx.ExceptNotSupp(); + case TlcModule.DataKind.Dictionary: + ectx.Check(value is JObject, "Expected object value"); + Func makeDict = MakeDictionary; + return Utils.MarshalInvoke(makeDict, type.GetGenericArguments()[1], ectx, (JObject)value, attributes, catalog); + case TlcModule.DataKind.Component: + var jo = value as JObject; + ectx.Check(jo != null, "Expected object value"); + // REVIEW: consider accepting strings alone. + var jName = jo[FieldNames.Name]; + ectx.Check(jName != null, "Field '" + FieldNames.Name + "' is required for component."); + ectx.Check(jName is JValue, "Expected '" + FieldNames.Name + "' field to be a string."); + var name = jName.Value(); + ectx.Check(jo[FieldNames.Settings] == null || jo[FieldNames.Settings] is JObject, + "Expected '" + FieldNames.Settings + "' field to be an object"); + return GetComponentJson(ectx, type, name, jo[FieldNames.Settings] as JObject, catalog); + default: + var settings = value as JObject; + ectx.Check(settings != null, "Expected object value"); + var inputBuilder = new InputBuilder(ectx, type, catalog); + + if (inputBuilder._fields.Length == 0) + throw ectx.Except($"Unsupported input type: {dt}"); + + if (settings != null) { - if (!inputBuilder.TrySetValueJson(pair.Key, pair.Value)) - throw ectx.Except($"Unexpected value for component '{type}', field '{pair.Key}': '{pair.Value}'"); + foreach (var pair in settings) + { + if (!inputBuilder.TrySetValueJson(pair.Key, pair.Value)) + throw ectx.Except($"Unexpected value for component '{type}', field '{pair.Key}': '{pair.Value}'"); + } } - } - var missing = inputBuilder.GetMissingValues().ToArray(); - if (missing.Length > 0) - throw ectx.Except($"The following required inputs were not provided for component '{type}': {string.Join(", ", missing)}"); - return inputBuilder.GetInstance(); + var missing = inputBuilder.GetMissingValues().ToArray(); + if (missing.Length > 0) + throw ectx.Except($"The following required inputs were not provided for component '{type}': {string.Join(", ", missing)}"); + return inputBuilder.GetInstance(); } } catch (FormatException ex) diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index 0c2950a87b..e711b87117 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -205,8 +205,8 @@ public static CommonOutputs.MacroOutput CrossValidate( Nodes = new JArray(graph.Select(n => n.ToJson()).ToArray()), TransformModel = null, LabelColumn = input.LabelColumn, - GroupColumn = input.GroupColumn.IsExplicit ? input.GroupColumn : Optional.Implicit(DefaultColumnNames.GroupId), - WeightColumn = input.WeightColumn.IsExplicit ? input.WeightColumn : Optional.Implicit(DefaultColumnNames.Weight) + GroupColumn = input.GroupColumn, + WeightColumn = input.WeightColumn }; if (transformModelVarName != null) @@ -409,7 +409,6 @@ public static CommonOutputs.MacroOutput CrossValidate( var combineConfusionMatrix = new Var(); combineConfusionMatrix.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix)); combineOutputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), combineConfusionMatrix.VarName); - } subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); subGraphNodes.Add(EntryPointNode.Create(env, "Models.CrossValidationResultsCombiner", combineArgs, node.Catalog, node.Context, combineInputBindingMap, combineInputMap, combineOutputMap)); @@ -484,22 +483,22 @@ private static IMamlEvaluator GetEvaluator(IHostEnvironment env, MacroUtils.Trai { switch (kind) { - case MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer: - return new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer: - return new MultiClassMamlEvaluator(env, new MultiClassMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureRegressorTrainer: - return new RegressionMamlEvaluator(env, new RegressionMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureRankerTrainer: - return new RankerMamlEvaluator(env, new RankerMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureAnomalyDetectorTrainer: - return new AnomalyDetectionMamlEvaluator(env, new AnomalyDetectionMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureClusteringTrainer: - return new ClusteringMamlEvaluator(env, new ClusteringMamlEvaluator.Arguments()); - case MacroUtils.TrainerKinds.SignatureMultiOutputRegressorTrainer: - return new MultiOutputRegressionMamlEvaluator(env, new MultiOutputRegressionMamlEvaluator.Arguments()); - default: - throw env.ExceptParam(nameof(kind), $"Trainer kind {kind} does not have an evaluator"); + case MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer: + return new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer: + return new MultiClassMamlEvaluator(env, new MultiClassMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureRegressorTrainer: + return new RegressionMamlEvaluator(env, new RegressionMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureRankerTrainer: + return new RankerMamlEvaluator(env, new RankerMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureAnomalyDetectorTrainer: + return new AnomalyDetectionMamlEvaluator(env, new AnomalyDetectionMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureClusteringTrainer: + return new ClusteringMamlEvaluator(env, new ClusteringMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureMultiOutputRegressorTrainer: + return new MultiOutputRegressionMamlEvaluator(env, new MultiOutputRegressionMamlEvaluator.Arguments()); + default: + throw env.ExceptParam(nameof(kind), $"Trainer kind {kind} does not have an evaluator"); } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index 029ab670fd..bbcee502d7 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -270,16 +270,22 @@ public void TestCrossValidationMacro() var nop = new ML.Transforms.NoOperation(); var nopOutput = subGraph.Add(nop); - var learnerInput = new ML.Trainers.StochasticDualCoordinateAscentRegressor + var generate = new ML.Transforms.RandomNumberGenerator(); + generate.Column = new[] { new ML.Transforms.GenerateNumberTransformColumn() { Name = "Weight1" } }; + generate.Data = nopOutput.OutputData; + var generateOutput = subGraph.Add(generate); + + var learnerInput = new ML.Trainers.PoissonRegressor { - TrainingData = nopOutput.OutputData, - NumThreads = 1 + TrainingData = generateOutput.OutputData, + NumThreads = 1, + WeightColumn = "Weight1" }; var learnerOutput = subGraph.Add(learnerInput); var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner { - TransformModels = new ArrayVar(nopOutput.Model), + TransformModels = new ArrayVar(nopOutput.Model, generateOutput.Model), PredictorModel = learnerOutput.PredictorModel }; var modelCombineOutput = subGraph.Add(modelCombine); @@ -316,7 +322,8 @@ public void TestCrossValidationMacro() Data = importOutput.Data, Nodes = subGraph, Kind = ML.Models.MacroUtilsTrainerKinds.SignatureRegressorTrainer, - TransformModel = null + TransformModel = null, + WeightColumn = "Weight1" }; crossValidate.Inputs.Data = nop.Data; crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; @@ -332,40 +339,66 @@ public void TestCrossValidationMacro() Assert.True(b); b = schema.TryGetColumnIndex("Fold Index", out int foldCol); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + b = schema.TryGetColumnIndex("IsWeighted", out int isWeightedCol); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol || col == isWeightedCol)) { var getter = cursor.GetGetter(metricCol); var foldGetter = cursor.GetGetter(foldCol); + var isWeightedGetter = cursor.GetGetter(isWeightedCol); DvText fold = default; + DvBool isWeighted = default; - // Get the verage. - b = cursor.MoveNext(); - Assert.True(b); double avg = 0; - getter(ref avg); - foldGetter(ref fold); - Assert.True(fold.EqualsStr("Average")); - - // Get the standard deviation. - b = cursor.MoveNext(); - Assert.True(b); - double stdev = 0; - getter(ref stdev); - foldGetter(ref fold); - Assert.True(fold.EqualsStr("Standard Deviation")); - Assert.Equal(0.0013, stdev, 4); - - double sum = 0; - double val = 0; - for (int f = 0; f < 2; f++) + double weightedAvg = 0; + for (int w = 0; w < 2; w++) { + // Get the average. b = cursor.MoveNext(); Assert.True(b); - getter(ref val); + if (w == 1) + getter(ref weightedAvg); + else + getter(ref avg); foldGetter(ref fold); - sum += val; - Assert.True(fold.EqualsStr("Fold " + f)); + Assert.True(fold.EqualsStr("Average")); + isWeightedGetter(ref isWeighted); + Assert.True(isWeighted.IsTrue == (w == 1)); + + // Get the standard deviation. + b = cursor.MoveNext(); + Assert.True(b); + double stdev = 0; + getter(ref stdev); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Standard Deviation")); + if (w == 1) + Assert.Equal(0.002827, stdev, 6); + else + Assert.Equal(0.002376, stdev, 6); + isWeightedGetter(ref isWeighted); + Assert.True(isWeighted.IsTrue == (w == 1)); + } + double sum = 0; + double weightedSum = 0; + for (int f = 0; f < 2; f++) + { + for (int w = 0; w < 2; w++) + { + b = cursor.MoveNext(); + Assert.True(b); + double val = 0; + getter(ref val); + foldGetter(ref fold); + if (w == 1) + weightedSum += val; + else + sum += val; + Assert.True(fold.EqualsStr("Fold " + f)); + isWeightedGetter(ref isWeighted); + Assert.True(isWeighted.IsTrue == (w == 1)); + } } + Assert.Equal(weightedAvg, weightedSum / 2); Assert.Equal(avg, sum / 2); b = cursor.MoveNext(); Assert.False(b); @@ -681,9 +714,9 @@ public void TestCrossValidationMacroWithNonDefaultNames() getter(ref stdev); foldGetter(ref fold); Assert.True(fold.EqualsStr("Standard Deviation")); - Assert.Equal(2.462, stdev.Values[0], 3); - Assert.Equal(2.763, stdev.Values[1], 3); - Assert.Equal(3.273, stdev.Values[2], 3); + Assert.Equal(5.247, stdev.Values[0], 3); + Assert.Equal(4.703, stdev.Values[1], 3); + Assert.Equal(3.844, stdev.Values[2], 3); var sumBldr = new BufferBuilder(R8Adder.Instance); sumBldr.Reset(avg.Length, true);