diff --git a/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs b/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs index 75ddfc4ced..f4e535816e 100644 --- a/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs +++ b/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs @@ -130,11 +130,10 @@ private void RunCore(IChannel ch) if (!string.IsNullOrWhiteSpace(Args.Columns)) { - var args = new ChooseColumnsTransform.Arguments(); - args.Column = Args.Columns - .Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries).Select(s => new ChooseColumnsTransform.Column() { Name = s }).ToArray(); - if (Utils.Size(args.Column) > 0) - data = new ChooseColumnsTransform(Host, args, data); + var keepColumns = Args.Columns + .Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries).ToArray(); + if (Utils.Size(keepColumns) > 0) + data = SelectColumnsTransform.CreateKeep(Host, data, keepColumns); } IDataSaver saver; diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs index 0d38960a04..96ce1d8c61 100644 --- a/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs +++ b/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs @@ -101,7 +101,7 @@ public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(I } var copyColumn = new CopyColumnsTransform(env, copyCols.ToArray()).Transform(input.Data); - var dropColumn = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = copyCols.Select(c => c.Source).ToArray() }, copyColumn); + var dropColumn = SelectColumnsTransform.CreateDrop(env, copyColumn, copyCols.Select(c => c.Source).ToArray()); return new CommonOutputs.TransformOutput { Model = new TransformModel(env, dropColumn, input.Data), OutputData = dropColumn }; } } diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index 35735cb654..e7448e2b48 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -703,59 +703,42 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary() - { - new ChooseColumnsTransform.Column() - { - Name = string.Format(FoldDrAtKFormat, _k), - Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtK - }, - new ChooseColumnsTransform.Column() - { - Name = string.Format(FoldDrAtPFormat, _p), - Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr - }, - new ChooseColumnsTransform.Column() - { - Name = string.Format(FoldDrAtNumAnomaliesFormat, numAnomalies), - Source=AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos - }, - new ChooseColumnsTransform.Column() - { - Name=AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK - }, - new ChooseColumnsTransform.Column() - { - Name=AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP - }, - new ChooseColumnsTransform.Column() - { - Name=AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos - }, - new ChooseColumnsTransform.Column() - { - Name = BinaryClassifierEvaluator.Auc - } - }; + var kFormatName = string.Format(FoldDrAtKFormat, _k); + var pFormatName = string.Format(FoldDrAtPFormat, _p); + var numAnomName = string.Format(FoldDrAtNumAnomaliesFormat, numAnomalies); + + (string Source, string Name)[] cols = + { + (AnomalyDetectionEvaluator.OverallMetrics.DrAtK, kFormatName), + (AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr, pFormatName), + (AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos, numAnomName) + }; + + // List of columns to keep, note that the order specified determines the order of the output + var colsToKeep = new List(); + colsToKeep.Add(kFormatName); + colsToKeep.Add(pFormatName); + colsToKeep.Add(numAnomName); + colsToKeep.Add(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK); + colsToKeep.Add(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP); + colsToKeep.Add(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos); + colsToKeep.Add(BinaryClassifierEvaluator.Auc); + + overall = new CopyColumnsTransform(Host, cols).Transform(overall); + IDataView fold = SelectColumnsTransform.CreateKeep(Host, overall, colsToKeep.ToArray()); - args.Column = cols.ToArray(); - IDataView fold = new ChooseColumnsTransform(Host, args, overall); string weightedFold; ch.Info(MetricWriter.GetPerFoldResults(Host, fold, out weightedFold)); } protected override IDataView GetOverallResultsCore(IDataView overall) { - var args = new DropColumnsTransform.Arguments(); - args.Column = new[] - { - AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies, - AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK, - AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP, - AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos - }; - return new DropColumnsTransform(Host, args, overall); + return SelectColumnsTransform.CreateDrop(Host, + overall, + AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies, + AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK, + AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP, + AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos); } protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index 67cb81eb98..bee170ecef 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -1333,43 +1333,33 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary() - { - new ChooseColumnsTransform.Column() - { - Name = FoldAccuracy, - Source = BinaryClassifierEvaluator.Accuracy - }, - new ChooseColumnsTransform.Column() - { - Name = FoldLogLoss, - Source = BinaryClassifierEvaluator.LogLoss - }, - new ChooseColumnsTransform.Column() - { - Name = BinaryClassifierEvaluator.Entropy - }, - new ChooseColumnsTransform.Column() - { - Name = FoldLogLosRed, - Source = BinaryClassifierEvaluator.LogLossReduction - }, - new ChooseColumnsTransform.Column() - { - Name = BinaryClassifierEvaluator.Auc - } - }; + (string Source, string Name)[] cols = + { + (BinaryClassifierEvaluator.Accuracy, FoldAccuracy), + (BinaryClassifierEvaluator.LogLoss, FoldLogLoss), + (BinaryClassifierEvaluator.LogLossReduction, FoldLogLosRed) + }; + + var colsToKeep = new List(); + colsToKeep.Add(FoldAccuracy); + colsToKeep.Add(FoldLogLoss); + colsToKeep.Add(BinaryClassifierEvaluator.Entropy); + colsToKeep.Add(FoldLogLosRed); + colsToKeep.Add(BinaryClassifierEvaluator.Auc); + int index; if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out index)) - cols.Add(new ChooseColumnsTransform.Column() { Name = MetricKinds.ColumnNames.IsWeighted }); + colsToKeep.Add(MetricKinds.ColumnNames.IsWeighted); if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out index)) - cols.Add(new ChooseColumnsTransform.Column() { Name = MetricKinds.ColumnNames.StratCol }); + colsToKeep.Add(MetricKinds.ColumnNames.StratCol); if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out index)) - cols.Add(new ChooseColumnsTransform.Column() { Name = MetricKinds.ColumnNames.StratVal }); + colsToKeep.Add(MetricKinds.ColumnNames.StratVal); + + fold = new CopyColumnsTransform(Host, cols).Transform(fold); + + // Select the columns that are specified in the Copy + fold = SelectColumnsTransform.CreateKeep(Host, fold, colsToKeep.ToArray()); - args.Column = cols.ToArray(); - fold = new ChooseColumnsTransform(Host, args, fold); string weightedConf; var unweightedConf = MetricWriter.GetConfusionTable(Host, conf, out weightedConf); string weightedFold; @@ -1386,9 +1376,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index ec98a93878..2f4607d849 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -933,7 +933,7 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string variableSizeVectorColumnName, type); // Drop the old column that does not have variable length. - idv = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = new[] { variableSizeVectorColumnName } }, idv); + idv = SelectColumnsTransform.CreateDrop(env, idv, variableSizeVectorColumnName); } return idv; }; @@ -1059,8 +1059,7 @@ internal static IDataView GetOverallMetricsData(IHostEnvironment env, IDataView { if (Utils.Size(nonAveragedCols) > 0) { - var dropArgs = new DropColumnsTransform.Arguments() { Column = nonAveragedCols.ToArray() }; - data = new DropColumnsTransform(env, dropArgs, data); + data = SelectColumnsTransform.CreateDrop(env, data, nonAveragedCols.ToArray()); } idvList.Add(data); } @@ -1734,9 +1733,7 @@ public static IDataView GetNonStratifiedMetrics(IHostEnvironment env, IDataView var found = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal); env.Check(found, "If stratification column exist, data view must also contain a StratVal column"); - var dropArgs = new DropColumnsTransform.Arguments(); - dropArgs.Column = new[] { data.Schema.GetColumnName(stratCol), data.Schema.GetColumnName(stratVal) }; - data = new DropColumnsTransform(env, dropArgs, data); + data = SelectColumnsTransform.CreateDrop(env, data, data.Schema.GetColumnName(stratCol), data.Schema.GetColumnName(stratVal)); return data; } } diff --git a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs index 06a8388175..5e3c77a9e8 100644 --- a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs @@ -213,13 +213,14 @@ private IDataView WrapPerInstance(RoleMappedData perInst) var idv = perInst.Data; // Make a list of column names that Maml outputs as part of the per-instance data view, and then wrap - // the per-instance data computed by the evaluator in a ChooseColumnsTransform. - var cols = new List(); + // the per-instance data computed by the evaluator in a SelectColumnsTransform. + var cols = new List<(string Source, string Name)>(); + var colsToKeep = new List(); // If perInst is the result of cross-validation and contains a fold Id column, include it. int foldCol; if (perInst.Schema.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol)) - cols.Add(new ChooseColumnsTransform.Column() { Source = MetricKinds.ColumnNames.FoldIndex }); + colsToKeep.Add(MetricKinds.ColumnNames.FoldIndex); // Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform. if (perInst.Schema.Name == null) @@ -228,22 +229,24 @@ private IDataView WrapPerInstance(RoleMappedData perInst) args.Column = new[] { new GenerateNumberTransform.Column() { Name = "Instance" } }; args.UseCounter = true; idv = new GenerateNumberTransform(Host, args, idv); - cols.Add(new ChooseColumnsTransform.Column() { Name = "Instance" }); + colsToKeep.Add("Instance"); } else - cols.Add(new ChooseColumnsTransform.Column() { Source = perInst.Schema.Name.Name, Name = "Instance" }); + { + cols.Add((perInst.Schema.Name.Name, "Instance")); + colsToKeep.Add("Instance"); + } // Maml outputs the weight column if it exists. if (perInst.Schema.Weight != null) - cols.Add(new ChooseColumnsTransform.Column() { Name = perInst.Schema.Weight.Name }); + colsToKeep.Add(perInst.Schema.Weight.Name); // Get the other columns from the evaluator. foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema)) - cols.Add(new ChooseColumnsTransform.Column() { Name = col }); + colsToKeep.Add(col); - var chooseArgs = new ChooseColumnsTransform.Arguments(); - chooseArgs.Column = cols.ToArray(); - idv = new ChooseColumnsTransform(Host, chooseArgs, idv); + idv = new CopyColumnsTransform(Host, cols.ToArray()).Transform(idv); + idv = SelectColumnsTransform.CreateKeep(Host, idv, colsToKeep.ToArray()); return GetPerInstanceMetricsCore(idv, perInst.Schema); } diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs index b73465aabe..eb2e52b8a0 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs @@ -1051,22 +1051,14 @@ protected override IDataView GetOverallResultsCore(IDataView overall) private IDataView ChangeTopKAccColumnName(IDataView input) { input = new CopyColumnsTransform(Host, (MultiClassClassifierEvaluator.TopKAccuracy, string.Format(TopKAccuracyFormat, _outputTopKAcc))).Transform(input); - var dropArgs = new DropColumnsTransform.Arguments - { - Column = new[] { MultiClassClassifierEvaluator.TopKAccuracy } - }; - return new DropColumnsTransform(Host, dropArgs, input); + return SelectColumnsTransform.CreateDrop(Host, input, MultiClassClassifierEvaluator.TopKAccuracy ); } private IDataView DropPerClassColumn(IDataView input) { if (input.Schema.TryGetColumnIndex(MultiClassClassifierEvaluator.PerClassLogLoss, out int perClassCol)) { - var args = new DropColumnsTransform.Arguments - { - Column = new[] { MultiClassClassifierEvaluator.PerClassLogLoss } - }; - input = new DropColumnsTransform(Host, args, input); + input = SelectColumnsTransform.CreateDrop(Host, input, MultiClassClassifierEvaluator.PerClassLogLoss); } return input; } diff --git a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs deleted file mode 100644 index 4462f2b8e2..0000000000 --- a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs +++ /dev/null @@ -1,590 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using Float = System.Single; - -[assembly: LoadableClass(typeof(ChooseColumnsTransform), typeof(ChooseColumnsTransform.Arguments), typeof(SignatureDataTransform), - "Choose Columns Transform", "ChooseColumnsTransform", "ChooseColumns", "Choose", DocName = "transform/DropKeepChooseTransforms.md")] - -namespace Microsoft.ML.Transforms -{ - public sealed class ChooseColumnsTransform : RowToRowTransformBase - { - // These values are serialized so should not be changed. - public enum HiddenColumnOption : byte - { - Drop = 1, - Keep = 2, - Rename = 3 - } - - public sealed class Column : OneToOneColumn - { - [Argument(ArgumentType.Multiple, HelpText = "What to do with hidden columns")] - public HiddenColumnOption? Hidden; - - public static Column Parse(string str) - { - Contracts.AssertNonEmpty(str); - - var res = new Column(); - if (res.TryParse(str)) - return res; - return null; - } - - public bool TryUnparse(StringBuilder sb) - { - Contracts.AssertValue(sb); - if (Hidden != null) - return false; - return TryUnparseCore(sb); - } - } - - public sealed class Arguments - { - public Arguments() - { - - } - - internal Arguments(params string[] columns) - { - Column = new Column[columns.Length]; - for (int i = 0; i < columns.Length; i++) - { - Column[i] = new Column() { Source = columns[i], Name = columns[i] }; - } - } - - [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)] - public Column[] Column; - - [Argument(ArgumentType.Multiple, HelpText = "What to do with hidden columns")] - public HiddenColumnOption Hidden = HiddenColumnOption.Drop; - } - - private sealed class Bindings : ISchema - { - /// - /// This encodes the information specified in the Arguments object and is what is persisted. - /// - public sealed class RawColInfo - { - public readonly string Name; - public readonly string Source; - public readonly HiddenColumnOption Hid; - - public RawColInfo(string name, string source, HiddenColumnOption hid) - { - Contracts.AssertNonEmpty(name); - Contracts.AssertNonEmpty(source); - Contracts.Assert(Enum.IsDefined(typeof(HiddenColumnOption), hid)); - - Name = name; - Source = source; - Hid = hid; - } - } - - public sealed class ColInfo - { - public readonly string Name; - public readonly int Source; - public readonly ColumnType TypeSrc; - - public ColInfo(string name, int src, ColumnType typeSrc) - { - Contracts.AssertNonEmpty(name); - Contracts.Assert(src >= 0); - Contracts.AssertValue(typeSrc); - - Name = name; - Source = src; - TypeSrc = typeSrc; - } - } - - public readonly ISchema Input; - public readonly RawColInfo[] RawInfos; - public readonly HiddenColumnOption HidDefault; - public readonly ColInfo[] Infos; - public readonly Dictionary NameToInfoIndex; - - public Schema AsSchema { get; } - - public Bindings(Arguments args, ISchema schemaInput) - { - Contracts.AssertValue(args); - Contracts.AssertValue(schemaInput); - - Input = schemaInput; - - Contracts.Check(Enum.IsDefined(typeof(HiddenColumnOption), args.Hidden), "hidden"); - HidDefault = args.Hidden; - - RawInfos = new RawColInfo[Utils.Size(args.Column)]; - if (RawInfos.Length > 0) - { - var names = new HashSet(); - for (int i = 0; i < RawInfos.Length; i++) - { - var item = args.Column[i]; - string dst = item.Name; - string src = item.Source; - - if (string.IsNullOrWhiteSpace(src)) - src = dst; - else if (string.IsNullOrWhiteSpace(dst)) - dst = src; - Contracts.CheckUserArg(!string.IsNullOrWhiteSpace(dst), nameof(Column.Name)); - - if (!names.Add(dst)) - throw Contracts.ExceptUserArg(nameof(args.Column), "New column '{0}' specified multiple times", dst); - - var hid = item.Hidden ?? args.Hidden; - Contracts.CheckUserArg(Enum.IsDefined(typeof(HiddenColumnOption), hid), nameof(args.Hidden)); - - RawInfos[i] = new RawColInfo(dst, src, hid); - } - } - - BuildInfos(out Infos, out NameToInfoIndex, user: true); - AsSchema = Schema.Create(this); - } - - private void BuildInfos(out ColInfo[] infosArray, out Dictionary nameToCol, bool user) - { - var raws = RawInfos; - var tops = new List(); - - bool rename = false; - Dictionary> dups = null; - if (raws.Length == 0) - { - // Empty raws means take all with default HiddenColumnOption. - var rawList = new List(); - for (int col = 0; col < Input.ColumnCount; col++) - { - string src = Input.GetColumnName(col); - int tmp; - if (!Input.TryGetColumnIndex(src, out tmp)) - { - Contracts.Assert(false, "Why couldn't the schema find the name?"); - continue; - } - - if (tmp == col) - { - var raw = new RawColInfo(src, src, HidDefault); - rawList.Add(raw); - tops.Add(col); - } - else if (HidDefault != HiddenColumnOption.Drop) - { - if (dups == null) - dups = new Dictionary>(); - List list; - if (!dups.TryGetValue(src, out list)) - dups[src] = list = new List(); - list.Add(col); - } - } - if (dups != null && HidDefault == HiddenColumnOption.Rename) - rename = true; - - raws = rawList.ToArray(); - } - else - { - for (int i = 0; i < raws.Length; i++) - { - var raw = raws[i]; - - int col; - if (!Input.TryGetColumnIndex(raw.Source, out col)) - { - throw user ? - Contracts.ExceptUserArg(nameof(Arguments.Column), "source column '{0}' not found", raw.Source) : - Contracts.ExceptDecode("source column '{0}' not found", raw.Source); - } - tops.Add(col); - - if (raw.Hid != HiddenColumnOption.Drop) - { - if (dups == null) - dups = new Dictionary>(); - dups[raw.Source] = null; - if (raw.Hid == HiddenColumnOption.Rename) - rename = true; - } - } - - if (dups != null) - { - for (int col = 0; col < Input.ColumnCount; col++) - { - string src = Input.GetColumnName(col); - List list; - if (!dups.TryGetValue(src, out list)) - continue; - int tmp; - if (!Input.TryGetColumnIndex(src, out tmp)) - { - Contracts.Assert(false, "Why couldn't the schema find the name?"); - continue; - } - if (tmp == col) - continue; - if (list == null) - dups[src] = list = new List(); - list.Add(col); - } - } - } - Contracts.Assert(tops.Count == raws.Length); - - HashSet names = null; - if (rename) - names = new HashSet(raws.Select(r => r.Name)); - - var infos = new List(); - for (int i = 0; i < raws.Length; i++) - { - var raw = raws[i]; - - int colSrc; - ColumnType type; - ColInfo info; - - // Handle dups. - List list; - if (raw.Hid != HiddenColumnOption.Drop && - dups != null && dups.TryGetValue(raw.Source, out list) && list != null) - { - int iinfo = infos.Count; - int inc = 0; - for (int iv = list.Count; --iv >= 0; ) - { - colSrc = list[iv]; - type = Input.GetColumnType(colSrc); - string name = raw.Name; - if (raw.Hid == HiddenColumnOption.Rename) - name = GetUniqueName(names, name, ref inc); - info = new ColInfo(name, colSrc, type); - infos.Insert(iinfo, info); - } - } - - colSrc = tops[i]; - type = Input.GetColumnType(colSrc); - info = new ColInfo(raw.Name, colSrc, type); - infos.Add(info); - } - - infosArray = infos.ToArray(); - nameToCol = new Dictionary(Infos.Length); - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - nameToCol[Infos[iinfo].Name] = iinfo; - } - - private static string GetUniqueName(HashSet names, string name, ref int inc) - { - for (; ; ) - { - string tmp = string.Format("{0}_{1:000}", name, ++inc); - if (names.Add(tmp)) - return tmp; - } - } - - public Bindings(ModelLoadContext ctx, ISchema schemaInput) - { - Contracts.AssertValue(ctx); - Contracts.AssertValue(schemaInput); - - Input = schemaInput; - - // *** Binary format *** - // byte: default HiddenColumnOption value - // int: number of raw column infos - // for each raw column info - // int: id of output column name - // int: id of input column name - // byte: HiddenColumnOption - HidDefault = (HiddenColumnOption)ctx.Reader.ReadByte(); - Contracts.CheckDecode(Enum.IsDefined(typeof(HiddenColumnOption), HidDefault)); - - int count = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(count >= 0); - - RawInfos = new RawColInfo[count]; - if (count > 0) - { - var names = new HashSet(); - for (int i = 0; i < count; i++) - { - string dst = ctx.LoadNonEmptyString(); - Contracts.CheckDecode(names.Add(dst)); - string src = ctx.LoadNonEmptyString(); - - var hid = (HiddenColumnOption)ctx.Reader.ReadByte(); - Contracts.CheckDecode(Enum.IsDefined(typeof(HiddenColumnOption), hid)); - RawInfos[i] = new RawColInfo(dst, src, hid); - } - } - - BuildInfos(out Infos, out NameToInfoIndex, user: false); - - AsSchema = Schema.Create(this); - } - - public void Save(ModelSaveContext ctx) - { - Contracts.AssertValue(ctx); - - // *** Binary format *** - // byte: default HiddenColumnOption value - // int: number of raw column infos - // for each raw column info - // int: id of output column name - // int: id of input column name - // byte: HiddenColumnOption - Contracts.Assert((HiddenColumnOption)(byte)HidDefault == HidDefault); - ctx.Writer.Write((byte)HidDefault); - ctx.Writer.Write(RawInfos.Length); - for (int i = 0; i < RawInfos.Length; i++) - { - var raw = RawInfos[i]; - ctx.SaveNonEmptyString(raw.Name); - ctx.SaveNonEmptyString(raw.Source); - Contracts.Assert((HiddenColumnOption)(byte)raw.Hid == raw.Hid); - ctx.Writer.Write((byte)raw.Hid); - } - } - - public int ColumnCount - { - get { return Infos.Length; } - } - - public bool TryGetColumnIndex(string name, out int col) - { - Contracts.CheckValueOrNull(name); - if (name == null) - { - col = default(int); - return false; - } - return NameToInfoIndex.TryGetValue(name, out col); - } - - public string GetColumnName(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return Infos[col].Name; - } - - public ColumnType GetColumnType(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return Infos[col].TypeSrc; - } - - public IEnumerable> GetMetadataTypes(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return Input.GetMetadataTypes(Infos[col].Source); - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - Contracts.CheckNonEmpty(kind, nameof(kind)); - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return Input.GetMetadataTypeOrNull(kind, Infos[col].Source); - } - - public void GetMetadata(string kind, int col, ref TValue value) - { - Contracts.CheckNonEmpty(kind, nameof(kind)); - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - Input.GetMetadata(kind, Infos[col].Source, ref value); - } - - internal bool[] GetActive(Func predicate) - { - return Utils.BuildArray(ColumnCount, predicate); - } - - internal Func GetDependencies(Func predicate) - { - Contracts.AssertValue(predicate); - var active = new bool[Input.ColumnCount]; - for (int i = 0; i < Infos.Length; i++) - { - if (predicate(i)) - active[Infos[i].Source] = true; - } - return col => 0 <= col && col < active.Length && active[col]; - } - } - - public const string LoaderSignature = "ChooseColumnsTransform"; - internal const string LoaderSignatureOld = "ChooseColumnsFunction"; - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "CHSCOLSF", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, - verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature, - loaderSignatureAlt: LoaderSignatureOld, - loaderAssemblyName: typeof(ChooseColumnsTransform).Assembly.FullName); - } - - private readonly Bindings _bindings; - - private const string RegistrationName = "ChooseColumns"; - - /// - /// Convenience constructor for public facing API. - /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Names of the columns to choose. - public ChooseColumnsTransform(IHostEnvironment env, IDataView input, params string[] columns) - : this(env, new Arguments(columns), input) - { - } - - /// - /// Public constructor corresponding to SignatureDataTransform. - /// - public ChooseColumnsTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, input) - { - Host.CheckValue(args, nameof(args)); - - _bindings = new Bindings(args, Source.Schema); - } - - private ChooseColumnsTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, input) - { - Host.AssertValue(ctx); - - // *** Binary format *** - // int: sizeof(Float) - // bindings - int cbFloat = ctx.Reader.ReadInt32(); - Host.CheckDecode(cbFloat == sizeof(Float)); - _bindings = new Bindings(ctx, Source.Schema); - } - - public static ChooseColumnsTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new ChooseColumnsTransform(h, ctx, input)); - } - - public override void Save(ModelSaveContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - - // *** Binary format *** - // int: sizeof(Float) - // bindings - ctx.Writer.Write(sizeof(Float)); - _bindings.Save(ctx); - } - - public override Schema Schema => _bindings.AsSchema; - - protected override bool? ShouldUseParallelCursors(Func predicate) - { - Host.AssertValue(predicate); - // Parallel doesn't matter to this transform. - return null; - } - - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) - { - Host.AssertValue(predicate, "predicate"); - Host.AssertValueOrNull(rand); - - var inputPred = _bindings.GetDependencies(predicate); - var active = _bindings.GetActive(predicate); - var input = Source.GetRowCursor(inputPred, rand); - return new RowCursor(Host, _bindings, input, active); - } - - public sealed override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) - { - Host.CheckValue(predicate, nameof(predicate)); - Host.CheckValueOrNull(rand); - - var inputPred = _bindings.GetDependencies(predicate); - var active = _bindings.GetActive(predicate); - var inputs = Source.GetRowCursorSet(out consolidator, inputPred, n, rand); - Host.AssertNonEmpty(inputs); - - // No need to split if this is given 1 input cursor. - var cursors = new IRowCursor[inputs.Length]; - for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, _bindings, inputs[i], active); - return cursors; - } - - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor - { - private readonly Bindings _bindings; - private readonly bool[] _active; - - public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, bool[] active) - : base(provider, input) - { - Ch.AssertValue(bindings); - Ch.Assert(active == null || active.Length == bindings.ColumnCount); - - _bindings = bindings; - _active = active; - } - - public Schema Schema => _bindings.AsSchema; - - public bool IsColumnActive(int col) - { - Ch.Check(0 <= col && col < _bindings.ColumnCount); - return _active == null || _active[col]; - } - - public ValueGetter GetGetter(int col) - { - Ch.Check(IsColumnActive(col)); - - var info = _bindings.Infos[col]; - return Input.GetGetter(info.Source); - } - } - } -} diff --git a/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs deleted file mode 100644 index badd3052fd..0000000000 --- a/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs +++ /dev/null @@ -1,413 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Transforms; -using System; -using System.Collections.Generic; -using Float = System.Single; - -[assembly: LoadableClass(DropColumnsTransform.DropColumnsSummary, typeof(DropColumnsTransform), typeof(DropColumnsTransform.Arguments), typeof(SignatureDataTransform), - DropColumnsTransform.DropUserName, "DropColumns", "DropColumnsTransform", DropColumnsTransform.DropShortName, DocName = "transform/DropKeepChooseTransforms.md")] - -[assembly: LoadableClass(DropColumnsTransform.KeepColumnsSummary, typeof(DropColumnsTransform), typeof(DropColumnsTransform.KeepArguments), typeof(SignatureDataTransform), - DropColumnsTransform.KeepUserName, "KeepColumns", "KeepColumnsTransform", DropColumnsTransform.KeepShortName, DocName = "transform/DropKeepChooseTransforms.md")] - -namespace Microsoft.ML.Transforms -{ - /// - /// Transform to drop columns with the given names. Note that if there are names that - /// are not in the input schema, that is not an error. - /// - public sealed class DropColumnsTransform : RowToRowMapperTransformBase - { - public abstract class ArgumentsBase : TransformInputBase - { - internal abstract string[] Columns { get; } - } - - public sealed class Arguments : ArgumentsBase - { - [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Column name to drop", ShortName = "col", SortOrder = 1)] - public string[] Column; - - internal override string[] Columns => Column; - } - - public sealed class KeepArguments : ArgumentsBase - { - [Argument(ArgumentType.Multiple, HelpText = "Column name to keep", ShortName = "col", SortOrder = 1)] - public string[] Column; - - internal override string[] Columns => Column; - } - - private sealed class Bindings : ISchema - { - public readonly ISchema Input; - - // Whether to keep (vs drop) the named columns. - public readonly bool Keep; - // The column names to drop/keep. - public readonly HashSet Names; - // Map from our column indices to source column indices. - public readonly int[] ColMap; - // Map from names to our column indices. - public readonly Dictionary NameToCol; - - public Schema AsSchema { get; } - - public Bindings(ArgumentsBase args, bool keep, ISchema schemaInput) - { - Contracts.AssertValue(args); - Contracts.AssertNonEmpty(args.Columns); - Contracts.AssertValue(schemaInput); - - Keep = keep; - Input = schemaInput; - - Names = new HashSet(); - for (int i = 0; i < args.Columns.Length; i++) - { - var name = args.Columns[i]; - Contracts.CheckNonWhiteSpace(name, nameof(args.Columns)); - - // REVIEW: Should this just be a warning? - if (!Names.Add(name)) - throw Contracts.ExceptUserArg(nameof(args.Columns), "Column '{0}' specified multiple times", name); - } - - BuildMap(out ColMap, out NameToCol); - - AsSchema = Schema.Create(this); - } - - private void BuildMap(out int[] map, out Dictionary nameToCol) - { - var srcs = new List(); - nameToCol = new Dictionary(); - for (int src = 0; src < Input.ColumnCount; src++) - { - string name = Input.GetColumnName(src); - if (Names.Contains(name) == !Keep) - continue; - - // If the Input schema maps name to this src column, record it in our name table. - int tmp; - if (Input.TryGetColumnIndex(name, out tmp) && tmp == src) - nameToCol.Add(name, srcs.Count); - // Record the source column index. - srcs.Add(src); - } - map = srcs.ToArray(); - } - - public Bindings(ModelLoadContext ctx, ISchema schemaInput) - { - Contracts.AssertValue(ctx); - Contracts.AssertValue(schemaInput); - - Input = schemaInput; - - // *** Binary format *** - // bool: whether to keep (vs drop) the named columns - // int: number of names - // int[]: the ids of the names - Keep = ctx.Reader.ReadBoolByte(); - int count = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(count > 0); - - Names = new HashSet(); - for (int i = 0; i < count; i++) - { - string name = ctx.LoadNonEmptyString(); - Contracts.CheckDecode(Names.Add(name)); - } - - BuildMap(out ColMap, out NameToCol); - AsSchema = Schema.Create(this); - } - - public void Save(ModelSaveContext ctx) - { - Contracts.AssertValue(ctx); - - // *** Binary format *** - // bool: whether to keep (vs drop) the named columns - // int: number of names - // int[]: the ids of the names - ctx.Writer.WriteBoolByte(Keep); - ctx.Writer.Write(Names.Count); - foreach (var name in Names) - ctx.SaveNonEmptyString(name); - } - - public int ColumnCount - { - get { return ColMap.Length; } - } - - public bool TryGetColumnIndex(string name, out int col) - { - Contracts.CheckValueOrNull(name); - - if (name == null) - { - col = default(int); - return false; - } - return NameToCol.TryGetValue(name, out col); - } - - public string GetColumnName(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return Input.GetColumnName(ColMap[col]); - } - - public ColumnType GetColumnType(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return Input.GetColumnType(ColMap[col]); - } - - public IEnumerable> GetMetadataTypes(int col) - { - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return Input.GetMetadataTypes(ColMap[col]); - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - Contracts.CheckNonEmpty(kind, nameof(kind)); - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return Input.GetMetadataTypeOrNull(kind, ColMap[col]); - } - - public void GetMetadata(string kind, int col, ref TValue value) - { - Contracts.CheckNonEmpty(kind, nameof(kind)); - Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - Input.GetMetadata(kind, ColMap[col], ref value); - } - - internal bool[] GetActive(Func predicate) - { - return Utils.BuildArray(ColumnCount, predicate); - } - - internal Func GetDependencies(Func predicate) - { - Contracts.AssertValue(predicate); - var active = new bool[Input.ColumnCount]; - for (int i = 0; i < ColMap.Length; i++) - { - if (predicate(i)) - active[ColMap[i]] = true; - } - return col => 0 <= col && col < active.Length && active[col]; - } - } - - public const string DropColumnsSummary = "Removes a column or columns from the dataset."; - public const string KeepColumnsSummary = "Selects which columns from the dataset to keep."; - public const string DropUserName = "Drop Columns Transform"; - public const string KeepUserName = "Keep Columns Transform"; - public const string DropShortName = "Drop"; - public const string KeepShortName = "Keep"; - - public const string LoaderSignature = "DropColumnsTransform"; - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "DRPCOLST", - // verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x00010002, // Added KeepColumns - verReadableCur: 0x00010002, - verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(DropColumnsTransform).Assembly.FullName); - } - - private readonly Bindings _bindings; - - private const string DropRegistrationName = "DropColumns"; - private const string KeepRegistrationName = "KeepColumns"; - - /// - /// Convenience constructor for public facing API. - /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Name of the columns to be dropped. - public DropColumnsTransform(IHostEnvironment env, IDataView input, params string[] columnsToDrop) - :this(env, new Arguments() { Column = columnsToDrop }, input) - { - } - - /// - /// Public constructor corresponding to SignatureDataTransform. - /// - public DropColumnsTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, DropRegistrationName, input) - { - Host.CheckValue(args, nameof(args)); - Host.CheckNonEmpty(args.Column, nameof(args.Column)); - - _bindings = new Bindings(args, false, Source.Schema); - } - - /// - /// Public constructor corresponding to SignatureDataTransform. - /// - public DropColumnsTransform(IHostEnvironment env, KeepArguments args, IDataView input) - : base(env, KeepRegistrationName, input) - { - Host.CheckValue(args, nameof(args)); - Host.CheckNonEmpty(args.Column, nameof(args.Column)); - - _bindings = new Bindings(args, true, Source.Schema); - } - - private DropColumnsTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, input) - { - Host.AssertValue(ctx); - - // *** Binary format *** - // int: sizeof(Float) - // bindings - int cbFloat = ctx.Reader.ReadInt32(); - Host.CheckDecode(cbFloat == sizeof(Float)); - _bindings = new Bindings(ctx, Source.Schema); - } - - public static DropColumnsTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(DropRegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new DropColumnsTransform(h, ctx, input)); - } - - public override void Save(ModelSaveContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - - // *** Binary format *** - // int: sizeof(Float) - // bindings - ctx.Writer.Write(sizeof(Float)); - _bindings.Save(ctx); - } - - public override Schema Schema => _bindings.AsSchema; - - protected override bool? ShouldUseParallelCursors(Func predicate) - { - Host.AssertValue(predicate); - // Parallel doesn't matter to this transform. - return null; - } - - protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) - { - Host.AssertValue(predicate, "predicate"); - Host.AssertValueOrNull(rand); - - var inputPred = _bindings.GetDependencies(predicate); - var active = _bindings.GetActive(predicate); - var input = Source.GetRowCursor(inputPred, rand); - return new RowCursor(Host, _bindings, input, active); - } - - public sealed override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, - Func predicate, int n, IRandom rand = null) - { - Host.CheckValue(predicate, nameof(predicate)); - Host.CheckValueOrNull(rand); - - var inputPred = _bindings.GetDependencies(predicate); - var active = _bindings.GetActive(predicate); - var inputs = Source.GetRowCursorSet(out consolidator, inputPred, n, rand); - Host.AssertNonEmpty(inputs); - - // No need to split if this is given 1 input cursor. - var cursors = new IRowCursor[inputs.Length]; - for (int i = 0; i < inputs.Length; i++) - cursors[i] = new RowCursor(Host, _bindings, inputs[i], active); - return cursors; - } - - protected override Func GetDependenciesCore(Func predicate) - { - return _bindings.GetDependencies(predicate); - } - - protected override Delegate[] CreateGetters(IRow input, Func active, out Action disp) - { - disp = null; - return new Delegate[0]; - } - - protected override int MapColumnIndex(out bool isSrc, int col) - { - isSrc = true; - return _bindings.ColMap[col]; - } - - // REVIEW: Refactor so ChooseColumns can share the same cursor class. - private sealed class RowCursor : SynchronizedCursorBase, IRowCursor - { - private readonly Bindings _bindings; - private readonly bool[] _active; - - public RowCursor(IChannelProvider provider, Bindings bindings, IRowCursor input, bool[] active) - : base(provider, input) - { - Ch.AssertValue(bindings); - Ch.Assert(active == null || active.Length == bindings.ColumnCount); - - _bindings = bindings; - _active = active; - } - - public Schema Schema => _bindings.AsSchema; - - public bool IsColumnActive(int col) - { - Ch.Check(0 <= col && col < _bindings.ColumnCount); - return _active == null || _active[col]; - } - - public ValueGetter GetGetter(int col) - { - Ch.Check(IsColumnActive(col)); - return Input.GetGetter(_bindings.ColMap[col]); - } - } - } - - public class KeepColumnsTransform - { - /// - /// A helper method to create for public facing API. - /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Name of the columns to be kept. All other columns will be removed. - /// - public static IDataTransform Create(IHostEnvironment env, IDataView input, params string[] columnsToKeep) - => new DropColumnsTransform(env, new DropColumnsTransform.KeepArguments() { Column = columnsToKeep }, input); - } -} diff --git a/src/Microsoft.ML.Data/Transforms/SelectColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/SelectColumnsTransform.cs index e6edfe7510..7f486fc77d 100644 --- a/src/Microsoft.ML.Data/Transforms/SelectColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/SelectColumnsTransform.cs @@ -47,7 +47,7 @@ public sealed class SelectColumnsEstimator : TrivialEstimatorInstance of the host environment. /// The array of column names to keep. public SelectColumnsEstimator(IHostEnvironment env, params string[] keepColumns) - : this(env, keepColumns, null, true, true) + : this(env, keepColumns, null, SelectColumnsTransform.Defaults.KeepHidden, SelectColumnsTransform.Defaults.IgnoreMissing) { } /// @@ -61,8 +61,8 @@ public SelectColumnsEstimator(IHostEnvironment env, params string[] keepColumns) /// or that are missing from the input. If a missing colums exists a /// SchemaMistmatch exception is thrown. If true, the check is not made. public SelectColumnsEstimator(IHostEnvironment env, string[] keepColumns, - string[] dropColumns, bool keepHidden = true, - bool ignoreMissing= true) + string[] dropColumns, bool keepHidden = SelectColumnsTransform.Defaults.KeepHidden, + bool ignoreMissing = SelectColumnsTransform.Defaults.IgnoreMissing) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(SelectColumnsEstimator)), new SelectColumnsTransform(env, keepColumns, dropColumns, keepHidden, ignoreMissing)) { @@ -126,6 +126,12 @@ public sealed class SelectColumnsTransform : ITransformer, ICanSaveModel private readonly IHost _host; private string[] _selectedColumns; + internal static class Defaults + { + public const bool KeepHidden = false; + public const bool IgnoreMissing = false; + }; + public bool IsRowToRowMapper => true; public IEnumerable SelectColumns => _selectedColumns.AsReadOnly(); @@ -179,14 +185,14 @@ public sealed class Arguments : TransformInputBase public string[] DropColumns; [Argument(ArgumentType.AtMostOnce, HelpText = "Specifies whether to keep or remove hidden columns.", ShortName = "hidden", SortOrder = 3)] - public bool KeepHidden = true; + public bool KeepHidden = Defaults.KeepHidden; [Argument(ArgumentType.AtMostOnce, HelpText = "Specifies whether to ignore columns that are missing from the input.", ShortName = "ignore", SortOrder = 4)] - public bool IgnoreMissing = true; + public bool IgnoreMissing = Defaults.IgnoreMissing; } public SelectColumnsTransform(IHostEnvironment env, string[] keepColumns, string[] dropColumns, - bool keepHidden=true, bool ignoreMissing=true) + bool keepHidden=Defaults.KeepHidden, bool ignoreMissing=Defaults.IgnoreMissing) { _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(SelectColumnsTransform)); _host.CheckValueOrNull(keepColumns); @@ -232,7 +238,7 @@ private static SelectColumnsTransform LoadDropColumnsTransform(IHostEnvironment // int: sizeof(Float) // bindings int cbFloat = ctx.Reader.ReadInt32(); - //env.CheckDecode(cbFloat == sizeof(Float)); + env.CheckDecode(cbFloat == sizeof(float)); // *** Binary format *** // bool: whether to keep (vs drop) the named columns @@ -256,7 +262,9 @@ private static SelectColumnsTransform LoadDropColumnsTransform(IHostEnvironment else dropColumns = names.ToArray(); - return new SelectColumnsTransform(env, keepColumns, dropColumns, keep); + // Note for backward compatibility, Drop/Keep Columns always preserves + // hidden columns + return new SelectColumnsTransform(env, keepColumns, dropColumns, true); } /// @@ -384,6 +392,30 @@ public static IDataView Create(IHostEnvironment env, ModelLoadContext ctx, IData return transform.Transform(input); } + public static IDataTransform CreateKeep(IHostEnvironment env, IDataView input, params string[] keepColumns) + { + var transform = new SelectColumnsTransform(env, keepColumns, null); + return new SelectColumnsDataTransform(env, transform, new Mapper(transform, input.Schema), input); + } + + public static IDataTransform CreateKeep(IHostEnvironment env, IDataView input, bool keepHidden, params string[] keepColumns) + { + var transform = new SelectColumnsTransform(env, keepColumns, null, keepHidden); + return new SelectColumnsDataTransform(env, transform, new Mapper(transform, input.Schema), input); + } + + public static IDataTransform CreateDrop(IHostEnvironment env, IDataView input, params string[] dropColumns) + { + var transform = new SelectColumnsTransform(env, null, dropColumns); + return new SelectColumnsDataTransform(env, transform, new Mapper(transform, input.Schema), input); + } + + public static IDataTransform CreateDrop(IHostEnvironment env, IDataView input, bool keepHidden, params string[] dropColumns) + { + var transform = new SelectColumnsTransform(env, null, dropColumns, keepHidden); + return new SelectColumnsDataTransform(env, transform, new Mapper(transform, input.Schema), input); + } + // Factory method for SignatureDataTransform. private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { @@ -490,21 +522,62 @@ private static int[] BuildOutputToInputMap(IEnumerable selectedColumns, { var outputToInputMapping = new List(); var columnCount = inputSchema.ColumnCount; - int outputIdx = 0; - for (int colIdx = 0; colIdx < columnCount; ++colIdx) + if (keepColumns) { - if (!keepHidden && inputSchema.IsHidden(colIdx)) - continue; + // With KeepColumns, the order that is specified is preserved in the mapping. + // For example if a given input has the columns of ABC and the select columns are + // specified as CA, then the output will be CA. + + // In order to account for keeping hidden columns, build a dictionary of + // column name-> list of column indices. This dictionary is used for + // building the final mapping. + var columnDict = new Dictionary>(); + for(int colIdx = 0; colIdx < inputSchema.ColumnCount; ++colIdx) + { + if (!keepHidden && inputSchema.IsHidden(colIdx)) + continue; + + var columnName = inputSchema[colIdx].Name; + if (columnDict.TryGetValue(columnName, out List columnList)) + columnList.Add(colIdx); + else + { + columnList = new List(); + columnList.Add(colIdx); + columnDict.Add(columnName, columnList); + } + } - var columnName = inputSchema[colIdx].Name; - var selected = selectedColumns.Contains(columnName); - selected = (keepColumns) ? selected : !selected; - if (selected) + // Since the ordering matters, iterate through the selected columns + // finding the associated index that should be used. + foreach(var columnName in selectedColumns) { + if (columnDict.TryGetValue(columnName, out List columnList)) + { + foreach(var colIdx in columnList) + { + outputToInputMapping.Add(colIdx); + } + } + } + } + else + { + // Handles the drop case, removing any columns specified from the input + // In the case of drop, the order of the output is modeled after the input + // given an input of ABC and dropping column B will result in AC. + for(int colIdx = 0; colIdx < inputSchema.ColumnCount; colIdx++) + { + if (!keepHidden && inputSchema.IsHidden(colIdx)) + continue; + + if (selectedColumns.Contains(inputSchema[colIdx].Name)) + continue; + outputToInputMapping.Add(colIdx); - outputIdx++; } + } return outputToInputMapping.ToArray(); @@ -513,9 +586,7 @@ private static int[] BuildOutputToInputMap(IEnumerable selectedColumns, private static Schema GenerateOutputSchema(IEnumerable map, Schema inputSchema) { - IEnumerable inputs = Enumerable.Range(0, inputSchema.ColumnCount); - var outputColumns = inputs.Where(idx=> map.Contains(idx)) - .Select(idx=>inputSchema[idx]); + var outputColumns = map.Select(x=>inputSchema[x]); return new Schema(outputColumns); } } diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index 9f5b2e8791..a73f925256 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -11768,12 +11768,12 @@ public sealed partial class ColumnSelector : Microsoft.ML.Runtime.EntryPoints.Co /// /// Specifies whether to keep or remove hidden columns. /// - public bool KeepHidden { get; set; } = true; + public bool KeepHidden { get; set; } = false; /// /// Specifies whether to ignore columns that are missing from the input. /// - public bool IgnoreMissing { get; set; } = true; + public bool IgnoreMissing { get; set; } = false; /// /// Input dataset diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CVSplit.cs b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CVSplit.cs index f3d7a582a4..a99d3a9f6e 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CVSplit.cs +++ b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CVSplit.cs @@ -68,11 +68,11 @@ public static Output Split(IHostEnvironment env, Input input) { var trainData = new RangeFilter(host, new RangeFilter.Arguments { Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = true }, data); - output.TrainData[i] = new DropColumnsTransform(host, new DropColumnsTransform.Arguments { Column = new[] { stratCol } }, trainData); + output.TrainData[i] = SelectColumnsTransform.CreateDrop(host, trainData, stratCol); var testData = new RangeFilter(host, new RangeFilter.Arguments { Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = false }, data); - output.TestData[i] = new DropColumnsTransform(host, new DropColumnsTransform.Arguments { Column = new[] { stratCol } }, testData); + output.TestData[i] = SelectColumnsTransform.CreateDrop(host, testData, stratCol); } return output; diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/TrainTestSplit.cs b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/TrainTestSplit.cs index 5b7e735380..244e2893e8 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/TrainTestSplit.cs +++ b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/TrainTestSplit.cs @@ -53,11 +53,11 @@ public static Output Split(IHostEnvironment env, Input input) IDataView trainData = new RangeFilter(host, new RangeFilter.Arguments { Column = stratCol, Min = 0, Max = input.Fraction, Complement = false }, data); - trainData = new DropColumnsTransform(host, new DropColumnsTransform.Arguments { Column = new[] { stratCol } }, trainData); + trainData = SelectColumnsTransform.CreateDrop(host, trainData, stratCol); IDataView testData = new RangeFilter(host, new RangeFilter.Arguments { Column = stratCol, Min = 0, Max = input.Fraction, Complement = true }, data); - testData = new DropColumnsTransform(host, new DropColumnsTransform.Arguments { Column = new[] { stratCol } }, testData); + testData = SelectColumnsTransform.CreateDrop(host, testData, stratCol); return new Output() { TrainData = trainData, TestData = testData }; } diff --git a/src/Microsoft.ML.Transforms/NAHandleTransform.cs b/src/Microsoft.ML.Transforms/NAHandleTransform.cs index 7ef5e9bb10..1a9116f071 100644 --- a/src/Microsoft.ML.Transforms/NAHandleTransform.cs +++ b/src/Microsoft.ML.Transforms/NAHandleTransform.cs @@ -229,7 +229,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // Finally, drop the temporary indicator columns. if (dropCols.Count > 0) - output = new DropColumnsTransform(h, new DropColumnsTransform.Arguments() { Column = dropCols.ToArray() }, output); + output = SelectColumnsTransform.CreateDrop(h, output, dropCols.ToArray()); return output; } diff --git a/src/Microsoft.ML.Transforms/TermLookupTransform.cs b/src/Microsoft.ML.Transforms/TermLookupTransform.cs index f34be7c2d7..48c3b9a9d9 100644 --- a/src/Microsoft.ML.Transforms/TermLookupTransform.cs +++ b/src/Microsoft.ML.Transforms/TermLookupTransform.cs @@ -13,6 +13,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Reflection; using System.Text; @@ -507,13 +508,14 @@ private static byte[] GetBytesFromDataView(IHost host, IDataView lookup, string var typeTerm = schema.GetColumnType(colTerm); host.CheckUserArg(typeTerm.IsText, nameof(Arguments.TermColumn), "term column must contain text"); var typeValue = schema.GetColumnType(colValue); - - var args = new ChooseColumnsTransform.Arguments(); - args.Column = new[] { - new ChooseColumnsTransform.Column {Name = "Term", Source = termColumn}, - new ChooseColumnsTransform.Column {Name = "Value", Source = valueColumn}, + var cols = new List<(string Source, string Name)>() + { + (termColumn, "Term"), + (valueColumn, "Value") }; - var view = new ChooseColumnsTransform(host, args, lookup); + + var view = new CopyColumnsTransform(host, cols.ToArray()).Transform(lookup); + view = SelectColumnsTransform.CreateKeep(host, view, cols.Select(x=>x.Name).ToArray()); var saver = new BinarySaver(host, new BinarySaver.Arguments()); using (var strm = new MemoryStream()) diff --git a/src/Microsoft.ML.Transforms/Text/SentimentAnalyzerTransform.cs b/src/Microsoft.ML.Transforms/Text/SentimentAnalyzerTransform.cs index 0fc3733759..5a6edb4464 100644 --- a/src/Microsoft.ML.Transforms/Text/SentimentAnalyzerTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/SentimentAnalyzerTransform.cs @@ -95,7 +95,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // 5. Drop all the columns created by the pretrained model, including the expected input column // and the output column, which we have copied to a temporary column in (4). - input = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = _modelIntermediateColumnNames }, input); + input = SelectColumnsTransform.CreateDrop(env, input, _modelIntermediateColumnNames); // 6. Unalias all the original columns that were originally present in the IDataView, but may have // been shadowed by column names in the pretrained model. This method will also drop all the temporary @@ -107,7 +107,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV input = copyTransformer.Transform(input); // 8. Drop the temporary column with the score created in (4). - return new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = new[] { scoreTempName } }, input); + return SelectColumnsTransform.CreateDrop(env, input, scoreTempName); } /// @@ -148,10 +148,7 @@ private static IDataView UnaliasIfNeeded(IHostEnvironment env, IDataView input, Column = hiddenNames.Select(pair => new CopyColumnsTransform.Column() { Name = pair.Key, Source = pair.Value }).ToArray() }, input); - return new DropColumnsTransform(env, new DropColumnsTransform.Arguments() - { - Column = hiddenNames.Select(pair => pair.Value).ToArray() - }, input); + return SelectColumnsTransform.CreateDrop(env, input, hiddenNames.Select(pair => pair.Value).ToArray()); } private static IDataView LoadTransforms(IHostEnvironment env, IDataView input, string modelFile) diff --git a/src/Microsoft.ML.Transforms/Text/TextTransform.cs b/src/Microsoft.ML.Transforms/Text/TextTransform.cs index 2c77451906..e9e3416284 100644 --- a/src/Microsoft.ML.Transforms/Text/TextTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/TextTransform.cs @@ -14,6 +14,7 @@ using Microsoft.ML.Runtime.TextAnalytics; using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; +using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Text; using System; using System.Collections.Generic; @@ -473,9 +474,7 @@ public ITransformer Fit(IDataView input) } } - view = new DropColumnsTransform(h, - new DropColumnsTransform.Arguments() { Column = tempCols.ToArray() }, view); - + view = SelectColumnsTransform.CreateDrop(h, view, tempCols.ToArray()); return new Transformer(_host, input, view); } diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs index 4a9233dfa1..ba80d64ef0 100644 --- a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs @@ -151,13 +151,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV view = NgramHashExtractorTransform.Create(h, featurizeArgs, view); // Since we added columns with new names, we need to explicitly drop them before we return the IDataTransform. - var dropColsArgs = - new DropColumnsTransform.Arguments() - { - Column = tmpColNames.ToArray() - }; - - return new DropColumnsTransform(h, dropColsArgs, view); + return SelectColumnsTransform.CreateDrop(h, view, tmpColNames.ToArray()); } } @@ -447,10 +441,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV }; view = new NgramHashTransform(h, ngramHashArgs, view); - return new DropColumnsTransform(h, new DropColumnsTransform.Arguments() - { - Column = tmpColNames.SelectMany(cols => cols).ToArray() - }, view); + return SelectColumnsTransform.CreateDrop(h, view, tmpColNames.SelectMany(cols => cols).ToArray()); } public static IDataTransform Create(NgramHashExtractorArguments extractorArgs, IHostEnvironment env, IDataView input, diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 3c174d996e..8e7acf16e3 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -17945,7 +17945,7 @@ "Required": false, "SortOrder": 3.0, "IsNullable": false, - "Default": true + "Default": false }, { "Name": "IgnoreMissing", @@ -17957,7 +17957,7 @@ "Required": false, "SortOrder": 4.0, "IsNullable": false, - "Default": true + "Default": false } ], "Outputs": [ diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeConcatUnknownLength-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeConcatUnknownLength-Schema.txt index 9ec6a3c4d9..20a6551b87 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeConcatUnknownLength-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeConcatUnknownLength-Schema.txt @@ -4,7 +4,7 @@ Single: I4 Text: Text Unknown: Vec ----- DelimitedTokenizeTransform ---- +---- RowToRowMapperTransform ---- 5 columns: Known: Vec Single: I4 @@ -82,6 +82,6 @@ Indicators: Vec Metadata 'IsNormalized': Bool: '1' All: Vec ----- ChooseColumnsTransform ---- +---- SelectColumnsDataTransform ---- 1 columns: All: Vec diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeConcatWithAliases-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeConcatWithAliases-Schema.txt index c4dc0025f4..20e25b8025 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeConcatWithAliases-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeConcatWithAliases-Schema.txt @@ -29,7 +29,7 @@ [0] 'A', [1] 'B.thickness', [2] 'B.uniform_size', [3] 'B.uniform_shape', [4] 'B.adhesion', [5] 'B.epit_size', [6] 'B.bare_nuclei', [7] 'B.bland_chromatin', [8] 'B.normal_nucleoli', [9] 'B.mitoses' [10] 'thickness', [11] 'uniform_size', [12] 'uniform_shape', [13] 'adhesion', [14] 'epit_size', [15] 'bare_nuclei', [16] 'bland_chromatin', [17] 'normal_nucleoli', [18] 'mitoses', [19] 'Vector.thickness' [20] 'Vector.uniform_size', [21] 'Vector.uniform_shape', [22] 'Vector.adhesion', [23] 'Vector.epit_size', [24] 'Vector.bare_nuclei', [25] 'Vector.bland_chromatin', [26] 'Vector.normal_nucleoli', [27] 'Vector.mitoses' ----- DropColumnsTransform ---- +---- SelectColumnsDataTransform ---- 2 columns: All: Vec Metadata 'SlotNames': Vec: Length=11, Count=11 diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeHash-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeHash-Schema.txt index e4547fc611..7b9f3a71dc 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeHash-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeHash-Schema.txt @@ -291,7 +291,7 @@ VarHash6: Vec> SingleHash: Key VarComb: Vec> ----- ChooseColumnsTransform ---- +---- SelectColumnsDataTransform ---- 15 columns: SingleHash: Key Hash0: Vec, 3> diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeKeyToVec-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeKeyToVec-Schema.txt index 4428207ee9..0feb5ad206 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeKeyToVec-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeKeyToVec-Schema.txt @@ -172,7 +172,7 @@ Metadata 'IsNormalized': Bool: '1' Metadata 'SlotNames': Vec: Length=7, Count=7 [0] 'Never-married', [1] 'Married-civ-spouse', [2] 'Widowed', [3] 'Divorced', [4] 'Separated', [5] 'Married-spouse-absent', [6] 'Married-AF-spouse' ----- ChooseColumnsTransform ---- +---- SelectColumnsDataTransform ---- 10 columns: MarKey: Key Metadata 'KeyValues': Vec: Length=7, Count=7 diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers-Schema.txt index 9516c7b1d7..6bfec04297 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers-Schema.txt @@ -50,7 +50,7 @@ Metadata 'KeyValues': Vec: Length=7, Count=7 [0] 'Wirtschaft', [1] 'Gesundheit', [2] 'Deutschland', [3] 'Ausland', [4] 'Unterhaltung', [5] 'Sport', [6] 'Technik & Wissen' FileLabel: Key ----- ChooseColumnsTransform ---- +---- SelectColumnsDataTransform ---- 4 columns: RawLabel: Text AutoLabel: Key diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers1-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers1-Schema.txt index 9c13e8e7d8..2f805fd103 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers1-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers1-Schema.txt @@ -17,7 +17,7 @@ Metadata 'SlotNames': Vec: Length=2, Count=2 [0] 'weg fuer milliardenhilfe frei', [1] 'vor dem parlamentsgebaeude toben strassenkaempfe zwischen demonstranten drinnen haben die griechischen abgeordneten das drastische sparpaket am abend endgueltig beschlossen die entscheidung ist eine wichtige voraussetzung fuer die auszahlung von weiteren acht milliarden euro hilfsgeldern athen das griechische parlament hat einem umfassenden sparpaket endgueltig zugestimmt' FileLabel: Key ----- ChooseColumnsTransform ---- +---- SelectColumnsDataTransform ---- 2 columns: RawLabel: Text FileLabel: Key diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers2-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers2-Schema.txt index a6aa12c43a..7f097005cb 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers2-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers2-Schema.txt @@ -17,7 +17,7 @@ Metadata 'SlotNames': Vec: Length=2, Count=2 [0] 'weg fuer milliardenhilfe frei', [1] 'vor dem parlamentsgebaeude toben strassenkaempfe zwischen demonstranten drinnen haben die griechischen abgeordneten das drastische sparpaket am abend endgueltig beschlossen die entscheidung ist eine wichtige voraussetzung fuer die auszahlung von weiteren acht milliarden euro hilfsgeldern athen das griechische parlament hat einem umfassenden sparpaket endgueltig zugestimmt' FileLabel: Key ----- ChooseColumnsTransform ---- +---- SelectColumnsDataTransform ---- 2 columns: RawLabel: Text FileLabel: Key diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers3-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers3-Schema.txt index a897b22f5e..9c8630237b 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers3-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers3-Schema.txt @@ -17,7 +17,7 @@ Metadata 'SlotNames': Vec: Length=2, Count=2 [0] 'weg fuer milliardenhilfe frei', [1] 'vor dem parlamentsgebaeude toben strassenkaempfe zwischen demonstranten drinnen haben die griechischen abgeordneten das drastische sparpaket am abend endgueltig beschlossen die entscheidung ist eine wichtige voraussetzung fuer die auszahlung von weiteren acht milliarden euro hilfsgeldern athen das griechische parlament hat einem umfassenden sparpaket endgueltig zugestimmt' FileLabel: R4 ----- ChooseColumnsTransform ---- +---- SelectColumnsDataTransform ---- 2 columns: RawLabel: Text FileLabel: R4 diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers4-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers4-Schema.txt index 0bf6f2ca1f..e54b1a5652 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers4-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers4-Schema.txt @@ -28,7 +28,7 @@ [0] 'weg fuer milliardenhilfe frei', [1] 'vor dem parlamentsgebaeude toben strassenkaempfe zwischen demonstranten drinnen haben die griechischen abgeordneten das drastische sparpaket am abend endgueltig beschlossen die entscheidung ist eine wichtige voraussetzung fuer die auszahlung von weiteren acht milliarden euro hilfsgeldern athen das griechische parlament hat einem umfassenden sparpaket endgueltig zugestimmt' FileLabelNum: R4 FileLabelKey: Key ----- ChooseColumnsTransform ---- +---- SelectColumnsDataTransform ---- 3 columns: RawLabel: Text FileLabelNum: R4 diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers5-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers5-Schema.txt index 4140d7345a..76739e7fda 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers5-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeLabelParsers5-Schema.txt @@ -17,7 +17,7 @@ Metadata 'SlotNames': Vec: Length=2, Count=2 [0] 'weg fuer milliardenhilfe frei', [1] 'vor dem parlamentsgebaeude toben strassenkaempfe zwischen demonstranten drinnen haben die griechischen abgeordneten das drastische sparpaket am abend endgueltig beschlossen die entscheidung ist eine wichtige voraussetzung fuer die auszahlung von weiteren acht milliarden euro hilfsgeldern athen das griechische parlament hat einem umfassenden sparpaket endgueltig zugestimmt' FileLabel: Key ----- ChooseColumnsTransform ---- +---- SelectColumnsDataTransform ---- 2 columns: RawLabel: Text FileLabel: Key diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeTermDictionaryNgram-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeTermDictionaryNgram-Schema.txt index 3d2daa7558..83f4220b1c 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeTermDictionaryNgram-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeTermDictionaryNgram-Schema.txt @@ -89,7 +89,7 @@ [0] '*', [1] '*|*', [2] '*|fred', [3] 'fred', [4] 'fred|mcgriff', [5] 'mcgriff', [6] 'mcgriff|*', [7] '*|padres', [8] 'padres', [9] 'padres|*' [10] '*|free', [11] 'free', [12] 'free|agent', [13] 'agent', [14] '*|erythromycin', [15] 'erythromycin', [16] 'erythromycin|*', [17] '*|treating', [18] 'treating', [19] 'treating|pneumonia' [20] 'pneumonia' ----- ChooseColumnsTransform ---- +---- SelectColumnsDataTransform ---- 1 columns: Features: Vec Metadata 'SlotNames': Vec: Length=21, Count=21 diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeTermDictionaryNgramHash-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeTermDictionaryNgramHash-Schema.txt index e4e9f5cd8a..eeeb3948d5 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeTermDictionaryNgramHash-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeTermDictionaryNgramHash-Schema.txt @@ -93,6 +93,6 @@ T1: Text T2: Text Features: Vec ----- ChooseColumnsTransform ---- +---- SelectColumnsDataTransform ---- 1 columns: Features: Vec diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeTermDictionaryNgramTerms-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeTermDictionaryNgramTerms-Schema.txt index b906d350c4..7f049f63b6 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeTermDictionaryNgramTerms-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeTermDictionaryNgramTerms-Schema.txt @@ -82,7 +82,7 @@ Metadata 'SlotNames': Vec: Length=15, Count=15 [0] '*', [1] '*|*', [2] '*|mcgriff', [3] 'mcgriff', [4] 'mcgriff|*', [5] '*|padres', [6] 'padres', [7] 'padres|*', [8] '*|agent', [9] 'agent' [10] '*|erythromycin', [11] 'erythromycin', [12] 'erythromycin|*', [13] '*|pneumonia', [14] 'pneumonia' ----- ChooseColumnsTransform ---- +---- SelectColumnsDataTransform ---- 1 columns: Features: Vec Metadata 'SlotNames': Vec: Length=15, Count=15 diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 863d280fca..2d2ab5dae4 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -129,8 +129,8 @@ public void EntryPointFeatureCombiner() expected = Env.CreateTransform("KeyToVector{col=F1}", expected); expected = Env.CreateTransform("Concat{col=Features:F1,F2,Rest}", expected); - expected = Env.CreateTransform("ChooseColumns{col=Features}", expected); - result = Env.CreateTransform("ChooseColumns{col=Features}", result); + expected = Env.CreateTransform("SelectColumns{keepcol=Features hidden=-}", expected); + result = Env.CreateTransform("SelectColumns{keepcol=Features hidden=-}", result); CheckSameValues(result, expected); Done(); } @@ -436,24 +436,6 @@ public void EntryPointInputRangeChecks() Assert.True(EntryPointUtils.IsValueWithinRange(range, 0.0)); } - [Fact] - public void EntryPointInputArgsChecks() - { - var input = new DropColumnsTransform.KeepArguments(); - try - { - EntryPointUtils.CheckInputArgs(Env, input); - Assert.False(true); - } - catch - { - } - - input.Data = new EmptyDataView(Env, new Schema(new[] { new Schema.Column("ColA", NumberType.R4, null) })); - input.Column = new string[0]; - EntryPointUtils.CheckInputArgs(Env, input); - } - [Fact] public void EntryPointCreateEnsemble() { @@ -490,9 +472,7 @@ public void EntryPointCreateEnsemble() }, } }, individualScores[i]); - individualScores[i] = new DropColumnsTransform(Env, - new DropColumnsTransform.Arguments() { Column = new[] { MetadataUtils.Const.ScoreValueKind.Score } }, - individualScores[i]); + individualScores[i] = SelectColumnsTransform.CreateDrop(Env, individualScores[i], MetadataUtils.Const.ScoreValueKind.Score); } var avgEnsembleInput = new EnsembleCreator.ClassifierInput { Models = predictorModels, ModelCombiner = EnsembleCreator.ClassifierCombiner.Average }; diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 5471ccf8e1..24fc01a5c4 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -422,7 +422,7 @@ public void NormalizerWithOnFit() // Just for fun, let's also write out some of the lines of the data to the console. using (var stream = new MemoryStream()) { - IDataView v = new ChooseColumnsTransform(env, tdata.AsDynamic, "r", "ncdf", "n", "b"); + IDataView v = SelectColumnsTransform.CreateKeep(env, tdata.AsDynamic, "r", "ncdf", "n", "b"); v = TakeFilter.Create(env, v, 10); var saver = new TextSaver(env, new TextSaver.Arguments() { diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 3e1e653549..9edb7bd312 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -44,7 +44,7 @@ public void SavePipeLabelParsers() "xf=AutoLabel{col=AutoLabel:RawLabel}", "xf=Term{col=StringLabel:RawLabel terms={Wirtschaft,Gesundheit,Deutschland,Ausland,Unterhaltung,Sport,Technik & Wissen}}", string.Format("xf=TermLookup{{col=FileLabel:RawLabel data={{{0}}}}}", mappingPathData), - "xf=ChooseColumns{col=RawLabel col=AutoLabel col=StringLabel col=FileLabel}" + "xf=SelectColumns{keepcol=RawLabel keepcol=AutoLabel keepcol=StringLabel keepcol=FileLabel}" }); mappingPathData = DeleteOutputPath("SavePipe", "Mapping.txt"); @@ -64,7 +64,7 @@ public void SavePipeLabelParsers() new[] { "loader=Text{col=RawLabel:TXT:0 col=Names:TXT:1-2 col=Features:TXT:3-4 header+}", string.Format("xf=TermLookup{{col=FileLabel:RawLabel data={{{0}}}}}", mappingPathData), - "xf=ChooseColumns{col=RawLabel col=FileLabel}" + "xf=SelectColumns{keepcol=RawLabel keepcol=FileLabel}" }, suffix: "1"); mappingPathData = DeleteOutputPath("SavePipe", "Mapping.txt"); @@ -84,7 +84,7 @@ public void SavePipeLabelParsers() new[] { "loader=Text{col=RawLabel:TXT:0 col=Names:TXT:1-2 col=Features:TXT:3-4 header+}", string.Format("xf=TermLookup{{col=FileLabel:RawLabel data={{{0}}}}}", mappingPathData), - "xf=ChooseColumns{col=RawLabel col=FileLabel}" + "xf=SelectColumns{keepcol=RawLabel keepcol=FileLabel}" }, suffix: "2"); mappingPathData = DeleteOutputPath("SavePipe", "Mapping.txt"); @@ -104,7 +104,7 @@ public void SavePipeLabelParsers() new[] { "loader=Text{col=RawLabel:TXT:0 col=Names:TXT:1-2 col=Features:TXT:3-4 header+}", string.Format("xf=TermLookup{{key=- col=FileLabel:RawLabel data={{{0}}}}}", mappingPathData), - "xf=ChooseColumns{col=RawLabel col=FileLabel}" + "xf=SelectColumns{keepcol=RawLabel keepcol=FileLabel}" }, suffix: "3"); mappingPathData = DeleteOutputPath("SavePipe", "Mapping.txt"); @@ -130,7 +130,7 @@ public void SavePipeLabelParsers() "loader=Text{col=RawLabel:TXT:0 col=Names:TXT:1-2 col=Features:TXT:3-4 header+}", string.Format("xf=TermLookup{{key=- col=FileLabelNum:RawLabel data={{{0}}}}}", mappingPathData), string.Format("xf=TermLookup{{col=FileLabelKey:RawLabel data={{{0}}}}}", mappingPathData), - "xf=ChooseColumns{col=RawLabel col=FileLabelNum col=FileLabelKey}" + "xf=SelectColumns{keepcol=RawLabel keepcol=FileLabelNum keepcol=FileLabelKey}" }, suffix: "4"); writer.WriteLine(ProgressLogLine); Env.PrintProgress(); @@ -155,7 +155,7 @@ public void SavePipeLabelParsers() new[] { "loader=Text{col=RawLabel:TXT:0 col=Names:TXT:1-2 col=Features:TXT:3-4 header+}", string.Format("xf=TermLookup{{col=FileLabel:RawLabel data={{{0}}}}}", mappingPathData), - "xf=ChooseColumns{col=RawLabel col=FileLabel}" + "xf=SelectColumns{keepcol=RawLabel keepcol=FileLabel}" }, suffix: "5"); Done(); @@ -184,7 +184,7 @@ public void SavePipeWithHeader() Done(); } - [Fact(Skip = "Schema baseline comparison fails")] + [Fact] public void SavePipeKeyToVec() { string pathTerms = DeleteOutputPath("SavePipe", "Terms.txt"); @@ -205,7 +205,7 @@ public void SavePipeKeyToVec() "xf=Convert{col=MarKeyU8:U8:MarKey col=CombKeyU1:U1:CombKey}", "xf=KeyToVector{col={name=CombBagVec src=CombKey bag+} col={name=CombIndVec src=CombKey} col=MarVec:MarKey}", "xf=KeyToVector{col={name=CombBagVecU1 src=CombKeyU1 bag+} col={name=CombIndVecU1 src=CombKeyU1} col=MarVecU8:MarKeyU8}", - "xf=ChooseColumns{col=MarKey col=CombKey col=MarVec col=MarVecU8 col=CombBagVec col=CombBagVecU1 col=CombIndVec col=CombIndVecU1 col=Mar col=Comb}", + "xf=SelectColumns{keepcol=MarKey keepcol=CombKey keepcol=MarVec keepcol=MarVecU8 keepcol=CombBagVec keepcol=CombBagVecU1 keepcol=CombIndVec keepcol=CombIndVecU1 keepcol=Mar keepcol=Comb}", }, pipe => @@ -242,7 +242,7 @@ public void SavePipeKeyToVec() Done(); } - [Fact(Skip = "Schema baseline comparison fails")] + [Fact] public void SavePipeConcatUnknownLength() { string pathData = DeleteOutputPath("SavePipe", "ConcatUnknownLength.txt"); @@ -263,7 +263,7 @@ public void SavePipeConcatUnknownLength() "xf=Convert{col=Indicators type=R8}", "xf=Convert{col=Known col=Single col=Unknown type=R8}", "xf=Concat{col=All:Indicators,Known,Single,Unknown}", - "xf=ChooseColumns{col=All}" + "xf=SelectColumns{keepcol=All}" }); Done(); @@ -309,7 +309,7 @@ public void SavePipeNgramSparse() Done(); } - [Fact(Skip = "Schema baseline comparison fails")] + [Fact] public void SavePipeConcatWithAliases() { string pathData = GetDataPath("breast-cancer-withheader.txt"); @@ -319,7 +319,7 @@ public void SavePipeConcatWithAliases() "loader=Text{header+ col=A:0 col=B:1-9}", "xf=Concat{col={name=All source[First]=A src=A source[Rest]=B}}", "xf=Concat{col={name=All2 source=A source=B source[B]=B source[Vector]=B}}", - "xf=DropColumns{col=A col=B}" + "xf=SelectColumns{dropcol=A dropcol=B}" }); Done(); } @@ -351,7 +351,7 @@ public void SavePipeTermDictionary() new[] { "loader=Text{col=T1:TX:0 col=T2:TX:1}", string.Format(textSettings, dictFile), - "xf=ChooseColumns{col=Features}" + "xf=SelectColumns{keepcol=Features}" }, suffix: "Ngram"); textSettings = @@ -360,7 +360,7 @@ public void SavePipeTermDictionary() new[] { "loader=Text{col=T1:TX:0 col=T2:TX:1}", string.Format(textSettings, dictFile), - "xf=ChooseColumns{col=Features}" + "xf=SelectColumns{keepcol=Features}" }, suffix: "NgramHash"); @@ -371,7 +371,7 @@ public void SavePipeTermDictionary() new[] { "loader=Text{col=T1:TX:0 col=T2:TX:1}", string.Format(textSettings, terms), - "xf=ChooseColumns{col=Features}" + "xf=SelectColumns{keepcol=Features}" }, suffix: "NgramTerms"); terms = "sport,baseball,padres,med,erythromycin"; @@ -381,7 +381,7 @@ public void SavePipeTermDictionary() new[] { "loader=Text{col=T1:TX:0 col=T2:TX:1}", string.Format(textSettings, terms), - "xf=ChooseColumns{col=Features}" + "xf=SelectColumns{keepcol=Features}" }, suffix: "NgramHashTermsDropNA"); terms = "sport,baseball,mcgriff,med,erythromycin"; @@ -391,7 +391,7 @@ public void SavePipeTermDictionary() new[] { "loader=Text{col=T1:TX:0 col=T2:TX:1}", string.Format(textSettings, terms), - "xf=ChooseColumns{col=Features}" + "xf=SelectColumns{keepcol=Features}" }, suffix: "NgramTermsDropNA"); terms = "hello"; @@ -401,7 +401,7 @@ public void SavePipeTermDictionary() new[] { "loader=Text{col=T1:TX:0 col=T2:TX:1}", string.Format(textSettings, terms), - "xf=ChooseColumns{col=T1 col=T2 col=Features}" + "xf=SelectColumns{keepcol=T1 keepcol=T2 keepcol=Features}" }, suffix: "EmptyNgramTermsDropNA"); Done(); @@ -444,7 +444,7 @@ public void SavePipeHash() "xf=Hash{bits=7 ordered+ col={name=VarHash3 src=VarU1} col={name=VarHash4 src=VarU2} col={name=VarHash5 src=VarU4} col={name=VarHash6 src=VarU8}}", "xf=Hash{bits=4 col={name=SingleHash src=Single ordered+}}", "xf=Concat{col=VarComb:VarHash1,VarHash2,VarHash3,VarHash4,VarHash5,VarHash6}", - "xf=ChooseColumns{col=SingleHash col=Hash0 col=Hash1 col=Hash2 col=Hash3 col=Hash4 col=Hash5 col=Hash6 col=Hash7 col=Hash8 col=Hash9 col=Hash10 col=Hash11 col=Hash12 col=VarComb}", + "xf=SelectColumns{keepcol=SingleHash keepcol=Hash0 keepcol=Hash1 keepcol=Hash2 keepcol=Hash3 keepcol=Hash4 keepcol=Hash5 keepcol=Hash6 keepcol=Hash7 keepcol=Hash8 keepcol=Hash9 keepcol=Hash10 keepcol=Hash11 keepcol=Hash12 keepcol=VarComb}", }, logCurs: true); Done(); diff --git a/test/Microsoft.ML.Tests/SelectColumnsTransformsTests.cs b/test/Microsoft.ML.Tests/SelectColumnsTransformsTests.cs deleted file mode 100644 index 198c58e247..0000000000 --- a/test/Microsoft.ML.Tests/SelectColumnsTransformsTests.cs +++ /dev/null @@ -1,361 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.RunTests; -using Microsoft.ML.Runtime.Tools; -using Microsoft.ML.TestFramework; -using Microsoft.ML.Transforms; -using System; -using System.IO; -using Xunit; -using Xunit.Abstractions; - -namespace Microsoft.ML.Tests -{ - public class SelectColumnsTransformsTests : TestDataPipeBase - { - class TestClass - { - public int A; - public int B; - public int C; - } - - class TestClass2 - { - public int D; - public int E; - } - - public SelectColumnsTransformsTests(ITestOutputHelper output) : base(output) - { - } - - [Fact] - void TestSelect() - { - var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var est = new SelectColumnsEstimator(env, "A", "C"); - var transformer = est.Fit(dataView); - var result = transformer.Transform(dataView); - var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); - var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); - var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); - - Assert.True(foundColumnA); - Assert.Equal(0, aIdx); - Assert.False(foundColumnB); - Assert.Equal(0, bIdx); - Assert.True(foundColumnC); - Assert.Equal(1, cIdx); - } - } - - [Fact] - void TestSelectWorkout() - { - var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - var invalidData = new [] { new TestClass2 { D = 3, E = 5} }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var invalidDataView = ComponentCreation.CreateDataView(env, invalidData); - - // Workout on keep columns - var est = new SelectColumnsEstimator(env, new[] {"A", "B"}, null, true, false); - TestEstimatorCore(est, validFitInput: dataView, invalidInput: invalidDataView); - - // Workout on drop columns - est = new SelectColumnsEstimator(env, null, new[] {"A", "B"}, true, false); - TestEstimatorCore(est, validFitInput: dataView, invalidInput: invalidDataView); - - // Workout on keep columns with ignore mismatch -- using invalid data set - est = new SelectColumnsEstimator(env, null, new[] {"A", "B"}, true); - TestEstimatorCore(est, validFitInput: invalidDataView); - } - } - - [Fact] - void TestSelectColumnsNoMatch() - { - var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var est = new SelectColumnsEstimator(env, new[] {"D", "G"}); - var transformer = est.Fit(dataView); - var result = transformer.Transform(dataView); - var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); - var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); - var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); - - Assert.False(foundColumnA); - Assert.Equal(0, aIdx); - Assert.False(foundColumnB); - Assert.Equal(0, bIdx); - Assert.False(foundColumnC); - Assert.Equal(0, cIdx); - } - } - - [Fact] - void TestSelectColumnsWithSameName() - { - var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var est = new CopyColumnsEstimator(env, new[] {("A", "A"), ("B", "B")}); - var chain = est.Append(new SelectColumnsEstimator(env, new[]{"A", "C" })); - var transformer = chain.Fit(dataView); - var result = transformer.Transform(dataView); - - // Copied columns should equal AABBC, however we chose to keep A and C - // so the result is AAC - Assert.Equal(3, result.Schema.ColumnCount); - var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); - var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); - var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); - Assert.True(foundColumnA); - Assert.Equal(1, aIdx); - Assert.False(foundColumnB); - Assert.Equal(0, bIdx); - Assert.True(foundColumnC); - Assert.Equal(2, cIdx); - } - } - - [Fact] - void TestSelectColumnsWithNoKeepHidden() - { - var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var est = new CopyColumnsEstimator(env, new[] {("A", "A"), ("B", "B")}); - var chain = est.Append(new SelectColumnsEstimator(env, new[] {"A", "B" }, null, false)); - var transformer = chain.Fit(dataView); - var result = transformer.Transform(dataView); - - // Copied columns should equal AABBC, however we chose to keep A and B - // and not keeping the columns hidden, so the result should be AB - Assert.Equal(2, result.Schema.ColumnCount); - var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); - var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); - var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); - Assert.True(foundColumnA); - Assert.Equal(0, aIdx); - Assert.True(foundColumnB); - Assert.Equal(1, bIdx); - Assert.False(foundColumnC); - Assert.Equal(0, cIdx); - } - } - - [Fact] - void TestSelectWithKeepAndDropSet() - { - using (var env = new ConsoleEnvironment()) - { - var test = new string[]{ "D", "G"}; - Assert.Throws(() => new SelectColumnsEstimator(env, test, test)); - } - } - - [Fact] - void TestSelectNoKeepAndDropSet() - { - using (var env = new ConsoleEnvironment()) - { - Assert.Throws(() => new SelectColumnsEstimator(env, null, null)); - } - } - - [Fact] - void TestSelectSavingAndLoading() - { - var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var est = new SelectColumnsEstimator(env, new[] { "A", "B" }); - var transformer = est.Fit(dataView); - using (var ms = new MemoryStream()) - { - transformer.SaveTo(env, ms); - ms.Position = 0; - var loadedTransformer = TransformerChain.LoadFrom(env, ms); - var result = loadedTransformer.Transform(dataView); - Assert.Equal(2, result.Schema.ColumnCount); - Assert.Equal("A", result.Schema.GetColumnName(0)); - Assert.Equal("B", result.Schema.GetColumnName(1)); - } - } - } - - [Fact] - void TestSelectSavingAndLoadingWithNoKeepHidden() - { - var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var est = new CopyColumnsEstimator(env, new[] {("A", "A"), ("B", "B")}).Append( - new SelectColumnsEstimator(env, new[] { "A", "B" }, null, false)); - var transformer = est.Fit(dataView); - using (var ms = new MemoryStream()) - { - transformer.SaveTo(env, ms); - ms.Position = 0; - var loadedTransformer = TransformerChain.LoadFrom(env, ms); - var result = loadedTransformer.Transform(dataView); - Assert.Equal(2, result.Schema.ColumnCount); - Assert.Equal("A", result.Schema.GetColumnName(0)); - Assert.Equal("B", result.Schema.GetColumnName(1)); - } - } - } - - [Fact] - void TestSelectBackCompatDropColumns() - { - var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - string dropModelPath = GetDataPath("backcompat/drop-model.zip"); - using (FileStream fs = File.OpenRead(dropModelPath)) - { - var result = ModelFileUtils.LoadTransforms(env, dataView, fs); - var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); - var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); - var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); - Assert.False(foundColumnA); - Assert.Equal(0, aIdx); - Assert.True(foundColumnB); - Assert.Equal(0, bIdx); - Assert.True(foundColumnC); - Assert.Equal(1, cIdx); - } - } - } - - [Fact] - void TestSelectBackCompatKeepColumns() - { - var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - string dropModelPath = GetDataPath("backcompat/keep-model.zip"); - using (FileStream fs = File.OpenRead(dropModelPath)) - { - var result = ModelFileUtils.LoadTransforms(env, dataView, fs); - var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); - var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); - var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); - Assert.True(foundColumnA); - Assert.Equal(0, aIdx); - Assert.True(foundColumnB); - Assert.Equal(1, bIdx); - Assert.False(foundColumnC); - Assert.Equal(0, cIdx); - } - } - } - - [Fact] - void TestSelectBackCompatChooseColumns() - { - var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - string dropModelPath = GetDataPath("backcompat/choose-model.zip"); - using (FileStream fs = File.OpenRead(dropModelPath)) - { - var result = ModelFileUtils.LoadTransforms(env, dataView, fs); - var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); - var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); - var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); - Assert.True(foundColumnA); - Assert.Equal(0, aIdx); - Assert.True(foundColumnB); - Assert.Equal(1, bIdx); - Assert.False(foundColumnC); - Assert.Equal(0, cIdx); - } - } - } - - [Fact] - void TestSelectBackCompatChooseColumnsWithKeep() - { - // Model generated with: xf=copy{col=A:A col=B:B} xf=choose{col=Label col=Features col=A col=B hidden=keep} - // Output expected is AABB - var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - string chooseModelPath = GetDataPath("backcompat/choose-keep-model.zip"); - using (FileStream fs = File.OpenRead(chooseModelPath)) - { - var result = ModelFileUtils.LoadTransforms(env, dataView, fs); - Assert.Equal(4, result.Schema.ColumnCount); - var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); - var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); - var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); - Assert.True(foundColumnA); - Assert.Equal(1, aIdx); - Assert.True(foundColumnB); - Assert.Equal(3, bIdx); - Assert.False(foundColumnC); - Assert.Equal(0, cIdx); - } - } - } - - [Fact] - void TestCommandLineWithKeep() - { - using (var env = new ConsoleEnvironment()) - { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=select{keepcol=A keepcol=B} in=f:\1.txt" }), (int)0); - } - } - - [Fact] - void TestCommandLineWithDrop() - { - using (var env = new ConsoleEnvironment()) - { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=select{dropcol=A dropcol=B} in=f:\1.txt" }), (int)0); - } - } - - [Fact] - void TestCommandLineKeepWithoutHidden() - { - using (var env = new ConsoleEnvironment()) - { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=select{keepcol=A keepcol=B hidden=-} in=f:\1.txt" }), (int)0); - } - } - - [Fact] - void TestCommandLineKeepWithIgnoreMismatch() - { - using (var env = new ConsoleEnvironment()) - { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=select{keepcol=A keepcol=B ignore=-} in=f:\1.txt" }), (int)0); - } - } - } -} diff --git a/test/Microsoft.ML.Tests/Transformers/CategoricalHashTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoricalHashTests.cs index 6bed53f855..94efae4a3b 100644 --- a/test/Microsoft.ML.Tests/Transformers/CategoricalHashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CategoricalHashTests.cs @@ -87,9 +87,9 @@ public void CategoricalHashStatic() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); var savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); - savedData = new ChooseColumnsTransform(Env, savedData, "A", "B", "C", "D", "E"); + var view = SelectColumnsTransform.CreateKeep(Env, savedData, false, "A", "B", "C", "D", "E"); using (var fs = File.Create(outputPath)) - DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); + DataSaverUtils.SaveDataView(ch, saver, view, fs, keepHidden: true); } CheckEquality("CategoricalHash", "featurized.tsv"); diff --git a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs index 6604de22ec..c1bd93f17f 100644 --- a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs @@ -9,7 +9,6 @@ using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Runtime.Tools; using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Categorical; using System; using System.IO; using System.Linq; @@ -91,9 +90,9 @@ public void CategoricalStatic() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); var savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); - savedData = new ChooseColumnsTransform(Env, savedData, "A", "B", "C", "D", "E"); + var view = new SelectColumnsTransform(Env, new string[]{"A", "B", "C", "D", "E" }, null, false).Transform(savedData); using (var fs = File.Create(outputPath)) - DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); + DataSaverUtils.SaveDataView(ch, saver, view, fs, keepHidden: true); } CheckEquality("Categorical", "featurized.tsv"); diff --git a/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs b/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs index ac7b9dfde0..9133bdbcf9 100644 --- a/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs @@ -63,7 +63,7 @@ ColumnType GetType(Schema schema, string name) t = GetType(data.Schema, "f4"); Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 0); - data = new ChooseColumnsTransform(Env, data, "f1", "f2", "f3", "f4"); + data = SelectColumnsTransform.CreateKeep(Env, data, "f1", "f2", "f3", "f4"); var subdir = Path.Combine("Transform", "Concat"); var outputPath = GetOutputPath(subdir, "Concat1.tsv"); @@ -115,7 +115,7 @@ ColumnType GetType(Schema schema, string name) t = GetType(data.Schema, "f3"); Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 5); - data = new ChooseColumnsTransform(Env, data, "f2", "f3"); + data = SelectColumnsTransform.CreateKeep(Env, data, "f2", "f3"); var subdir = Path.Combine("Transform", "Concat"); var outputPath = GetOutputPath(subdir, "Concat2.tsv"); diff --git a/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs similarity index 100% rename from test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs rename to test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs diff --git a/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs b/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs index 6db0fdd329..ccbb809982 100644 --- a/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs @@ -42,8 +42,8 @@ public void FeatureSelectionWorkout() using (var ch = Env.Start("save")) { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); - IDataView savedData = TakeFilter.Create(Env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); - savedData = new ChooseColumnsTransform(Env, savedData, "bag_of_words_count", "bag_of_words_mi"); + IDataView savedData = TakeFilter.Create(Env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); + savedData = SelectColumnsTransform.CreateKeep(Env, savedData, "bag_of_words_count", "bag_of_words_mi"); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToValueTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToValueTests.cs index c55912a9ed..50bbe0c2c9 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToValueTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToValueTests.cs @@ -97,8 +97,9 @@ public void KeyToValuePigsty() TestEstimatorCore(est.AsDynamic, data2.AsDynamic, invalidInput: data.AsDynamic); // Check that term and ToValue are round-trippable. - var dataLeft = new ChooseColumnsTransform(Env, data.AsDynamic, "ScalarString", "VectorString"); - var dataRight = new ChooseColumnsTransform(Env, est.Fit(data2).Transform(data2).AsDynamic, "ScalarString", "VectorString"); + var dataLeft = SelectColumnsTransform.CreateKeep(Env, data.AsDynamic, "ScalarString", "VectorString"); + var dataRight = SelectColumnsTransform.CreateKeep(Env, est.Fit(data2).Transform(data2).AsDynamic, "ScalarString", "VectorString"); + CheckSameSchemas(dataLeft.Schema, dataRight.Schema); CheckSameValues(dataLeft, dataRight); Done(); diff --git a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs index 501cf6f014..200ee73e06 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs @@ -81,9 +81,9 @@ public void NAReplaceStatic() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); var savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); - savedData = new ChooseColumnsTransform(Env, savedData, "A", "B", "C", "D"); + var view = SelectColumnsTransform.CreateKeep(Env, savedData, "A", "B", "C", "D"); using (var fs = File.Create(outputPath)) - DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); + DataSaverUtils.SaveDataView(ch, saver, view, fs, keepHidden: true); } CheckEquality("NAReplace", "featurized.tsv"); diff --git a/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs b/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs index 5b86db03ab..56fbb2d976 100644 --- a/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs @@ -70,7 +70,10 @@ public void NormalizerWorkout() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); using (var fs = File.Create(outputPath)) - DataSaverUtils.SaveDataView(ch, saver, new DropColumnsTransform(Env, est.Fit(data).Transform(data), "float0"), fs, keepHidden: true); + { + var dataView = SelectColumnsTransform.CreateDrop(Env, est.Fit(data).Transform(data), true, "float0"); + DataSaverUtils.SaveDataView(ch, saver, dataView, fs, keepHidden: true); + } } CheckEquality("NormalizerEstimator", "normalized.tsv"); @@ -141,7 +144,7 @@ public void LpGcNormAndWhiteningWorkout() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true, OutputHeader = false }); IDataView savedData = TakeFilter.Create(Env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); - savedData = new ChooseColumnsTransform(Env, savedData, "lpnorm", "gcnorm", "whitened"); + savedData = SelectColumnsTransform.CreateKeep(Env, savedData, "lpnorm", "gcnorm", "whitened"); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); diff --git a/test/Microsoft.ML.Tests/Transformers/PcaTests.cs b/test/Microsoft.ML.Tests/Transformers/PcaTests.cs index 57593e543d..291b635662 100644 --- a/test/Microsoft.ML.Tests/Transformers/PcaTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/PcaTests.cs @@ -62,7 +62,7 @@ public void TestPcaEstimator() using (var ch = _env.Start("save")) { IDataView savedData = TakeFilter.Create(_env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); - savedData = new ChooseColumnsTransform(_env, savedData, "pca"); + savedData = SelectColumnsTransform.CreateKeep(_env, savedData, "pca"); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, _saver, savedData, fs, keepHidden: true); diff --git a/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs b/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs new file mode 100644 index 0000000000..d8900f5b6c --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs @@ -0,0 +1,404 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using Microsoft.ML.Transforms; +using System; +using System.IO; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class SelectColumnsTransformsTests : TestDataPipeBase + { + class TestClass + { + public int A; + public int B; + public int C; + } + + class TestClass2 + { + public int D; + public int E; + } + class TestClass3 + { + public string Label; + public string Features; + public int A; + public int B; + public int C; + }; + + public SelectColumnsTransformsTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + void TestSelectKeep() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var est = new SelectColumnsEstimator(Env, "A", "C"); + var transformer = est.Fit(dataView); + var result = transformer.Transform(dataView); + var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); + var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); + var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); + + Assert.True(foundColumnA); + Assert.Equal(0, aIdx); + Assert.False(foundColumnB); + Assert.Equal(0, bIdx); + Assert.True(foundColumnC); + Assert.Equal(1, cIdx); + } + + [Fact] + void TestSelectKeepWithOrder() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + + // Expected output will be CA + var est = new SelectColumnsEstimator(Env, "C", "A"); + var transformer = est.Fit(dataView); + var result = transformer.Transform(dataView); + var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); + var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); + var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); + + Assert.True(foundColumnA); + Assert.Equal(1, aIdx); + Assert.False(foundColumnB); + Assert.Equal(0, bIdx); + Assert.True(foundColumnC); + Assert.Equal(0, cIdx); + } + + [Fact] + void TestSelectDrop() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var est = new SelectColumnsEstimator(Env, null, new string[] { "A", "C" }); + var transformer = est.Fit(dataView); + var result = transformer.Transform(dataView); + var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); + var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); + var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); + + Assert.False(foundColumnA); + Assert.Equal(0, aIdx); + Assert.True(foundColumnB); + Assert.Equal(0, bIdx); + Assert.False(foundColumnC); + Assert.Equal(0, cIdx); + } + + [Fact] + void TestSelectWorkout() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var invalidData = new [] { new TestClass2 { D = 3, E = 5} }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var invalidDataView = ComponentCreation.CreateDataView(Env, invalidData); + + // Workout on keep columns + var est = new SelectColumnsEstimator(Env, new[] {"A", "B"}, null, true, false); + TestEstimatorCore(est, validFitInput: dataView, invalidInput: invalidDataView); + + // Workout on drop columns + est = new SelectColumnsEstimator(Env, null, new[] {"A", "B"}, true, false); + TestEstimatorCore(est, validFitInput: dataView, invalidInput: invalidDataView); + + // Workout on keep columns with ignore mismatch -- using invalid data set + est = new SelectColumnsEstimator(Env, new[] {"A", "B"}, null, true, true); + TestEstimatorCore(est, validFitInput: invalidDataView); + } + + [Fact] + void TestSelectColumnsWithMissing() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var est = new SelectColumnsEstimator(Env, new[] {"D", "G"}); + Assert.Throws(() => est.Fit(dataView)); + } + + [Fact] + void TestSelectColumnsWithSameName() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var est = new CopyColumnsEstimator(Env, new[] {("A", "A"), ("B", "B")}); + var chain = est.Append(new SelectColumnsEstimator(Env, new[]{"C", "A" })); + var transformer = chain.Fit(dataView); + var result = transformer.Transform(dataView); + + // Copied columns should equal AABBC, however we chose to keep A and C + // so the result is AC + Assert.Equal(2, result.Schema.ColumnCount); + var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); + var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); + var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); + Assert.True(foundColumnA); + Assert.Equal(1, aIdx); + Assert.False(foundColumnB); + Assert.Equal(0, bIdx); + Assert.True(foundColumnC); + Assert.Equal(0, cIdx); + } + + [Fact] + void TestSelectColumnsWithKeepHidden() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var est = new CopyColumnsEstimator(Env, new[] {("A", "A"), ("B", "B")}); + var chain = est.Append(new SelectColumnsEstimator(Env, new[] {"B", "A" }, null, true)); + var transformer = chain.Fit(dataView); + var result = transformer.Transform(dataView); + + // Input for SelectColumns should be AABBC, we chose to keep A and B + // and keep hidden columns is true, therefore the output should be AABB + Assert.Equal(4, result.Schema.ColumnCount); + var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); + var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); + var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); + Assert.True(foundColumnA); + Assert.Equal(3, aIdx); + Assert.True(foundColumnB); + Assert.Equal(1, bIdx); + Assert.False(foundColumnC); + Assert.Equal(0, cIdx); + } + + [Fact] + void TestSelectColumnsDropWithKeepHidden() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var est = new CopyColumnsEstimator(Env, new[] {("A", "A"), ("B", "B")}); + var chain = est.Append(new SelectColumnsEstimator(Env, null, new[] { "A" }, true)); + var transformer = chain.Fit(dataView); + var result = transformer.Transform(dataView); + + // Input for SelectColumns should be AABBC, we chose to drop A + // and keep hidden columns is true, therefore the output should be BBC + Assert.Equal(3, result.Schema.ColumnCount); + var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); + var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); + var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); + Assert.False(foundColumnA); + Assert.Equal(0, aIdx); + Assert.True(foundColumnB); + Assert.Equal(1, bIdx); + Assert.True(foundColumnC); + Assert.Equal(2, cIdx); + } + + [Fact] + void TestSelectWithKeepAndDropSet() + { + // Setting both keep and drop is not allowed. + var test = new string[]{ "D", "G"}; + Assert.Throws(() => new SelectColumnsEstimator(Env, test, test)); + } + + [Fact] + void TestSelectNoKeepAndDropSet() + { + // Passing null to both keep and drop is not allowed. + Assert.Throws(() => new SelectColumnsEstimator(Env, null, null)); + } + + [Fact] + void TestSelectSavingAndLoading() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var est = new SelectColumnsEstimator(Env, new[] { "A", "B" }); + var transformer = est.Fit(dataView); + using (var ms = new MemoryStream()) + { + transformer.SaveTo(Env, ms); + ms.Position = 0; + var loadedTransformer = TransformerChain.LoadFrom(Env, ms); + var result = loadedTransformer.Transform(dataView); + Assert.Equal(2, result.Schema.ColumnCount); + Assert.Equal("A", result.Schema.GetColumnName(0)); + Assert.Equal("B", result.Schema.GetColumnName(1)); + } + } + + [Fact] + void TestSelectSavingAndLoadingWithNoKeepHidden() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var est = new CopyColumnsEstimator(Env, new[] {("A", "A"), ("B", "B")}).Append( + new SelectColumnsEstimator(Env, new[] { "A", "B" }, null, false)); + var transformer = est.Fit(dataView); + using (var ms = new MemoryStream()) + { + transformer.SaveTo(Env, ms); + ms.Position = 0; + var loadedTransformer = TransformerChain.LoadFrom(Env, ms); + var result = loadedTransformer.Transform(dataView); + Assert.Equal(2, result.Schema.ColumnCount); + Assert.Equal("A", result.Schema.GetColumnName(0)); + Assert.Equal("B", result.Schema.GetColumnName(1)); + } + } + + [Fact] + void TestSelectBackCompatDropColumns() + { + // Model generated with: xf=drop{col=A} + // Expected output: Features Label B C + var data = new[] { new TestClass3() { Label="foo", Features="bar", A = 1, B = 2, C = 3, } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + string dropModelPath = GetDataPath("backcompat/drop-model.zip"); + using (FileStream fs = File.OpenRead(dropModelPath)) + { + var result = ModelFileUtils.LoadTransforms(Env, dataView, fs); + var foundColumnFeature = result.Schema.TryGetColumnIndex("Features", out int featureIdx); + var foundColumnLabel = result.Schema.TryGetColumnIndex("Label", out int labelIdx); + var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); + var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); + var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); + Assert.True(foundColumnLabel); + Assert.Equal(0, labelIdx); + Assert.True(foundColumnFeature); + Assert.Equal(1, featureIdx); + Assert.False(foundColumnA); + Assert.Equal(0, aIdx); + Assert.True(foundColumnB); + Assert.Equal(2, bIdx); + Assert.True(foundColumnC); + Assert.Equal(3, cIdx); + } + } + + [Fact] + void TestSelectBackCompatKeepColumns() + { + // Model generated with: xf=keep{col=Label col=Features col=A col=B} + // Expected output: Label Features A B + var data = new[] { new TestClass3() { Label="foo", Features="bar", A = 1, B = 2, C = 3, } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + string dropModelPath = GetDataPath("backcompat/keep-model.zip"); + using (FileStream fs = File.OpenRead(dropModelPath)) + { + var result = ModelFileUtils.LoadTransforms(Env, dataView, fs); + var foundColumnFeature = result.Schema.TryGetColumnIndex("Features", out int featureIdx); + var foundColumnLabel = result.Schema.TryGetColumnIndex("Label", out int labelIdx); + var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); + var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); + var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); + Assert.True(foundColumnLabel); + Assert.Equal(0, labelIdx); + Assert.True(foundColumnFeature); + Assert.Equal(1, featureIdx); + Assert.True(foundColumnA); + Assert.Equal(2, aIdx); + Assert.True(foundColumnB); + Assert.Equal(3, bIdx); + Assert.False(foundColumnC); + Assert.Equal(0, cIdx); + } + } + + [Fact] + void TestSelectBackCompatChooseColumns() + { + // Model generated with: xf=choose{col=Label col=Features col=A col=B} + // Output expected is Label Features A B + var data = new[] { new TestClass3() { Label="foo", Features="bar", A = 1, B = 2, C = 3, } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + string dropModelPath = GetDataPath("backcompat/choose-model.zip"); + using (FileStream fs = File.OpenRead(dropModelPath)) + { + var result = ModelFileUtils.LoadTransforms(Env, dataView, fs); + var foundColumnFeature = result.Schema.TryGetColumnIndex("Features", out int featureIdx); + var foundColumnLabel = result.Schema.TryGetColumnIndex("Label", out int labelIdx); + var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); + var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); + var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); + Assert.True(foundColumnLabel); + Assert.Equal(0, labelIdx); + Assert.True(foundColumnFeature); + Assert.Equal(1, featureIdx); + Assert.True(foundColumnA); + Assert.Equal(2, aIdx); + Assert.True(foundColumnB); + Assert.Equal(3, bIdx); + Assert.False(foundColumnC); + Assert.Equal(0, cIdx); + } + } + + [Fact] + void TestSelectBackCompatChooseColumnsWithKeep() + { + // Model generated with: xf=copy{col=A:A col=B:B} xf=choose{col=Label col=Features col=A col=B hidden=keep} + // Output expected is Label Features A A B B + var data = new[] { new TestClass3() { Label="foo", Features="bar", A = 1, B = 2, C = 3, } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + string chooseModelPath = GetDataPath("backcompat/choose-keep-model.zip"); + using (FileStream fs = File.OpenRead(chooseModelPath)) + { + var result = ModelFileUtils.LoadTransforms(Env, dataView, fs); + Assert.Equal(6, result.Schema.ColumnCount); + var foundColumnFeature = result.Schema.TryGetColumnIndex("Features", out int featureIdx); + var foundColumnLabel = result.Schema.TryGetColumnIndex("Label", out int labelIdx); + var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx); + var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx); + var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx); + Assert.True(foundColumnLabel); + Assert.Equal(0, labelIdx); + Assert.True(foundColumnFeature); + Assert.Equal(1, featureIdx); + Assert.True(foundColumnA); + Assert.Equal(3, aIdx); + Assert.True(foundColumnB); + Assert.Equal(5, bIdx); + Assert.False(foundColumnC); + Assert.Equal(0, cIdx); + } + } + + [Fact] + void TestCommandLineWithKeep() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=select{keepcol=A keepcol=B} in=f:\1.txt" }), (int)0); + } + + [Fact] + void TestCommandLineWithDrop() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=select{dropcol=A dropcol=B} in=f:\1.txt" }), (int)0); + } + + [Fact] + void TestCommandLineKeepWithoutHidden() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=select{keepcol=A keepcol=B hidden=-} in=f:\1.txt" }), (int)0); + } + + [Fact] + void TestCommandLineKeepWithIgnoreMismatch() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=select{keepcol=A keepcol=B ignore=-} in=f:\1.txt" }), (int)0); + } + } +} diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index 2abb38fcd1..1196cff7c1 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -46,7 +46,7 @@ public void TextFeaturizerWorkout() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); IDataView savedData = TakeFilter.Create(Env, feat.Fit(data).Transform(data).AsDynamic, 4); - savedData = new ChooseColumnsTransform(Env, savedData, "Data", "Data_TransformedText"); + savedData = SelectColumnsTransform.CreateKeep(Env, savedData, "Data", "Data_TransformedText"); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); @@ -80,7 +80,7 @@ public void TextTokenizationWorkout() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); IDataView savedData = TakeFilter.Create(Env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); - savedData = new ChooseColumnsTransform(Env, savedData, "text", "words", "chars"); + savedData = SelectColumnsTransform.CreateKeep(Env, savedData, "text", "words", "chars"); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); @@ -101,7 +101,7 @@ public void TokenizeWithSeparators() var est = new WordTokenizingEstimator(Env, "text", "words", separators: new[] { ' ', '?', '!', '.', ','}); var outdata = TakeFilter.Create(Env, est.Fit(data).Transform(data), 4); - var savedData = new ChooseColumnsTransform(Env, outdata, "words"); + var savedData = SelectColumnsTransform.CreateKeep(Env, outdata, "words"); var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); var outputPath = GetOutputPath("Text", "tokenizedWithSeparators.tsv"); @@ -151,7 +151,7 @@ public void TextNormalizationAndStopwordRemoverWorkout() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); IDataView savedData = TakeFilter.Create(Env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); - savedData = new ChooseColumnsTransform(Env, savedData, "text", "words_without_stopwords"); + savedData = SelectColumnsTransform.CreateKeep(Env, savedData, "text", "words_without_stopwords"); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); @@ -187,7 +187,7 @@ public void WordBagWorkout() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); IDataView savedData = TakeFilter.Create(Env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); - savedData = new ChooseColumnsTransform(Env, savedData, "text", "bag_of_words", "bag_of_wordshash"); + savedData = SelectColumnsTransform.CreateKeep(Env, savedData, "text", "bag_of_words", "bag_of_wordshash"); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); @@ -225,7 +225,7 @@ public void NgramWorkout() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); IDataView savedData = TakeFilter.Create(Env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); - savedData = new ChooseColumnsTransform(Env, savedData, "text", "terms", "ngrams", "ngramshash"); + savedData = SelectColumnsTransform.CreateKeep(Env, savedData, "text", "terms", "ngrams", "ngramshash"); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); @@ -265,7 +265,7 @@ public void LdaWorkout() { var saver = new TextSaver(env, new TextSaver.Arguments { Silent = true, OutputHeader = false, Dense = true }); IDataView savedData = TakeFilter.Create(env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); - savedData = new ChooseColumnsTransform(env, savedData, "topics"); + savedData = SelectColumnsTransform.CreateKeep(env, savedData, "topics"); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true);