diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs index e0d7483df0..a6fd218f5c 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs @@ -199,13 +199,13 @@ public ColumnSelectingTransformer(IHostEnvironment env, string[] keepColumns, st _host.CheckValueOrNull(keepColumns); _host.CheckValueOrNull(dropColumns); - bool keepValid = keepColumns != null && keepColumns.Count() > 0; - bool dropValid = dropColumns != null && dropColumns.Count() > 0; + bool keepValid = Utils.Size(keepColumns) > 0; + bool dropValid = Utils.Size(dropColumns) > 0; // Check that both are not valid - _host.Check(!(keepValid && dropValid), "Both keepColumns and dropColumns are set, only one can be specified."); + _host.Check(!(keepValid && dropValid), "Both " + nameof(keepColumns) + " and " + nameof(dropColumns) + " are set. Exactly one can be specified."); // Check that both are invalid - _host.Check(!(!keepValid && !dropValid), "Neither keepColumns or dropColumns is set, one must be specified."); + _host.Check(!(!keepValid && !dropValid), "Neither " + nameof(keepColumns) + " and " + nameof(dropColumns) + " is set. Exactly one must be specified."); _selectedColumns = (keepValid) ? keepColumns : dropColumns; KeepColumns = keepValid; @@ -558,7 +558,7 @@ private static int[] BuildOutputToInputMap(IEnumerable selectedColumns, // given an input of ABC and dropping column B will result in AC. // In drop mode, we drop all columns with the specified names and keep all the rest, // ignoring the keepHidden argument. - for(int colIdx = 0; colIdx < inputSchema.Count; colIdx++) + for (int colIdx = 0; colIdx < inputSchema.Count; colIdx++) { if (selectedColumns.Contains(inputSchema[colIdx].Name)) continue; diff --git a/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs b/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs index d2481fb4be..a6bb1baee0 100644 --- a/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs @@ -112,19 +112,16 @@ public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.Co => new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, maxNumTerms, sort); /// - /// Converts value types into loading the keys to use from . + /// Converts value types into , optionally loading the keys to use from . /// /// The categorical transform's catalog. /// The data columns to map to keys. - /// The path of the file containing the terms. - /// - /// + /// The data view containing the terms. If specified, this should be a single column data + /// view, and the key-values will be taken from taht column. If unspecified, the key-values will be determined + /// from the input data upon fitting. public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.ConversionTransforms catalog, - ValueToKeyMappingTransformer.ColumnInfo[] columns, - string file = null, - string termsColumn = null, - IComponentFactory loaderFactory = null) - => new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns, file, termsColumn, loaderFactory); + ValueToKeyMappingTransformer.ColumnInfo[] columns, IDataView keyData = null) + => new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns, keyData); /// /// Maps specified keys to specified values diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs index 3885563165..dfb1f5c19c 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs @@ -9,7 +9,7 @@ namespace Microsoft.ML.Transforms.Conversions { /// - public sealed class ValueToKeyMappingEstimator: IEstimator + public sealed class ValueToKeyMappingEstimator : IEstimator { public static class Defaults { @@ -19,9 +19,7 @@ public static class Defaults private readonly IHost _host; private readonly ValueToKeyMappingTransformer.ColumnInfo[] _columns; - private readonly string _file; - private readonly string _termsColumn; - private readonly IComponentFactory _loaderFactory; + private readonly IDataView _keyData; /// /// Initializes a new instance of . @@ -33,23 +31,28 @@ public static class Defaults /// How items should be ordered when vectorized. If choosen they will be in the order encountered. /// If , items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). public ValueToKeyMappingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int maxNumTerms = Defaults.MaxNumTerms, ValueToKeyMappingTransformer.SortOrder sort = Defaults.Sort) : - this(env, new [] { new ValueToKeyMappingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort) }) + this(env, new[] { new ValueToKeyMappingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort) }) { } - public ValueToKeyMappingEstimator(IHostEnvironment env, ValueToKeyMappingTransformer.ColumnInfo[] columns, - string file = null, string termsColumn = null, - IComponentFactory loaderFactory = null) + public ValueToKeyMappingEstimator(IHostEnvironment env, ValueToKeyMappingTransformer.ColumnInfo[] columns, IDataView keyData = null) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(ValueToKeyMappingEstimator)); + _host.CheckNonEmpty(columns, nameof(columns)); + _host.CheckValueOrNull(keyData); + if (keyData != null && keyData.Schema.Count != 1) + { + throw _host.ExceptParam(nameof(keyData), "If specified, this data view should contain only a single column " + + $"containing the terms to map, but this had {keyData.Schema.Count} columns."); + + } + _columns = columns; - _file = file; - _termsColumn = termsColumn; - _loaderFactory = loaderFactory; + _keyData = keyData; } - public ValueToKeyMappingTransformer Fit(IDataView input) => new ValueToKeyMappingTransformer(_host, input, _columns, _file, _termsColumn, _loaderFactory); + public ValueToKeyMappingTransformer Fit(IDataView input) => new ValueToKeyMappingTransformer(_host, input, _columns, _keyData, false); public SchemaShape GetOutputSchema(SchemaShape inputSchema) { diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index eb91b94a00..1a559d636a 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -290,24 +290,20 @@ private ColInfo[] CreateInfos(Schema inputSchema) internal ValueToKeyMappingTransformer(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) : - this(env, input, columns, null, null, null) + this(env, input, columns, null, false) { } internal ValueToKeyMappingTransformer(IHostEnvironment env, IDataView input, - ColumnInfo[] columns, - string file = null, string termsColumn = null, - IComponentFactory loaderFactory = null) + ColumnInfo[] columns, IDataView keyData, bool autoConvert) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { using (var ch = Host.Start("Training")) { var infos = CreateInfos(input.Schema); - _unboundMaps = Train(Host, ch, infos, file, termsColumn, loaderFactory, columns, input); + _unboundMaps = Train(Host, ch, infos, keyData, columns, input, autoConvert); _textMetadata = new bool[_unboundMaps.Length]; for (int iinfo = 0; iinfo < columns.Length; ++iinfo) - { _textMetadata[iinfo] = columns[iinfo].TextKeyValues; - } ch.Assert(_unboundMaps.Length == columns.Length); } } @@ -348,8 +344,9 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat item.TextKeyValues ?? args.TextKeyValues); cols[i].Terms = item.Terms ?? args.Terms; }; + var keyData = GetKeyDataViewOrNull(env, ch, args.DataFile, args.TermsColumn, args.Loader, out bool autoLoaded); + return new ValueToKeyMappingTransformer(env, input, cols, keyData, autoLoaded).MakeDataTransform(input); } - return new ValueToKeyMappingTransformer(env, input, cols, args.DataFile, args.TermsColumn, args.Loader).MakeDataTransform(input); } // Factory method for SignatureLoadModel. @@ -416,29 +413,44 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sch => Create(env, ctx).MakeRowMapper(inputSchema); /// - /// Utility method to create the file-based . + /// Returns a single-column , based on values from , + /// in the case where is set. If that is not set, this will + /// return . /// - private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, string file, string termsColumn, - IComponentFactory loaderFactory, Builder bldr) + /// The host environment. + /// The host channel to use to mark exceptions and log messages. + /// The name of the file. Must be specified if this method is called. + /// The single column to select out of this transform. If not specified, + /// this method will attempt to guess. + /// The loader creator. If we will attempt to determine + /// this + /// Whether we should try to convert to the desired type by ourselves when doing + /// the term map. This will not be true in the case that the loader was adequately specified automatically. + /// The single-column data containing the term data from the file. + [BestFriend] + internal static IDataView GetKeyDataViewOrNull(IHostEnvironment env, IChannel ch, + string file, string termsColumn, IComponentFactory loaderFactory, + out bool autoConvert) { - Contracts.AssertValue(ch); ch.AssertValue(env); - ch.Assert(!string.IsNullOrWhiteSpace(file)); - ch.AssertValue(bldr); + ch.AssertValueOrNull(file); + ch.AssertValueOrNull(termsColumn); + ch.AssertValueOrNull(loaderFactory); + + // If the user manually specifies a loader, or this is already a pre-processed binary + // file, then we assume the user knows what they're doing when they are so explicit, + // and do not attempt to convert to the desired type ourselves. + autoConvert = false; + if (string.IsNullOrWhiteSpace(file)) + return null; // First column using the file. string src = termsColumn; IMultiStreamSource fileSource = new MultiFileSource(file); - // If the user manually specifies a loader, or this is already a pre-processed binary - // file, then we assume the user knows what they're doing and do not attempt to convert - // to the desired type ourselves. - bool autoConvert = false; - IDataView termData; + IDataView keyData; if (loaderFactory != null) - { - termData = loaderFactory.CreateComponent(env, fileSource); - } + keyData = loaderFactory.CreateComponent(env, fileSource); else { // Determine the default loader from the extension. @@ -451,11 +463,11 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, stri ch.CheckUserArg(!string.IsNullOrWhiteSpace(src), nameof(termsColumn), "Must be specified"); if (isBinary) - termData = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource); + keyData = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource); else { ch.Assert(isTranspose); - termData = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource); + keyData = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource); } } else @@ -463,32 +475,53 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, stri if (!string.IsNullOrWhiteSpace(src)) { ch.Warning( - "{0} should not be specified when default loader is TextLoader. Ignoring {0}={1}", + "{0} should not be specified when default loader is " + nameof(TextLoader) + ". Ignoring {0}={1}", nameof(Arguments.TermsColumn), src); } - termData = new TextLoader(env, + keyData = new TextLoader(env, columns: new[] { new TextLoader.Column("Term", DataKind.TX, 0) }, dataSample: fileSource) .Read(fileSource); src = "Term"; + // In this case they are relying on heuristics, so auto-loading in this case is most appropriate. autoConvert = true; } } ch.AssertNonEmpty(src); - - int colSrc; - if (!termData.Schema.TryGetColumnIndex(src, out colSrc)) + if (keyData.Schema.GetColumnOrNull(src) == null) throw ch.ExceptUserArg(nameof(termsColumn), "Unknown column '{0}'", src); - var typeSrc = termData.Schema[colSrc].Type; + // Now, remove everything but that one column. + var selectTransformer = new ColumnSelectingTransformer(env, new string[] { src }, null); + keyData = selectTransformer.Transform(keyData); + ch.Assert(keyData.Schema.Count == 1); + return keyData; + } + + /// + /// Utility method to create the file-based . + /// + private static TermMap CreateTermMapFromData(IHostEnvironment env, IChannel ch, IDataView keyData, bool autoConvert, Builder bldr) + { + Contracts.AssertValue(ch); + ch.AssertValue(env); + ch.AssertValue(keyData); + ch.AssertValue(bldr); + if (keyData.Schema.Count != 1) + { + throw ch.ExceptParam(nameof(keyData), $"Input data containing terms should contain exactly one column, but " + + $"had {keyData.Schema.Count} instead. Consider using {nameof(ColumnSelectingEstimator)} on that data first."); + } + + var typeSrc = keyData.Schema[0].Type; if (!autoConvert && !typeSrc.Equals(bldr.ItemType)) - throw ch.ExceptUserArg(nameof(termsColumn), "Must be of type '{0}' but was '{1}'", bldr.ItemType, typeSrc); + throw ch.ExceptUserArg(nameof(keyData), "Input data's column must be of type '{0}' but was '{1}'", bldr.ItemType, typeSrc); - using (var cursor = termData.GetRowCursor(termData.Schema[colSrc])) - using (var pch = env.StartProgressChannel("Building term dictionary from file")) + using (var cursor = keyData.GetRowCursor(keyData.Schema[0])) + using (var pch = env.StartProgressChannel("Building dictionary from term data")) { var header = new ProgressHeader(new[] { "Total Terms" }, new[] { "examples" }); - var trainer = Trainer.Create(cursor, colSrc, autoConvert, int.MaxValue, bldr); - double rowCount = termData.GetRowCount() ?? double.NaN; + var trainer = Trainer.Create(cursor, 0, autoConvert, int.MaxValue, bldr); + double rowCount = keyData.GetRowCount() ?? double.NaN; long rowCur = 0; pch.SetHeader(header, e => @@ -501,7 +534,7 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, stri while (cursor.MoveNext() && trainer.ProcessRow()) rowCur++; if (trainer.Count == 0) - ch.Warning("Term map loaded from file resulted in an empty map."); + ch.Warning("Map from the term data resulted in an empty map."); pch.Checkpoint(trainer.Count, rowCur); return trainer.Finish(); } @@ -511,12 +544,12 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, stri /// This builds the instances per column. /// private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] infos, - string file, string termsColumn, - IComponentFactory loaderFactory, ColumnInfo[] columns, IDataView trainingData) + IDataView keyData, ColumnInfo[] columns, IDataView trainingData, bool autoConvert) { Contracts.AssertValue(env); env.AssertValue(ch); ch.AssertValue(infos); + ch.AssertValueOrNull(keyData); ch.AssertValue(columns); ch.AssertValue(trainingData); @@ -544,13 +577,13 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info bldr.ParseAddTermArg(termsArray, ch); termMap[iinfo] = bldr.Finish(); } - else if (!string.IsNullOrWhiteSpace(file)) + else if (keyData != null) { // First column using this file. if (termsFromFile == null) { var bldr = Builder.Create(infos[iinfo].TypeSrc, columns[iinfo].Sort); - termsFromFile = CreateFileTermMap(env, ch, file, termsColumn, loaderFactory, bldr); + termsFromFile = CreateTermMapFromData(env, ch, keyData, autoConvert, bldr); } if (!termsFromFile.ItemType.Equals(infos[iinfo].TypeSrc.GetItemType())) { @@ -559,7 +592,7 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info // a complicated feature would be, and also because it's difficult to see how we // can logically reconcile "reinterpretation" for different types with the resulting // data view having an actual type. - throw ch.ExceptUserArg(nameof(file), "Data file terms loaded as type '{0}' but mismatches column '{1}' item type '{2}'", + throw ch.ExceptParam(nameof(keyData), "Terms from input data type '{0}' but mismatches column '{1}' item type '{2}'", termsFromFile.ItemType, infos[iinfo].Name, infos[iinfo].TypeSrc.GetItemType()); } termMap[iinfo] = termsFromFile; diff --git a/src/Microsoft.ML.Transforms/OneHotEncoding.cs b/src/Microsoft.ML.Transforms/OneHotEncoding.cs index faf682da4c..028d7c4ef4 100644 --- a/src/Microsoft.ML.Transforms/OneHotEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotEncoding.cs @@ -139,7 +139,15 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat col.SetTerms(column.Terms ?? args.Terms); columns.Add(col); } - return new OneHotEncodingEstimator(env, columns.ToArray(), args.DataFile, args.TermsColumn, args.Loader).Fit(input).Transform(input) as IDataTransform; + IDataView keyData = null; + if (!string.IsNullOrEmpty(args.DataFile)) + { + using (var ch = h.Start("Load term data")) + keyData = ValueToKeyMappingTransformer.GetKeyDataViewOrNull(env, ch, args.DataFile, args.TermsColumn, args.Loader, out bool autoLoaded); + h.AssertValue(keyData); + } + var transformed = new OneHotEncodingEstimator(env, columns.ToArray(), keyData).Fit(input).Transform(input); + return (IDataTransform)transformed; } private readonly TransformerChain _transformer; @@ -220,13 +228,11 @@ public OneHotEncodingEstimator(IHostEnvironment env, string inputColumn, { } - public OneHotEncodingEstimator(IHostEnvironment env, ColumnInfo[] columns, - string file = null, string termsColumn = null, - IComponentFactory loaderFactory = null) + public OneHotEncodingEstimator(IHostEnvironment env, ColumnInfo[] columns, IDataView keyData = null) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(OneHotEncodingEstimator)); - _term = new ValueToKeyMappingEstimator(_host, columns, file, termsColumn, loaderFactory); + _term = new ValueToKeyMappingEstimator(_host, columns, keyData); var binaryCols = new List<(string input, string output)>(); var cols = new List<(string input, string output, bool bag)>(); for (int i = 0; i < columns.Length; i++) diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index 2b3eecd95b..d503442016 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -33,20 +33,20 @@ [assembly: LoadableClass(typeof(IRowMapper), typeof(StopWordsRemovingTransformer), null, typeof(SignatureLoadRowMapper), "Stopwords Remover Transform", StopWordsRemovingTransformer.LoaderSignature)] -[assembly: LoadableClass(CustomStopWordsRemovingTransform.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransform), typeof(CustomStopWordsRemovingTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(CustomStopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransformer), typeof(CustomStopWordsRemovingTransformer.Arguments), typeof(SignatureDataTransform), "Custom Stopwords Remover Transform", "CustomStopWordsRemoverTransform", "CustomStopWords")] -[assembly: LoadableClass(CustomStopWordsRemovingTransform.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransform), null, typeof(SignatureLoadDataTransform), - "Custom Stopwords Remover Transform", CustomStopWordsRemovingTransform.LoaderSignature)] +[assembly: LoadableClass(CustomStopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransformer), null, typeof(SignatureLoadDataTransform), + "Custom Stopwords Remover Transform", CustomStopWordsRemovingTransformer.LoaderSignature)] -[assembly: LoadableClass(CustomStopWordsRemovingTransform.Summary, typeof(CustomStopWordsRemovingTransform), null, typeof(SignatureLoadModel), - "Custom Stopwords Remover Transform", CustomStopWordsRemovingTransform.LoaderSignature)] +[assembly: LoadableClass(CustomStopWordsRemovingTransformer.Summary, typeof(CustomStopWordsRemovingTransformer), null, typeof(SignatureLoadModel), + "Custom Stopwords Remover Transform", CustomStopWordsRemovingTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(IRowMapper), typeof(CustomStopWordsRemovingTransform), null, typeof(SignatureLoadRowMapper), - "Custom Stopwords Remover Transform", CustomStopWordsRemovingTransform.LoaderSignature)] +[assembly: LoadableClass(typeof(IRowMapper), typeof(CustomStopWordsRemovingTransformer), null, typeof(SignatureLoadRowMapper), + "Custom Stopwords Remover Transform", CustomStopWordsRemovingTransformer.LoaderSignature)] [assembly: EntryPointModule(typeof(PredefinedStopWordsRemoverFactory))] -[assembly: EntryPointModule(typeof(CustomStopWordsRemovingTransform.LoaderArguments))] +[assembly: EntryPointModule(typeof(CustomStopWordsRemovingTransformer.LoaderArguments))] namespace Microsoft.ML.Transforms.Text { @@ -596,7 +596,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) /// This is usually applied after tokenizing text, so it compares individual tokens /// (case-insensitive comparison) to the stopwords. /// - public sealed class CustomStopWordsRemovingTransform : OneToOneTransformerBase + public sealed class CustomStopWordsRemovingTransformer : OneToOneTransformerBase { public sealed class Column : OneToOneColumn { @@ -646,9 +646,9 @@ public sealed class LoaderArguments : ArgumentsBase, IStopWordsRemoverFactory public IDataTransform CreateComponent(IHostEnvironment env, IDataView input, OneToOneColumn[] column) { if (Utils.Size(Stopword) > 0) - return new CustomStopWordsRemovingTransform(env, Stopword, column.Select(x => (x.Source, x.Name)).ToArray()).Transform(input) as IDataTransform; + return new CustomStopWordsRemovingTransformer(env, Stopword, column.Select(x => (x.Source, x.Name)).ToArray()).Transform(input) as IDataTransform; else - return new CustomStopWordsRemovingTransform(env, Stopwords, DataFile, StopwordsColumn, Loader, column.Select(x => (x.Source, x.Name)).ToArray()).Transform(input) as IDataTransform; + return new CustomStopWordsRemovingTransformer(env, Stopwords, DataFile, StopwordsColumn, Loader, column.Select(x => (x.Source, x.Name)).ToArray()).Transform(input) as IDataTransform; } } @@ -665,7 +665,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(CustomStopWordsRemovingTransform).Assembly.FullName); + loaderAssemblyName: typeof(CustomStopWordsRemovingTransformer).Assembly.FullName); } private const string StopwordsManagerLoaderSignature = "CustomStopWordsManager"; @@ -678,7 +678,7 @@ private static VersionInfo GetStopwordsManagerVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: StopwordsManagerLoaderSignature, - loaderAssemblyName: typeof(CustomStopWordsRemovingTransform).Assembly.FullName); + loaderAssemblyName: typeof(CustomStopWordsRemovingTransformer).Assembly.FullName); } private static readonly ColumnType _outputType = new VectorType(TextType.Instance); @@ -808,7 +808,7 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory stopwords, string d /// The environment. /// Array of words to remove. /// Pairs of columns to remove stop words from. - public CustomStopWordsRemovingTransform(IHostEnvironment env, string[] stopwords, params (string input, string output)[] columns) : + public CustomStopWordsRemovingTransformer(IHostEnvironment env, string[] stopwords, params (string input, string output)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns) { _stopWordsMap = new NormStr.Pool(); @@ -826,7 +826,7 @@ public CustomStopWordsRemovingTransform(IHostEnvironment env, string[] stopwords } } - internal CustomStopWordsRemovingTransform(IHostEnvironment env, string stopwords, + internal CustomStopWordsRemovingTransformer(IHostEnvironment env, string stopwords, string dataFile, string stopwordsColumn, IComponentFactory loader, params (string input, string output)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns) { @@ -874,7 +874,7 @@ public override void Save(ModelSaveContext ctx) }); } - private CustomStopWordsRemovingTransform(IHost host, ModelLoadContext ctx) : + private CustomStopWordsRemovingTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) { var columnsLength = ColumnPairs.Length; @@ -919,13 +919,13 @@ private CustomStopWordsRemovingTransform(IHost host, ModelLoadContext ctx) : } // Factory method for SignatureLoadModel. - private static CustomStopWordsRemovingTransform Create(IHostEnvironment env, ModelLoadContext ctx) + private static CustomStopWordsRemovingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); var host = env.Register(RegistrationName); host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new CustomStopWordsRemovingTransform(host, ctx); + return new CustomStopWordsRemovingTransformer(host, ctx); } // Factory method for SignatureDataTransform. @@ -942,11 +942,11 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat var item = args.Column[i]; cols[i] = (item.Source ?? item.Name, item.Name); } - CustomStopWordsRemovingTransform transfrom = null; + CustomStopWordsRemovingTransformer transfrom = null; if (Utils.Size(args.Stopword) > 0) - transfrom = new CustomStopWordsRemovingTransform(env, args.Stopword, cols); + transfrom = new CustomStopWordsRemovingTransformer(env, args.Stopword, cols); else - transfrom = new CustomStopWordsRemovingTransform(env, args.Stopwords, args.DataFile, args.StopwordsColumn, args.Loader, cols); + transfrom = new CustomStopWordsRemovingTransformer(env, args.Stopwords, args.DataFile, args.StopwordsColumn, args.Loader, cols); return transfrom.MakeDataTransform(input); } @@ -963,9 +963,9 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sch private sealed class Mapper : OneToOneMapperBase { private readonly ColumnType[] _types; - private readonly CustomStopWordsRemovingTransform _parent; + private readonly CustomStopWordsRemovingTransformer _parent; - public Mapper(CustomStopWordsRemovingTransform parent, Schema inputSchema) + public Mapper(CustomStopWordsRemovingTransformer parent, Schema inputSchema) : base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; @@ -1036,7 +1036,7 @@ protected override Delegate MakeGetter(Row input, int iinfo, Func act /// This is usually applied after tokenizing text, so it compares individual tokens /// (case-insensitive comparison) to the stopwords. /// - public sealed class CustomStopWordsRemovingEstimator : TrivialEstimator + public sealed class CustomStopWordsRemovingEstimator : TrivialEstimator { internal const string ExpectedColumnType = "vector of Text type"; @@ -1061,7 +1061,7 @@ public CustomStopWordsRemovingEstimator(IHostEnvironment env, string inputColumn /// Pairs of columns to remove stop words on. /// Array of words to remove. public CustomStopWordsRemovingEstimator(IHostEnvironment env, (string input, string output)[] columns, string[] stopwords) : - base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CustomStopWordsRemovingEstimator)), new CustomStopWordsRemovingTransform(env, stopwords, columns)) + base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CustomStopWordsRemovingEstimator)), new CustomStopWordsRemovingTransformer(env, stopwords, columns)) { } diff --git a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs index 5345cfdc90..2ac24d981d 100644 --- a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs @@ -24,14 +24,14 @@ public CategoricalTests(ITestOutputHelper output) : base(output) { } - private class TestClass + private sealed class TestClass { public int A; public int B; public int C; } - private class TestMeta + private sealed class TestMeta { [VectorType(2)] public string[] A; @@ -47,6 +47,11 @@ private class TestMeta public string H; } + private sealed class TestStringClass + { + public string A; + } + [Fact] public void CategoricalWorkout() { @@ -98,6 +103,39 @@ public void CategoricalOneHotEncoding() Done(); } + /// + /// In which we take a categorical value and map it to a vector, but we get the mapping from a side data view + /// rather than the data we are fitting. + /// + [Fact] + public void CategoricalOneHotEncodingFromSideData() + { + // In this case, whatever the value of the input, the term mapping should come from the optional side data if specified. + var data = new[] { new TestStringClass() { A = "Stay" }, new TestStringClass() { A = "awhile and listen" } }; + + var mlContext = new MLContext(); + var dataView = mlContext.Data.ReadFromEnumerable(data); + + var sideDataBuilder = new ArrayDataViewBuilder(mlContext); + sideDataBuilder.AddColumn("Hello", "hello", "my", "friend"); + var sideData = sideDataBuilder.GetDataView(); + + var ci = new OneHotEncodingEstimator.ColumnInfo("A", "CatA", OneHotEncodingTransformer.OutputKind.Bag); + var pipe = new OneHotEncodingEstimator(mlContext, new[] { ci }, sideData); + + var output = pipe.Fit(dataView).Transform(dataView); + + VBuffer> slotNames = default; + output.Schema["CatA"].GetSlotNames(ref slotNames); + + Assert.Equal(3, slotNames.Length); + Assert.Equal("hello", slotNames.GetItemOrDefault(0).ToString()); + Assert.Equal("my", slotNames.GetItemOrDefault(1).ToString()); + Assert.Equal("friend", slotNames.GetItemOrDefault(2).ToString()); + + Done(); + } + [Fact] public void CategoricalStatic() { diff --git a/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs b/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs index 720f60f86e..60167e302b 100644 --- a/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs @@ -25,7 +25,7 @@ public ConvertTests(ITestOutputHelper output) : base(output) { } - private class TestPrimitiveClass + private sealed class TestPrimitiveClass { [VectorType(2)] public string[] AA; @@ -53,20 +53,23 @@ private class TestPrimitiveClass public double[] AN; } - private class TestClass + private sealed class TestClass { public int A; [VectorType(2)] public int[] B; } - public class MetaClass + private sealed class MetaClass { public float A; public string B; - } + private sealed class TestStringClass + { + public string A; + } [Fact] public void TestConvertWorkout() @@ -142,6 +145,40 @@ public void TestConvertWorkout() Done(); } + /// + /// Apply with side data. + /// + [Fact] + public void ValueToKeyFromSideData() + { + // In this case, whatever the value of the input, the term mapping should come from the optional side data if specified. + var data = new[] { new TestStringClass() { A = "Stay" }, new TestStringClass() { A = "awhile and listen" } }; + + var mlContext = new MLContext(); + var dataView = mlContext.Data.ReadFromEnumerable(data); + + var sideDataBuilder = new ArrayDataViewBuilder(mlContext); + sideDataBuilder.AddColumn("Hello", "hello", "my", "friend"); + var sideData = sideDataBuilder.GetDataView(); + + // For some reason the column info is on the *transformer*, not the estimator. Already tracked as issue #1760. + var ci = new ValueToKeyMappingTransformer.ColumnInfo("A", "CatA"); + var pipe = mlContext.Transforms.Conversion.MapValueToKey(new[] { ci }, sideData); + var output = pipe.Fit(dataView).Transform(dataView); + + VBuffer> slotNames = default; + output.Schema["CatA"].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref slotNames); + + Assert.Equal(3, slotNames.Length); + Assert.Equal("hello", slotNames.GetItemOrDefault(0).ToString()); + Assert.Equal("my", slotNames.GetItemOrDefault(1).ToString()); + Assert.Equal("friend", slotNames.GetItemOrDefault(2).ToString()); + + Done(); + } + + + [Fact] public void TestCommandLine() {