diff --git a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs index e7b1ff6741..6633e2535f 100644 --- a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs +++ b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs @@ -190,6 +190,8 @@ protected Delegate CreateGetter(int col) public ValueGetter GetGetter(int col) { Ch.Check(IsColumnActive(col), "The column must be active against the defined predicate."); + if (!(Getters[col] is ValueGetter)) + throw Ch.Except($"Invalid TValue in GetGetter: '{typeof(TValue)}'"); return Getters[col] as ValueGetter; } diff --git a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs index 2230329c78..1ff3daee02 100644 --- a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs +++ b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs @@ -473,9 +473,9 @@ public float Cost } } - public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunContext context, + private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCatalog, RunContext context, string id, string entryPointName, JObject inputs, JObject outputs, bool checkpoint = false, - string stageId = "", float cost = float.NaN) + string stageId = "", float cost = float.NaN, string label = null, string group = null, string weight = null) { Contracts.AssertValue(env); env.AssertNonEmpty(id); @@ -497,6 +497,7 @@ public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunCont _inputMap = new Dictionary(); _inputBindingMap = new Dictionary>(); _inputBuilder = new InputBuilder(_host, _entryPoint.InputType, moduleCatalog); + // REVIEW: This logic should move out of Node eventually and be delegated to // a class that can nest to handle Components with variables. if (inputs != null) @@ -508,6 +509,51 @@ public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunCont if (missing.Length > 0) throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}"); + var inputInstance = _inputBuilder.GetInstance(); + var warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." + + " Using column '{2}'. To column use '{1}' instead, please specify this name in" + + "the trainer node arguments."; + if (!string.IsNullOrEmpty(label) && Utils.Size(_entryPoint.InputKinds) > 0 && + _entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithLabel))) + { + var labelColField = _inputBuilder.GetFieldNameOrNull("LabelColumn"); + ch.AssertNonEmpty(labelColField); + var labelColFieldType = _inputBuilder.GetFieldTypeOrNull(labelColField); + ch.Assert(labelColFieldType == typeof(string)); + var inputLabel = inputInstance.GetType().GetField(labelColField).GetValue(inputInstance); + if (label != (string)inputLabel) + ch.Warning(warning, "label", label, inputLabel); + else + _inputBuilder.TrySetValue(labelColField, label); + } + if (!string.IsNullOrEmpty(group) && Utils.Size(_entryPoint.InputKinds) > 0 && + _entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithGroupId))) + { + var groupColField = _inputBuilder.GetFieldNameOrNull("GroupIdColumn"); + ch.AssertNonEmpty(groupColField); + var groupColFieldType = _inputBuilder.GetFieldTypeOrNull(groupColField); + ch.Assert(groupColFieldType == typeof(string)); + var inputGroup = inputInstance.GetType().GetField(groupColField).GetValue(inputInstance); + if (group != (Optional)inputGroup) + ch.Warning(warning, "group Id", label, inputGroup); + else + _inputBuilder.TrySetValue(groupColField, label); + } + if (!string.IsNullOrEmpty(weight) && Utils.Size(_entryPoint.InputKinds) > 0 && + (_entryPoint.InputKinds.Contains(typeof(CommonInputs.ITrainerInputWithWeight)) || + _entryPoint.InputKinds.Contains(typeof(CommonInputs.IUnsupervisedTrainerWithWeight)))) + { + var weightColField = _inputBuilder.GetFieldNameOrNull("WeightColumn"); + ch.AssertNonEmpty(weightColField); + var weightColFieldType = _inputBuilder.GetFieldTypeOrNull(weightColField); + ch.Assert(weightColFieldType == typeof(string)); + var inputWeight = inputInstance.GetType().GetField(weightColField).GetValue(inputInstance); + if (weight != (Optional)inputWeight) + ch.Warning(warning, "weight", label, inputWeight); + else + _inputBuilder.TrySetValue(weightColField, label); + } + // Validate outputs. _outputHelper = new OutputHelper(_host, _entryPoint.OutputType); _outputMap = new Dictionary(); @@ -550,10 +596,15 @@ public static EntryPointNode Create( var inputBuilder = new InputBuilder(env, info.InputType, catalog); var outputHelper = new OutputHelper(env, info.OutputType); - var entryPointNode = new EntryPointNode(env, catalog, context, context.GenerateId(entryPointName), entryPointName, - inputBuilder.GetJsonObject(arguments, inputBindingMap, inputMap), - outputHelper.GetJsonObject(outputMap), checkpoint, stageId, cost); - return entryPointNode; + using (var ch = env.Start("Create EntryPointNode")) + { + var entryPointNode = new EntryPointNode(env, ch, catalog, context, context.GenerateId(entryPointName), entryPointName, + inputBuilder.GetJsonObject(arguments, inputBindingMap, inputMap), + outputHelper.GetJsonObject(outputMap), checkpoint, stageId, cost); + + ch.Done(); + return entryPointNode; + } } public static EntryPointNode Create( @@ -850,7 +901,8 @@ private object BuildParameterValue(List bindings) throw _host.ExceptNotImpl("Unsupported ParameterBinding"); } - public static List ValidateNodes(IHostEnvironment env, RunContext context, JArray nodes, ModuleCatalog moduleCatalog) + public static List ValidateNodes(IHostEnvironment env, RunContext context, JArray nodes, + ModuleCatalog moduleCatalog, string label = null, string group = null, string weight = null) { Contracts.AssertValue(env); env.AssertValue(context); @@ -890,8 +942,10 @@ public static List ValidateNodes(IHostEnvironment env, RunContex ch.Warning("Node '{0}' has unexpected fields that are ignored: {1}", id, string.Join(", ", unexpectedFields.Select(x => x.Name))); } - result.Add(new EntryPointNode(env, moduleCatalog, context, id, name, inputs, outputs, checkpoint, stageId, cost)); + result.Add(new EntryPointNode(env, ch, moduleCatalog, context, id, name, inputs, outputs, checkpoint, stageId, cost, label, group, weight)); } + + ch.Done(); } return result; } diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs index 939d49893b..e5afd8dbb5 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs @@ -4,12 +4,11 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Reflection; using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Internal.Utilities; using Newtonsoft.Json.Linq; namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils @@ -405,7 +404,11 @@ private static object ParseJsonValue(IExceptionContext ectx, Type type, Attribut return null; if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Optional<>) || type.GetGenericTypeDefinition() == typeof(Nullable<>))) + { + if (type.GetGenericTypeDefinition() == typeof(Optional<>) && value.HasValues) + value = value.Values().FirstOrDefault(); type = type.GetGenericArguments()[0]; + } if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Var<>))) { diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 058e8bafe3..e35fb2ff81 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -2165,6 +2165,16 @@ public sealed partial class CrossValidationResultsCombiner /// public string LabelColumn { get; set; } = "Label"; + /// + /// Column to use for example weight + /// + public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + + /// + /// Column to use for grouping + /// + public Microsoft.ML.Runtime.EntryPoints.Optional GroupColumn { get; set; } + /// /// Specifies the trainer kind, which determines the evaluator to be used. /// @@ -2270,6 +2280,21 @@ public sealed partial class CrossValidator /// public Models.MacroUtilsTrainerKinds Kind { get; set; } = Models.MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; + /// + /// Column to use for labels + /// + public string LabelColumn { get; set; } = "Label"; + + /// + /// Column to use for example weight + /// + public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + + /// + /// Column to use for grouping + /// + public Microsoft.ML.Runtime.EntryPoints.Optional GroupColumn { get; set; } + public sealed class Output { @@ -3456,6 +3481,21 @@ public sealed partial class TrainTestEvaluator /// public bool IncludeTrainingMetrics { get; set; } = false; + /// + /// Column to use for labels + /// + public string LabelColumn { get; set; } = "Label"; + + /// + /// Column to use for example weight + /// + public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + + /// + /// Column to use for grouping + /// + public Microsoft.ML.Runtime.EntryPoints.Optional GroupColumn { get; set; } + public sealed class Output { @@ -6173,7 +6213,7 @@ public enum KMeansPlusPlusTrainerInitAlgorithm /// /// K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers. /// - public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem + public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem { @@ -6208,6 +6248,11 @@ public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.Entry /// public int? NumThreads { get; set; } + /// + /// Column to use for example weight + /// + public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + /// /// The data to be used for training /// @@ -7024,7 +7069,7 @@ namespace Trainers /// /// Train an PCA Anomaly model. /// - public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem + public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem { diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index 325b8a093a..e711b87117 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -77,6 +77,15 @@ public sealed class Arguments // (and the same for the TrainTest macro). I currently do not know how to do this, so this should be revisited in the future. [Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 8)] public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer; + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 10)] + public string LabelColumn = DefaultColumnNames.Label; + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 11)] + public Optional WeightColumn = Optional.Implicit(DefaultColumnNames.Weight); + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)] + public Optional GroupColumn = Optional.Implicit(DefaultColumnNames.GroupId); } // REVIEW: This output would be much better as an array of CommonOutputs.ClassificationEvaluateOutput, @@ -121,6 +130,12 @@ public sealed class CombineMetricsInput [Argument(ArgumentType.AtMostOnce, HelpText = "The label column name", ShortName = "Label", SortOrder = 5)] public string LabelColumn = DefaultColumnNames.Label; + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 6)] + public Optional WeightColumn = Optional.Implicit(DefaultColumnNames.Weight); + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)] + public Optional GroupColumn = Optional.Implicit(DefaultColumnNames.GroupId); + [Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 6)] public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer; } @@ -188,7 +203,10 @@ public static CommonOutputs.MacroOutput CrossValidate( var args = new TrainTestMacro.Arguments { Nodes = new JArray(graph.Select(n => n.ToJson()).ToArray()), - TransformModel = null + TransformModel = null, + LabelColumn = input.LabelColumn, + GroupColumn = input.GroupColumn, + WeightColumn = input.WeightColumn }; if (transformModelVarName != null) @@ -356,6 +374,9 @@ public static CommonOutputs.MacroOutput CrossValidate( var combineArgs = new CombineMetricsInput(); combineArgs.Kind = input.Kind; + combineArgs.LabelColumn = input.LabelColumn; + combineArgs.WeightColumn = input.WeightColumn; + combineArgs.GroupColumn = input.GroupColumn; // Set the input bindings for the CombineMetrics entry point. var combineInputBindingMap = new Dictionary>(); @@ -383,10 +404,12 @@ public static CommonOutputs.MacroOutput CrossValidate( var combineInstanceMetric = new Var(); combineInstanceMetric.VarName = node.GetOutputVariableName(nameof(Output.PerInstanceMetrics)); combineOutputMap.Add(nameof(Output.PerInstanceMetrics), combineInstanceMetric.VarName); - var combineConfusionMatrix = new Var(); - combineConfusionMatrix.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix)); - combineOutputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), combineConfusionMatrix.VarName); - + if (confusionMatricesOutput != null) + { + var combineConfusionMatrix = new Var(); + combineConfusionMatrix.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix)); + combineOutputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), combineConfusionMatrix.VarName); + } subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); subGraphNodes.Add(EntryPointNode.Create(env, "Models.CrossValidationResultsCombiner", combineArgs, node.Catalog, node.Context, combineInputBindingMap, combineInputMap, combineOutputMap)); return new CommonOutputs.MacroOutput() { Nodes = subGraphNodes }; @@ -398,7 +421,12 @@ public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetrics var eval = GetEvaluator(env, input.Kind); var perInst = EvaluateUtils.ConcatenatePerInstanceDataViews(env, eval, true, true, input.PerInstanceMetrics.Select( - idv => RoleMappedData.Create(idv, RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn))).ToArray(), + idv => RoleMappedData.CreateOpt(idv, new[] + { + RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn), + RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, input.WeightColumn.Value), + RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, input.GroupColumn.Value) + })).ToArray(), out var variableSizeVectorColumnNames); var warnings = input.Warnings != null ? new List(input.Warnings) : new List(); diff --git a/src/Microsoft.ML/Runtime/EntryPoints/MacroUtils.cs b/src/Microsoft.ML/Runtime/EntryPoints/MacroUtils.cs index aa5e8abc16..7e7f6acbac 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/MacroUtils.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/MacroUtils.cs @@ -30,9 +30,8 @@ public sealed class EvaluatorSettings { public string LabelColumn { get; set; } public string NameColumn { get; set; } - public string ScoreColumn { get; set; } - public string[] StratColumn { get; set; } public string WeightColumn { get; set; } + public string GroupColumn { get; set; } public string FeatureColumn { get; set; } public EvaluatorSettings() @@ -61,8 +60,6 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, WeightColumn = settings.WeightColumn }, EvaluatorOutput = () => new Models.BinaryClassificationEvaluator.Output() @@ -77,8 +74,6 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, WeightColumn = settings.WeightColumn }, EvaluatorOutput = () => new Models.ClassificationEvaluator.Output() @@ -93,9 +88,8 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, - WeightColumn = settings.WeightColumn + WeightColumn = settings.WeightColumn, + GroupIdColumn = settings.GroupColumn }, EvaluatorOutput = () => new Models.RankerEvaluator.Output() } @@ -109,8 +103,6 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, WeightColumn = settings.WeightColumn }, EvaluatorOutput = () => new Models.RegressionEvaluator.Output() @@ -125,8 +117,6 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, WeightColumn = settings.WeightColumn, }, EvaluatorOutput = () => new Models.MultiOutputRegressionEvaluator.Output() @@ -141,8 +131,6 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, WeightColumn = settings.WeightColumn }, EvaluatorOutput = () => new Models.AnomalyDetectionEvaluator.Output() @@ -157,8 +145,6 @@ private static Dictionary { LabelColumn = settings.LabelColumn, NameColumn = settings.NameColumn, - ScoreColumn = settings.ScoreColumn, - StratColumn = settings.StratColumn, WeightColumn = settings.WeightColumn, FeatureColumn = settings.FeatureColumn }, diff --git a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs index c3c06b1031..3e2ba94615 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs @@ -62,6 +62,15 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates whether to include and output training dataset metrics.", SortOrder = 9)] public Boolean IncludeTrainingMetrics = false; + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 10)] + public string LabelColumn = DefaultColumnNames.Label; + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 11)] + public Optional WeightColumn = Optional.Implicit(DefaultColumnNames.Weight); + + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for grouping", ShortName = "group", SortOrder = 12)] + public Optional GroupColumn = Optional.Implicit(DefaultColumnNames.GroupId); } public sealed class Output @@ -110,7 +119,8 @@ public static CommonOutputs.MacroOutput TrainTest( // Parse the subgraph. var subGraphRunContext = new RunContext(env); - var subGraphNodes = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes, node.Catalog); + var subGraphNodes = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes, node.Catalog, input.LabelColumn, + input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null, input.WeightColumn.IsExplicit ? input.WeightColumn.Value : null); // Change the subgraph to use the training data as input. var varName = input.Inputs.Data.VarName; @@ -206,11 +216,13 @@ public static CommonOutputs.MacroOutput TrainTest( // Do not double-add previous nodes. exp.Reset(); - // REVIEW: we need to extract the proper label column name here to pass to the evaluators. - // This is where you would add code to do it. + + // REVIEW: add similar support for NameColumn and FeatureColumn. var settings = new MacroUtils.EvaluatorSettings { - LabelColumn = DefaultColumnNames.Label + LabelColumn = input.LabelColumn, + WeightColumn = input.WeightColumn.IsExplicit ? input.WeightColumn.Value : null, + GroupColumn = input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null }; string outVariableName; diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 7e06c77821..e237902f6f 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -1333,6 +1333,18 @@ "IsNullable": false, "Default": "Label" }, + { + "Name": "WeightColumn", + "Type": "String", + "Desc": "Column to use for example weight", + "Aliases": [ + "weight" + ], + "Required": false, + "SortOrder": 6.0, + "IsNullable": false, + "Default": "Weight" + }, { "Name": "Kind", "Type": { @@ -1352,6 +1364,18 @@ "SortOrder": 6.0, "IsNullable": false, "Default": "SignatureBinaryClassifierTrainer" + }, + { + "Name": "GroupColumn", + "Type": "String", + "Desc": "Column to use for grouping", + "Aliases": [ + "group" + ], + "Required": false, + "SortOrder": 12.0, + "IsNullable": false, + "Default": "GroupId" } ], "Outputs": [ @@ -1504,6 +1528,42 @@ "SortOrder": 8.0, "IsNullable": false, "Default": "SignatureBinaryClassifierTrainer" + }, + { + "Name": "LabelColumn", + "Type": "String", + "Desc": "Column to use for labels", + "Aliases": [ + "lab" + ], + "Required": false, + "SortOrder": 10.0, + "IsNullable": false, + "Default": "Label" + }, + { + "Name": "WeightColumn", + "Type": "String", + "Desc": "Column to use for example weight", + "Aliases": [ + "weight" + ], + "Required": false, + "SortOrder": 11.0, + "IsNullable": false, + "Default": "Weight" + }, + { + "Name": "GroupColumn", + "Type": "String", + "Desc": "Column to use for grouping", + "Aliases": [ + "group" + ], + "Required": false, + "SortOrder": 12.0, + "IsNullable": false, + "Default": "GroupId" } ], "Outputs": [ @@ -3107,6 +3167,42 @@ "SortOrder": 9.0, "IsNullable": false, "Default": false + }, + { + "Name": "LabelColumn", + "Type": "String", + "Desc": "Column to use for labels", + "Aliases": [ + "lab" + ], + "Required": false, + "SortOrder": 10.0, + "IsNullable": false, + "Default": "Label" + }, + { + "Name": "WeightColumn", + "Type": "String", + "Desc": "Column to use for example weight", + "Aliases": [ + "weight" + ], + "Required": false, + "SortOrder": 11.0, + "IsNullable": false, + "Default": "Weight" + }, + { + "Name": "GroupColumn", + "Type": "String", + "Desc": "Column to use for grouping", + "Aliases": [ + "group" + ], + "Required": false, + "SortOrder": 12.0, + "IsNullable": false, + "Default": "GroupId" } ], "Outputs": [ diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index 4c1969c2d9..bbcee502d7 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -270,16 +270,22 @@ public void TestCrossValidationMacro() var nop = new ML.Transforms.NoOperation(); var nopOutput = subGraph.Add(nop); - var learnerInput = new ML.Trainers.StochasticDualCoordinateAscentRegressor + var generate = new ML.Transforms.RandomNumberGenerator(); + generate.Column = new[] { new ML.Transforms.GenerateNumberTransformColumn() { Name = "Weight1" } }; + generate.Data = nopOutput.OutputData; + var generateOutput = subGraph.Add(generate); + + var learnerInput = new ML.Trainers.PoissonRegressor { - TrainingData = nopOutput.OutputData, - NumThreads = 1 + TrainingData = generateOutput.OutputData, + NumThreads = 1, + WeightColumn = "Weight1" }; var learnerOutput = subGraph.Add(learnerInput); var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner { - TransformModels = new ArrayVar(nopOutput.Model), + TransformModels = new ArrayVar(nopOutput.Model, generateOutput.Model), PredictorModel = learnerOutput.PredictorModel }; var modelCombineOutput = subGraph.Add(modelCombine); @@ -316,7 +322,8 @@ public void TestCrossValidationMacro() Data = importOutput.Data, Nodes = subGraph, Kind = ML.Models.MacroUtilsTrainerKinds.SignatureRegressorTrainer, - TransformModel = null + TransformModel = null, + WeightColumn = "Weight1" }; crossValidate.Inputs.Data = nop.Data; crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; @@ -332,40 +339,66 @@ public void TestCrossValidationMacro() Assert.True(b); b = schema.TryGetColumnIndex("Fold Index", out int foldCol); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + b = schema.TryGetColumnIndex("IsWeighted", out int isWeightedCol); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol || col == isWeightedCol)) { var getter = cursor.GetGetter(metricCol); var foldGetter = cursor.GetGetter(foldCol); + var isWeightedGetter = cursor.GetGetter(isWeightedCol); DvText fold = default; + DvBool isWeighted = default; - // Get the verage. - b = cursor.MoveNext(); - Assert.True(b); double avg = 0; - getter(ref avg); - foldGetter(ref fold); - Assert.True(fold.EqualsStr("Average")); - - // Get the standard deviation. - b = cursor.MoveNext(); - Assert.True(b); - double stdev = 0; - getter(ref stdev); - foldGetter(ref fold); - Assert.True(fold.EqualsStr("Standard Deviation")); - Assert.Equal(0.0013, stdev, 4); - - double sum = 0; - double val = 0; - for (int f = 0; f < 2; f++) + double weightedAvg = 0; + for (int w = 0; w < 2; w++) { + // Get the average. b = cursor.MoveNext(); Assert.True(b); - getter(ref val); + if (w == 1) + getter(ref weightedAvg); + else + getter(ref avg); foldGetter(ref fold); - sum += val; - Assert.True(fold.EqualsStr("Fold " + f)); + Assert.True(fold.EqualsStr("Average")); + isWeightedGetter(ref isWeighted); + Assert.True(isWeighted.IsTrue == (w == 1)); + + // Get the standard deviation. + b = cursor.MoveNext(); + Assert.True(b); + double stdev = 0; + getter(ref stdev); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Standard Deviation")); + if (w == 1) + Assert.Equal(0.002827, stdev, 6); + else + Assert.Equal(0.002376, stdev, 6); + isWeightedGetter(ref isWeighted); + Assert.True(isWeighted.IsTrue == (w == 1)); + } + double sum = 0; + double weightedSum = 0; + for (int f = 0; f < 2; f++) + { + for (int w = 0; w < 2; w++) + { + b = cursor.MoveNext(); + Assert.True(b); + double val = 0; + getter(ref val); + foldGetter(ref fold); + if (w == 1) + weightedSum += val; + else + sum += val; + Assert.True(fold.EqualsStr("Fold " + f)); + isWeightedGetter(ref isWeighted); + Assert.True(isWeighted.IsTrue == (w == 1)); + } } + Assert.Equal(weightedAvg, weightedSum / 2); Assert.Equal(avg, sum / 2); b = cursor.MoveNext(); Assert.False(b); @@ -593,5 +626,118 @@ public void TestCrossValidationMacroWithStratification() } } } + + [Fact] + public void TestCrossValidationMacroWithNonDefaultNames() + { + string dataPath = GetDataPath(@"adult.tiny.with-schema.txt"); + using (var env = new TlcEnvironment(42)) + { + var subGraph = env.CreateExperiment(); + + var textToKey = new ML.Transforms.TextToKeyConverter(); + textToKey.Column = new[] { new ML.Transforms.TermTransformColumn() { Name = "Label1", Source = "Label" } }; + var textToKeyOutput = subGraph.Add(textToKey); + + var hash = new ML.Transforms.HashConverter(); + hash.Column = new[] { new ML.Transforms.HashJoinTransformColumn() { Name = "GroupId1", Source = "Workclass" } }; + hash.Data = textToKeyOutput.OutputData; + var hashOutput = subGraph.Add(hash); + + var learnerInput = new Trainers.FastTreeRanker + { + TrainingData = hashOutput.OutputData, + NumThreads = 1, + LabelColumn = "Label1", + GroupIdColumn = "GroupId1" + }; + var learnerOutput = subGraph.Add(learnerInput); + + var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner + { + TransformModels = new ArrayVar(textToKeyOutput.Model, hashOutput.Model), + PredictorModel = learnerOutput.PredictorModel + }; + var modelCombineOutput = subGraph.Add(modelCombine); + + var experiment = env.CreateExperiment(); + var importInput = new ML.Data.TextLoader(dataPath); + importInput.Arguments.HasHeader = true; + importInput.Arguments.Column = new TextLoaderColumn[] + { + new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, + new TextLoaderColumn { Name = "Workclass", Source = new[] { new TextLoaderRange(1) }, Type = DataKind.Text }, + new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(9, 14) } } + }; + var importOutput = experiment.Add(importInput); + + var crossValidate = new Models.CrossValidator + { + Data = importOutput.Data, + Nodes = subGraph, + TransformModel = null, + LabelColumn = "Label1", + GroupColumn = "GroupId1", + Kind = Models.MacroUtilsTrainerKinds.SignatureRankerTrainer + }; + crossValidate.Inputs.Data = textToKey.Data; + crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; + var crossValidateOutput = experiment.Add(crossValidate); + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); + + var schema = data.Schema; + var b = schema.TryGetColumnIndex("NDCG", out int metricCol); + Assert.True(b); + b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + { + var getter = cursor.GetGetter>(metricCol); + var foldGetter = cursor.GetGetter(foldCol); + DvText fold = default; + + // Get the verage. + b = cursor.MoveNext(); + Assert.True(b); + var avg = default(VBuffer); + getter(ref avg); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Average")); + + // Get the standard deviation. + b = cursor.MoveNext(); + Assert.True(b); + var stdev = default(VBuffer); + getter(ref stdev); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Standard Deviation")); + Assert.Equal(5.247, stdev.Values[0], 3); + Assert.Equal(4.703, stdev.Values[1], 3); + Assert.Equal(3.844, stdev.Values[2], 3); + + var sumBldr = new BufferBuilder(R8Adder.Instance); + sumBldr.Reset(avg.Length, true); + var val = default(VBuffer); + for (int f = 0; f < 2; f++) + { + b = cursor.MoveNext(); + Assert.True(b); + getter(ref val); + foldGetter(ref fold); + sumBldr.AddFeatures(0, ref val); + Assert.True(fold.EqualsStr("Fold " + f)); + } + var sum = default(VBuffer); + sumBldr.GetResult(ref sum); + for (int i = 0; i < avg.Length; i++) + Assert.Equal(avg.Values[i], sum.Values[i] / 2); + b = cursor.MoveNext(); + Assert.False(b); + } + } + } } }