From 92e6f2da0bc1f9e1ba4dc7fb4009219e4435f361 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Sun, 3 Feb 2019 18:17:27 +0000 Subject: [PATCH 01/10] lda --- .../LdaStaticExtensions.cs | 4 +- .../EntryPoints/TextAnalytics.cs | 4 +- .../Text/LdaTransform.cs | 376 +++++++++--------- .../Text/TextCatalog.cs | 2 +- .../Common/EntryPoints/core_ep-list.tsv | 2 +- .../DataPipe/TestDataPipe.cs | 8 +- 6 files changed, 196 insertions(+), 200 deletions(-) diff --git a/src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs index f87e5ef4a1..22d17db2ae 100644 --- a/src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs @@ -101,13 +101,13 @@ public override IEstimator Reconcile(IHostEnvironment env, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) { - var infos = new LatentDirichletAllocationTransformer.ColumnInfo[toOutput.Length]; + var infos = new LatentDirichletAllocationEstimator.ColumnInfo[toOutput.Length]; Action onFit = null; for (int i = 0; i < toOutput.Length; ++i) { var tcol = (ILdaCol)toOutput[i]; - infos[i] = new LatentDirichletAllocationTransformer.ColumnInfo(outputNames[toOutput[i]], + infos[i] = new LatentDirichletAllocationEstimator.ColumnInfo(outputNames[toOutput[i]], inputNames[tcol.Input], tcol.Config.NumTopic, tcol.Config.AlphaSum, diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index 4f5caadc87..f5abedc092 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -114,13 +114,13 @@ public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, T Desc = LatentDirichletAllocationTransformer.Summary, UserName = LatentDirichletAllocationTransformer.UserName, ShortName = LatentDirichletAllocationTransformer.ShortName)] - public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LatentDirichletAllocationTransformer.Arguments input) + public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LatentDirichletAllocationTransformer.Options input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input); - var cols = input.Columns.Select(colPair => new LatentDirichletAllocationTransformer.ColumnInfo(colPair, input)).ToArray(); + var cols = input.Columns.Select(colPair => new LatentDirichletAllocationEstimator.ColumnInfo(colPair, input)).ToArray(); var est = new LatentDirichletAllocationEstimator(h, cols); var view = est.Fit(input.Data).Transform(input.Data); diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index e9f8e5a7a9..c6520c6d99 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -19,7 +19,7 @@ using Microsoft.ML.TextAnalytics; using Microsoft.ML.Transforms.Text; -[assembly: LoadableClass(LatentDirichletAllocationTransformer.Summary, typeof(IDataTransform), typeof(LatentDirichletAllocationTransformer), typeof(LatentDirichletAllocationTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(LatentDirichletAllocationTransformer.Summary, typeof(IDataTransform), typeof(LatentDirichletAllocationTransformer), typeof(LatentDirichletAllocationTransformer.Options), typeof(SignatureDataTransform), "Latent Dirichlet Allocation Transform", LatentDirichletAllocationTransformer.LoaderSignature, "Lda")] [assembly: LoadableClass(LatentDirichletAllocationTransformer.Summary, typeof(IDataTransform), typeof(LatentDirichletAllocationTransformer), null, typeof(SignatureLoadDataTransform), @@ -51,7 +51,7 @@ namespace Microsoft.ML.Transforms.Text /// public sealed class LatentDirichletAllocationTransformer : OneToOneTransformerBase { - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:srcs)", Name = "Column", ShortName = "col", SortOrder = 49)] public Column[] Columns; @@ -161,175 +161,6 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class ColumnInfo - { - public readonly string Name; - public readonly string InputColumnName; - public readonly int NumTopic; - public readonly float AlphaSum; - public readonly float Beta; - public readonly int MHStep; - public readonly int NumIter; - public readonly int LikelihoodInterval; - public readonly int NumThread; - public readonly int NumMaxDocToken; - public readonly int NumSummaryTermPerTopic; - public readonly int NumBurninIter; - public readonly bool ResetRandomGenerator; - - /// - /// Describes how the transformer handles one column pair. - /// - /// The column containing the output scores over a set of topics, represented as a vector of floats. - /// The column representing the document as a vector of floats.A null value for the column means is replaced. - /// The number of topics. - /// Dirichlet prior on document-topic vectors. - /// Dirichlet prior on vocab-topic vectors. - /// Number of Metropolis Hasting step. - /// Number of iterations. - /// Compute log likelihood over local dataset on this iteration interval. - /// The number of training threads. Default value depends on number of logical processors. - /// The threshold of maximum count of tokens per doc. - /// The number of words to summarize the topic. - /// The number of burn-in iterations. - /// Reset the random number generator for each document. - public ColumnInfo(string name, - string inputColumnName = null, - int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic, - float alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum, - float beta = LatentDirichletAllocationEstimator.Defaults.Beta, - int mhStep = LatentDirichletAllocationEstimator.Defaults.Mhstep, - int numIter = LatentDirichletAllocationEstimator.Defaults.NumIterations, - int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval, - int numThread = LatentDirichletAllocationEstimator.Defaults.NumThreads, - int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken, - int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic, - int numBurninIter = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, - bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator) - { - Contracts.CheckValue(name, nameof(name)); - Contracts.CheckValueOrNull(inputColumnName); - Contracts.CheckParam(numTopic > 0, nameof(numTopic), "Must be positive."); - Contracts.CheckParam(mhStep > 0, nameof(mhStep), "Must be positive."); - Contracts.CheckParam(numIter > 0, nameof(numIter), "Must be positive."); - Contracts.CheckParam(likelihoodInterval > 0, nameof(likelihoodInterval), "Must be positive."); - Contracts.CheckParam(numThread >= 0, nameof(numThread), "Must be positive or zero."); - Contracts.CheckParam(numMaxDocToken > 0, nameof(numMaxDocToken), "Must be positive."); - Contracts.CheckParam(numSummaryTermPerTopic > 0, nameof(numSummaryTermPerTopic), "Must be positive"); - Contracts.CheckParam(numBurninIter >= 0, nameof(numBurninIter), "Must be non-negative."); - - Name = name; - InputColumnName = inputColumnName ?? name; - NumTopic = numTopic; - AlphaSum = alphaSum; - Beta = beta; - MHStep = mhStep; - NumIter = numIter; - LikelihoodInterval = likelihoodInterval; - NumThread = numThread; - NumMaxDocToken = numMaxDocToken; - NumSummaryTermPerTopic = numSummaryTermPerTopic; - NumBurninIter = numBurninIter; - ResetRandomGenerator = resetRandomGenerator; - } - - internal ColumnInfo(Column item, Arguments args) : - this(item.Name, - item.Source ?? item.Name, - item.NumTopic ?? args.NumTopic, - item.AlphaSum ?? args.AlphaSum, - item.Beta ?? args.Beta, - item.Mhstep ?? args.Mhstep, - item.NumIterations ?? args.NumIterations, - item.LikelihoodInterval ?? args.LikelihoodInterval, - item.NumThreads ?? args.NumThreads, - item.NumMaxDocToken ?? args.NumMaxDocToken, - item.NumSummaryTermPerTopic ?? args.NumSummaryTermPerTopic, - item.NumBurninIterations ?? args.NumBurninIterations, - item.ResetRandomGenerator ?? args.ResetRandomGenerator) - { - } - - internal ColumnInfo(IExceptionContext ectx, ModelLoadContext ctx) - { - Contracts.AssertValue(ectx); - ectx.AssertValue(ctx); - - // *** Binary format *** - // int NumTopic; - // float AlphaSum; - // float Beta; - // int MHStep; - // int NumIter; - // int LikelihoodInterval; - // int NumThread; - // int NumMaxDocToken; - // int NumSummaryTermPerTopic; - // int NumBurninIter; - // byte ResetRandomGenerator; - - NumTopic = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumTopic > 0); - - AlphaSum = ctx.Reader.ReadSingle(); - - Beta = ctx.Reader.ReadSingle(); - - MHStep = ctx.Reader.ReadInt32(); - ectx.CheckDecode(MHStep > 0); - - NumIter = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumIter > 0); - - LikelihoodInterval = ctx.Reader.ReadInt32(); - ectx.CheckDecode(LikelihoodInterval > 0); - - NumThread = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumThread >= 0); - - NumMaxDocToken = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumMaxDocToken > 0); - - NumSummaryTermPerTopic = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumSummaryTermPerTopic > 0); - - NumBurninIter = ctx.Reader.ReadInt32(); - ectx.CheckDecode(NumBurninIter >= 0); - - ResetRandomGenerator = ctx.Reader.ReadBoolByte(); - } - - internal void Save(ModelSaveContext ctx) - { - Contracts.AssertValue(ctx); - - // *** Binary format *** - // int NumTopic; - // float AlphaSum; - // float Beta; - // int MHStep; - // int NumIter; - // int LikelihoodInterval; - // int NumThread; - // int NumMaxDocToken; - // int NumSummaryTermPerTopic; - // int NumBurninIter; - // byte ResetRandomGenerator; - - ctx.Writer.Write(NumTopic); - ctx.Writer.Write(AlphaSum); - ctx.Writer.Write(Beta); - ctx.Writer.Write(MHStep); - ctx.Writer.Write(NumIter); - ctx.Writer.Write(LikelihoodInterval); - ctx.Writer.Write(NumThread); - ctx.Writer.Write(NumMaxDocToken); - ctx.Writer.Write(NumSummaryTermPerTopic); - ctx.Writer.Write(NumBurninIter); - ctx.Writer.WriteBoolByte(ResetRandomGenerator); - } - } - /// /// Provide details about the topics discovered by LightLDA. /// @@ -365,7 +196,7 @@ internal LdaSummary GetLdaDetails(int iinfo) private sealed class LdaState : IDisposable { - internal readonly ColumnInfo InfoEx; + internal readonly LatentDirichletAllocationEstimator.ColumnInfo InfoEx; private readonly int _numVocab; private readonly object _preparationSyncRoot; private readonly object _testSyncRoot; @@ -378,7 +209,7 @@ private LdaState() _testSyncRoot = new object(); } - internal LdaState(IExceptionContext ectx, ColumnInfo ex, int numVocab) + internal LdaState(IExceptionContext ectx, LatentDirichletAllocationEstimator.ColumnInfo ex, int numVocab) : this() { Contracts.AssertValue(ectx); @@ -415,7 +246,7 @@ internal LdaState(IExceptionContext ectx, ModelLoadContext ctx) // (serializing term by term, for one term) // int: term_id, int: topic_num, KeyValuePair[]: termTopicVector - InfoEx = new ColumnInfo(ectx, ctx); + InfoEx = new LatentDirichletAllocationEstimator.ColumnInfo(ectx, ctx); _numVocab = ctx.Reader.ReadInt32(); ectx.CheckDecode(_numVocab > 0); @@ -771,7 +602,7 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(LatentDirichletAllocationTransformer).Assembly.FullName); } - private readonly ColumnInfo[] _columns; + private readonly LatentDirichletAllocationEstimator.ColumnInfo[] _columns; private readonly LdaState[] _ldas; private readonly List>> _columnMappings; @@ -781,7 +612,7 @@ private static VersionInfo GetVersionInfo() internal const string UserName = "Latent Dirichlet Allocation Transform"; internal const string ShortName = "LightLda"; - private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(LatentDirichletAllocationEstimator.ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); @@ -797,7 +628,7 @@ private static (string outputColumnName, string inputColumnName)[] GetColumnPair private LatentDirichletAllocationTransformer(IHostEnvironment env, LdaState[] ldas, List>> columnMappings, - params ColumnInfo[] columns) + params LatentDirichletAllocationEstimator.ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LatentDirichletAllocationTransformer)), GetColumnPairs(columns)) { Host.AssertNonEmpty(ColumnPairs); @@ -817,7 +648,7 @@ private LatentDirichletAllocationTransformer(IHost host, ModelLoadContext ctx) : // Note: columnsLength would be just one in most cases. var columnsLength = ColumnPairs.Length; - _columns = new ColumnInfo[columnsLength]; + _columns = new LatentDirichletAllocationEstimator.ColumnInfo[columnsLength]; _ldas = new LdaState[columnsLength]; for (int i = 0; i < _ldas.Length; i++) { @@ -826,7 +657,7 @@ private LatentDirichletAllocationTransformer(IHost host, ModelLoadContext ctx) : } } - internal static LatentDirichletAllocationTransformer TrainLdaTransformer(IHostEnvironment env, IDataView inputData, params ColumnInfo[] columns) + internal static LatentDirichletAllocationTransformer TrainLdaTransformer(IHostEnvironment env, IDataView inputData, params LatentDirichletAllocationEstimator.ColumnInfo[] columns) { var ldas = new LdaState[columns.Length]; @@ -869,14 +700,14 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sch => Create(env, ctx).MakeRowMapper(inputSchema); // Factory method for SignatureDataTransform. - private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Columns, nameof(args.Columns)); + env.CheckValue(options.Columns, nameof(options.Columns)); - var cols = args.Columns.Select(colPair => new ColumnInfo(colPair, args)).ToArray(); + var cols = options.Columns.Select(colPair => new LatentDirichletAllocationEstimator.ColumnInfo(colPair, options)).ToArray(); return TrainLdaTransformer(env, input, cols).MakeDataTransform(input); } @@ -929,7 +760,7 @@ private static int GetFrequency(double value) return result; } - private static List>> Train(IHostEnvironment env, IChannel ch, IDataView inputData, LdaState[] states, params ColumnInfo[] columns) + private static List>> Train(IHostEnvironment env, IChannel ch, IDataView inputData, LdaState[] states, params LatentDirichletAllocationEstimator.ColumnInfo[] columns) { env.AssertValue(ch); ch.AssertValue(inputData); @@ -1104,7 +935,7 @@ internal static class Defaults } private readonly IHost _host; - private readonly ImmutableArray _columns; + private readonly ImmutableArray _columns; /// /// The environment. @@ -1121,7 +952,7 @@ internal static class Defaults /// The number of words to summarize the topic. /// The number of burn-in iterations. /// Reset the random number generator for each document. - public LatentDirichletAllocationEstimator(IHostEnvironment env, + internal LatentDirichletAllocationEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int numTopic = Defaults.NumTopic, float alphaSum = Defaults.AlphaSum, @@ -1134,7 +965,7 @@ public LatentDirichletAllocationEstimator(IHostEnvironment env, int numSummaryTermPerTopic = Defaults.NumSummaryTermPerTopic, int numBurninIterations = Defaults.NumBurninIterations, bool resetRandomGenerator = Defaults.ResetRandomGenerator) - : this(env, new[] { new LatentDirichletAllocationTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, + : this(env, new[] { new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator) }) { } @@ -1142,13 +973,182 @@ public LatentDirichletAllocationEstimator(IHostEnvironment env, /// /// The environment. /// Describes the parameters of the LDA process for each column pair. - public LatentDirichletAllocationEstimator(IHostEnvironment env, params LatentDirichletAllocationTransformer.ColumnInfo[] columns) + internal LatentDirichletAllocationEstimator(IHostEnvironment env, params ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(LatentDirichletAllocationEstimator)); _columns = columns.ToImmutableArray(); } + public sealed class ColumnInfo + { + public readonly string Name; + public readonly string InputColumnName; + public readonly int NumTopic; + public readonly float AlphaSum; + public readonly float Beta; + public readonly int MHStep; + public readonly int NumIter; + public readonly int LikelihoodInterval; + public readonly int NumThread; + public readonly int NumMaxDocToken; + public readonly int NumSummaryTermPerTopic; + public readonly int NumBurninIter; + public readonly bool ResetRandomGenerator; + + /// + /// Describes how the transformer handles one column pair. + /// + /// The column containing the output scores over a set of topics, represented as a vector of floats. + /// The column representing the document as a vector of floats.A null value for the column means is replaced. + /// The number of topics. + /// Dirichlet prior on document-topic vectors. + /// Dirichlet prior on vocab-topic vectors. + /// Number of Metropolis Hasting step. + /// Number of iterations. + /// Compute log likelihood over local dataset on this iteration interval. + /// The number of training threads. Default value depends on number of logical processors. + /// The threshold of maximum count of tokens per doc. + /// The number of words to summarize the topic. + /// The number of burn-in iterations. + /// Reset the random number generator for each document. + public ColumnInfo(string name, + string inputColumnName = null, + int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic, + float alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum, + float beta = LatentDirichletAllocationEstimator.Defaults.Beta, + int mhStep = LatentDirichletAllocationEstimator.Defaults.Mhstep, + int numIter = LatentDirichletAllocationEstimator.Defaults.NumIterations, + int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval, + int numThread = LatentDirichletAllocationEstimator.Defaults.NumThreads, + int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken, + int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic, + int numBurninIter = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, + bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator) + { + Contracts.CheckValue(name, nameof(name)); + Contracts.CheckValueOrNull(inputColumnName); + Contracts.CheckParam(numTopic > 0, nameof(numTopic), "Must be positive."); + Contracts.CheckParam(mhStep > 0, nameof(mhStep), "Must be positive."); + Contracts.CheckParam(numIter > 0, nameof(numIter), "Must be positive."); + Contracts.CheckParam(likelihoodInterval > 0, nameof(likelihoodInterval), "Must be positive."); + Contracts.CheckParam(numThread >= 0, nameof(numThread), "Must be positive or zero."); + Contracts.CheckParam(numMaxDocToken > 0, nameof(numMaxDocToken), "Must be positive."); + Contracts.CheckParam(numSummaryTermPerTopic > 0, nameof(numSummaryTermPerTopic), "Must be positive"); + Contracts.CheckParam(numBurninIter >= 0, nameof(numBurninIter), "Must be non-negative."); + + Name = name; + InputColumnName = inputColumnName ?? name; + NumTopic = numTopic; + AlphaSum = alphaSum; + Beta = beta; + MHStep = mhStep; + NumIter = numIter; + LikelihoodInterval = likelihoodInterval; + NumThread = numThread; + NumMaxDocToken = numMaxDocToken; + NumSummaryTermPerTopic = numSummaryTermPerTopic; + NumBurninIter = numBurninIter; + ResetRandomGenerator = resetRandomGenerator; + } + + internal ColumnInfo(LatentDirichletAllocationTransformer.Column item, LatentDirichletAllocationTransformer.Options options) : + this(item.Name, + item.Source ?? item.Name, + item.NumTopic ?? options.NumTopic, + item.AlphaSum ?? options.AlphaSum, + item.Beta ?? options.Beta, + item.Mhstep ?? options.Mhstep, + item.NumIterations ?? options.NumIterations, + item.LikelihoodInterval ?? options.LikelihoodInterval, + item.NumThreads ?? options.NumThreads, + item.NumMaxDocToken ?? options.NumMaxDocToken, + item.NumSummaryTermPerTopic ?? options.NumSummaryTermPerTopic, + item.NumBurninIterations ?? options.NumBurninIterations, + item.ResetRandomGenerator ?? options.ResetRandomGenerator) + { + } + + internal ColumnInfo(IExceptionContext ectx, ModelLoadContext ctx) + { + Contracts.AssertValue(ectx); + ectx.AssertValue(ctx); + + // *** Binary format *** + // int NumTopic; + // float AlphaSum; + // float Beta; + // int MHStep; + // int NumIter; + // int LikelihoodInterval; + // int NumThread; + // int NumMaxDocToken; + // int NumSummaryTermPerTopic; + // int NumBurninIter; + // byte ResetRandomGenerator; + + NumTopic = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumTopic > 0); + + AlphaSum = ctx.Reader.ReadSingle(); + + Beta = ctx.Reader.ReadSingle(); + + MHStep = ctx.Reader.ReadInt32(); + ectx.CheckDecode(MHStep > 0); + + NumIter = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumIter > 0); + + LikelihoodInterval = ctx.Reader.ReadInt32(); + ectx.CheckDecode(LikelihoodInterval > 0); + + NumThread = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumThread >= 0); + + NumMaxDocToken = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumMaxDocToken > 0); + + NumSummaryTermPerTopic = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumSummaryTermPerTopic > 0); + + NumBurninIter = ctx.Reader.ReadInt32(); + ectx.CheckDecode(NumBurninIter >= 0); + + ResetRandomGenerator = ctx.Reader.ReadBoolByte(); + } + + internal void Save(ModelSaveContext ctx) + { + Contracts.AssertValue(ctx); + + // *** Binary format *** + // int NumTopic; + // float AlphaSum; + // float Beta; + // int MHStep; + // int NumIter; + // int LikelihoodInterval; + // int NumThread; + // int NumMaxDocToken; + // int NumSummaryTermPerTopic; + // int NumBurninIter; + // byte ResetRandomGenerator; + + ctx.Writer.Write(NumTopic); + ctx.Writer.Write(AlphaSum); + ctx.Writer.Write(Beta); + ctx.Writer.Write(MHStep); + ctx.Writer.Write(NumIter); + ctx.Writer.Write(LikelihoodInterval); + ctx.Writer.Write(NumThread); + ctx.Writer.Write(NumMaxDocToken); + ctx.Writer.Write(NumSummaryTermPerTopic); + ctx.Writer.Write(NumBurninIter); + ctx.Writer.WriteBoolByte(ResetRandomGenerator); + } + } + /// /// Returns the schema that would be produced by the transformation. /// diff --git a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs index 176ece41ba..657113065d 100644 --- a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs @@ -622,7 +622,7 @@ public static LatentDirichletAllocationEstimator LatentDirichletAllocation(this /// Describes the parameters of LDA for each column pair. public static LatentDirichletAllocationEstimator LatentDirichletAllocation( this TransformsCatalog.TextTransforms catalog, - params LatentDirichletAllocationTransformer.ColumnInfo[] columns) + params LatentDirichletAllocationEstimator.ColumnInfo[] columns) => new LatentDirichletAllocationEstimator(CatalogUtils.GetEnvironment(catalog), columns); } } diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 8525c932d4..4742e033ec 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -100,7 +100,7 @@ Transforms.KeyToTextConverter KeyToValueTransform utilizes KeyValues metadata to Transforms.LabelColumnKeyBooleanConverter Transforms the label to either key or bool (if needed) to make it suitable for classification. Microsoft.ML.EntryPoints.FeatureCombiner PrepareClassificationLabel Microsoft.ML.EntryPoints.FeatureCombiner+ClassificationLabelInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.LabelIndicator Label remapper used by OVA Microsoft.ML.Transforms.LabelIndicatorTransform LabelIndicator Microsoft.ML.Transforms.LabelIndicatorTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.LabelToFloatConverter Transforms the label to float to make it suitable for regression. Microsoft.ML.EntryPoints.FeatureCombiner PrepareRegressionLabel Microsoft.ML.EntryPoints.FeatureCombiner+RegressionLabelInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.LightLda The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation. Microsoft.ML.Transforms.Text.TextAnalytics LightLda Microsoft.ML.Transforms.Text.LatentDirichletAllocationTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.LightLda The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation. Microsoft.ML.Transforms.Text.TextAnalytics LightLda Microsoft.ML.Transforms.Text.LatentDirichletAllocationTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.LogMeanVarianceNormalizer Normalizes the data based on the computed mean and variance of the logarithm of the data. Microsoft.ML.Data.Normalize LogMeanVar Microsoft.ML.Transforms.Normalizers.NormalizeTransform+LogMeanVarArguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.LpNormalizer Normalize vectors (rows) individually by rescaling them to unit norm (L2, L1 or LInf). Performs the following operation on a vector X: Y = (X - M) / D, where M is mean and D is either L2 norm, L1 norm or LInf norm. Microsoft.ML.Transforms.Projections.LpNormalization Normalize Microsoft.ML.Transforms.Projections.LpNormalizingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ManyHeterogeneousModelCombiner Combines a sequence of TransformModels and a PredictorModel into a single PredictorModel. Microsoft.ML.EntryPoints.ModelOperations CombineModels Microsoft.ML.EntryPoints.ModelOperations+PredictorModelInput Microsoft.ML.EntryPoints.ModelOperations+PredictorModelOutput diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 95af6fbd73..2489efaaaf 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -1324,7 +1324,7 @@ public void TestLDATransform() builder.AddColumn("F1V", NumberType.Float, data); var srcView = builder.GetDataView(); - var est = new LatentDirichletAllocationEstimator(Env, "F1V", numTopic: 3, numSummaryTermPerTopic: 3, alphaSum: 3, numThreads: 1, resetRandomGenerator: true); + var est = ML.Transforms.Text.LatentDirichletAllocation("F1V", numTopic: 3, numSummaryTermPerTopic: 3, alphaSum: 3, numThreads: 1, resetRandomGenerator: true); var ldaTransformer = est.Fit(srcView); var transformedData = ldaTransformer.Transform(srcView); @@ -1376,14 +1376,10 @@ public void TestLdaTransformerEmptyDocumentException() { Source = "Zeros", }; - var args = new LatentDirichletAllocationTransformer.Arguments() - { - Columns = new[] { col } - }; try { - var lda = new LatentDirichletAllocationEstimator(Env, "Zeros").Fit(srcView).Transform(srcView); + var lda = ML.Transforms.Text.LatentDirichletAllocation("Zeros").Fit(srcView).Transform(srcView); } catch (InvalidOperationException ex) { From e4543fed721911ecce836351827bc62c8d547c1e Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 4 Feb 2019 19:30:58 +0000 Subject: [PATCH 02/10] WordEmbeddingsExtractingEstimator --- .../WordEmbeddingsStaticExtensions.cs | 4 +- .../EntryPoints/TextAnalytics.cs | 2 +- .../Text/TextCatalog.cs | 2 +- .../Text/WordEmbeddingsExtractor.cs | 88 +++++++++---------- .../Common/EntryPoints/core_ep-list.tsv | 2 +- .../UnitTests/TestEntryPoints.cs | 2 +- 6 files changed, 50 insertions(+), 50 deletions(-) diff --git a/src/Microsoft.ML.StaticPipe/WordEmbeddingsStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/WordEmbeddingsStaticExtensions.cs index 1c0908d9ae..42ab3ea3a7 100644 --- a/src/Microsoft.ML.StaticPipe/WordEmbeddingsStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/WordEmbeddingsStaticExtensions.cs @@ -72,11 +72,11 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); - var cols = new WordEmbeddingsExtractingTransformer.ColumnInfo[toOutput.Length]; + var cols = new WordEmbeddingsExtractingEstimator.ColumnInfo[toOutput.Length]; for (int i = 0; i < toOutput.Length; ++i) { var outCol = (OutColumn)toOutput[i]; - cols[i] = new WordEmbeddingsExtractingTransformer.ColumnInfo(outputNames[outCol], inputNames[outCol.Input]); + cols[i] = new WordEmbeddingsExtractingEstimator.ColumnInfo(outputNames[outCol], inputNames[outCol.Input]); } bool customLookup = !string.IsNullOrWhiteSpace(_customLookupTable); diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index f5abedc092..2eed444b58 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -135,7 +135,7 @@ public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, Laten Desc = WordEmbeddingsExtractingTransformer.Summary, UserName = WordEmbeddingsExtractingTransformer.UserName, ShortName = WordEmbeddingsExtractingTransformer.ShortName)] - public static CommonOutputs.TransformOutput WordEmbeddings(IHostEnvironment env, WordEmbeddingsExtractingTransformer.Arguments input) + public static CommonOutputs.TransformOutput WordEmbeddings(IHostEnvironment env, WordEmbeddingsExtractingTransformer.Options input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); diff --git a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs index 657113065d..84daeff936 100644 --- a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs @@ -146,7 +146,7 @@ public static WordEmbeddingsExtractingEstimator ExtractWordEmbeddings(this Trans /// public static WordEmbeddingsExtractingEstimator ExtractWordEmbeddings(this TransformsCatalog.TextTransforms catalog, WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe, - params WordEmbeddingsExtractingTransformer.ColumnInfo[] columns) + params WordEmbeddingsExtractingEstimator.ColumnInfo[] columns) => new WordEmbeddingsExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), modelKind, columns); /// diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index 07b2fb23f4..c6874c36e2 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -21,7 +21,7 @@ using Microsoft.ML.Model.Onnx; using Microsoft.ML.Transforms.Text; -[assembly: LoadableClass(WordEmbeddingsExtractingTransformer.Summary, typeof(IDataTransform), typeof(WordEmbeddingsExtractingTransformer), typeof(WordEmbeddingsExtractingTransformer.Arguments), +[assembly: LoadableClass(WordEmbeddingsExtractingTransformer.Summary, typeof(IDataTransform), typeof(WordEmbeddingsExtractingTransformer), typeof(WordEmbeddingsExtractingTransformer.Options), typeof(SignatureDataTransform), WordEmbeddingsExtractingTransformer.UserName, "WordEmbeddingsTransform", WordEmbeddingsExtractingTransformer.ShortName, DocName = "transform/WordEmbeddingsTransform.md")] [assembly: LoadableClass(WordEmbeddingsExtractingTransformer.Summary, typeof(IDataTransform), typeof(WordEmbeddingsExtractingTransformer), null, typeof(SignatureLoadDataTransform), @@ -57,7 +57,7 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 0)] public Column[] Columns; @@ -148,23 +148,6 @@ public List GetWordLabels() } - /// - /// Information for each column pair. - /// - public sealed class ColumnInfo - { - public readonly string Name; - public readonly string InputColumnName; - - public ColumnInfo(string name, string inputColumnName = null) - { - Contracts.CheckNonEmpty(name, nameof(name)); - - Name = name; - InputColumnName = inputColumnName ?? name; - } - } - private const string RegistrationName = "WordEmbeddings"; private const int Timeout = 10 * 60 * 1000; @@ -176,9 +159,9 @@ public ColumnInfo(string name, string inputColumnName = null) /// Name of the column resulting from the transformation of . /// Name of the column to transform. If set to , the value of the will be used as source. /// The pretrained word embedding model. - public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null, + internal WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null, PretrainedModelKind modelKind = PretrainedModelKind.Sswe) - : this(env, modelKind, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName)) + : this(env, modelKind, new WordEmbeddingsExtractingEstimator.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -189,8 +172,8 @@ public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputCo /// Name of the column resulting from the transformation of . /// Filename for custom word embedding model. /// Name of the column to transform. If set to , the value of the will be used as source. - public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputColumnName, string customModelFile, string inputColumnName = null) - : this(env, customModelFile, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName)) + internal WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputColumnName, string customModelFile, string inputColumnName = null) + : this(env, customModelFile, new WordEmbeddingsExtractingEstimator.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -200,7 +183,7 @@ public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputCo /// Host Environment. /// The pretrained word embedding model. /// Input/Output columns. - public WordEmbeddingsExtractingTransformer(IHostEnvironment env, PretrainedModelKind modelKind, params ColumnInfo[] columns) + internal WordEmbeddingsExtractingTransformer(IHostEnvironment env, PretrainedModelKind modelKind, params WordEmbeddingsExtractingEstimator.ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { env.CheckUserArg(Enum.IsDefined(typeof(PretrainedModelKind), modelKind), nameof(modelKind)); @@ -216,7 +199,7 @@ public WordEmbeddingsExtractingTransformer(IHostEnvironment env, PretrainedModel /// Host Environment. /// Filename for custom word embedding model. /// Input/Output columns. - public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string customModelFile, params ColumnInfo[] columns) + internal WordEmbeddingsExtractingTransformer(IHostEnvironment env, string customModelFile, params WordEmbeddingsExtractingEstimator.ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { env.CheckValue(customModelFile, nameof(customModelFile)); @@ -228,39 +211,39 @@ public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string customMo _currentVocab = GetVocabularyDictionary(env); } - private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(WordEmbeddingsExtractingEstimator.ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - if (args.ModelKind == null) - args.ModelKind = PretrainedModelKind.Sswe; - env.CheckUserArg(!args.ModelKind.HasValue || Enum.IsDefined(typeof(PretrainedModelKind), args.ModelKind), nameof(args.ModelKind)); + if (options.ModelKind == null) + options.ModelKind = PretrainedModelKind.Sswe; + env.CheckUserArg(!options.ModelKind.HasValue || Enum.IsDefined(typeof(PretrainedModelKind), options.ModelKind), nameof(options.ModelKind)); - env.CheckValue(args.Columns, nameof(args.Columns)); + env.CheckValue(options.Columns, nameof(options.Columns)); - var cols = new ColumnInfo[args.Columns.Length]; + var cols = new WordEmbeddingsExtractingEstimator.ColumnInfo[options.Columns.Length]; for (int i = 0; i < cols.Length; i++) { - var item = args.Columns[i]; - cols[i] = new ColumnInfo( + var item = options.Columns[i]; + cols[i] = new WordEmbeddingsExtractingEstimator.ColumnInfo( item.Name, item.Source ?? item.Name); } - bool customLookup = !string.IsNullOrWhiteSpace(args.CustomLookupTable); + bool customLookup = !string.IsNullOrWhiteSpace(options.CustomLookupTable); if (customLookup) - return new WordEmbeddingsExtractingTransformer(env, args.CustomLookupTable, cols).MakeDataTransform(input); + return new WordEmbeddingsExtractingTransformer(env, options.CustomLookupTable, cols).MakeDataTransform(input); else - return new WordEmbeddingsExtractingTransformer(env, args.ModelKind.Value, cols).MakeDataTransform(input); + return new WordEmbeddingsExtractingTransformer(env, options.ModelKind.Value, cols).MakeDataTransform(input); } private WordEmbeddingsExtractingTransformer(IHost host, ModelLoadContext ctx) @@ -284,7 +267,7 @@ private WordEmbeddingsExtractingTransformer(IHost host, ModelLoadContext ctx) _currentVocab = GetVocabularyDictionary(host); } - public static WordEmbeddingsExtractingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + internal static WordEmbeddingsExtractingTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); IHost h = env.Register(RegistrationName); @@ -786,7 +769,7 @@ private static ParallelOptions GetParallelOptions(IHostEnvironment hostEnvironme public sealed class WordEmbeddingsExtractingEstimator : IEstimator { private readonly IHost _host; - private readonly WordEmbeddingsExtractingTransformer.ColumnInfo[] _columns; + private readonly ColumnInfo[] _columns; private readonly WordEmbeddingsExtractingTransformer.PretrainedModelKind? _modelKind; private readonly string _customLookupTable; @@ -802,7 +785,7 @@ public sealed class WordEmbeddingsExtractingEstimator : IEstimatorThe embeddings to use. internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe) - : this(env, modelKind, new WordEmbeddingsExtractingTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName)) + : this(env, modelKind, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -817,7 +800,7 @@ internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string outputCo /// The path of the pre-trained embeedings model to use. /// Name of the column to transform. internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string outputColumnName, string customModelFile, string inputColumnName = null) - : this(env, customModelFile, new WordEmbeddingsExtractingTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName)) + : this(env, customModelFile, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -832,7 +815,7 @@ internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string outputCo /// The array columns, and per-column configurations to extract embeedings from. internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe, - params WordEmbeddingsExtractingTransformer.ColumnInfo[] columns) + params ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(WordEmbeddingsExtractingEstimator)); @@ -841,7 +824,7 @@ internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, _columns = columns; } - internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string customModelFile, params WordEmbeddingsExtractingTransformer.ColumnInfo[] columns) + internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string customModelFile, params ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(WordEmbeddingsExtractingEstimator)); @@ -850,6 +833,23 @@ internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string customMo _columns = columns; } + /// + /// Information for each column pair. + /// + public sealed class ColumnInfo + { + public readonly string Name; + public readonly string InputColumnName; + + public ColumnInfo(string name, string inputColumnName = null) + { + Contracts.CheckNonEmpty(name, nameof(name)); + + Name = name; + InputColumnName = inputColumnName ?? name; + } + } + public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 4742e033ec..3db7170d9b 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -133,5 +133,5 @@ Transforms.TrainTestDatasetSplitter Split the dataset into train and test sets M Transforms.TreeLeafFeaturizer Trains a tree ensemble, or loads it from a file, then maps a numeric feature vector to three outputs: 1. A vector containing the individual tree outputs of the tree ensemble. 2. A vector indicating the leaves that the feature vector falls on in the tree ensemble. 3. A vector indicating the paths that the feature vector falls on in the tree ensemble. If a both a model file and a trainer are specified - will use the model file. If neither are specified, will train a default FastTree model. This can handle key labels by training a regression model towards their optionally permuted indices. Microsoft.ML.Data.TreeFeaturize Featurizer Microsoft.ML.Data.TreeEnsembleFeaturizerTransform+ArgumentsForEntryPoint Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.TwoHeterogeneousModelCombiner Combines a TransformModel and a PredictorModel into a single PredictorModel. Microsoft.ML.EntryPoints.ModelOperations CombineTwoModels Microsoft.ML.EntryPoints.ModelOperations+SimplePredictorModelInput Microsoft.ML.EntryPoints.ModelOperations+PredictorModelOutput Transforms.VectorToImage Converts vector array into image type. Microsoft.ML.ImageAnalytics.EntryPoints.ImageAnalytics VectorToImage Microsoft.ML.ImageAnalytics.VectorToImageTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.WordEmbeddings Word Embeddings transform is a text featurizer which converts vectors of text tokens into sentence vectors using a pre-trained model Microsoft.ML.Transforms.Text.TextAnalytics WordEmbeddings Microsoft.ML.Transforms.Text.WordEmbeddingsExtractingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.WordEmbeddings Word Embeddings transform is a text featurizer which converts vectors of text tokens into sentence vectors using a pre-trained model Microsoft.ML.Transforms.Text.TextAnalytics WordEmbeddings Microsoft.ML.Transforms.Text.WordEmbeddingsExtractingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.WordTokenizer The input to this transform is text, and the output is a vector of text containing the words (tokens) in the original text. The separator is space, but can be specified as any other character (or multiple characters) if needed. Microsoft.ML.Transforms.Text.TextAnalytics DelimitedTokenizeTransform Microsoft.ML.Transforms.Text.WordTokenizingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index a75ac92051..6335259029 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -3629,7 +3629,7 @@ public void EntryPointWordEmbeddings() }, InputFile = inputFile, }).Data; - var embedding = Transforms.Text.TextAnalytics.WordEmbeddings(Env, new WordEmbeddingsExtractingTransformer.Arguments() + var embedding = Transforms.Text.TextAnalytics.WordEmbeddings(Env, new WordEmbeddingsExtractingTransformer.Options() { Data = dataView, Columns = new[] { new WordEmbeddingsExtractingTransformer.Column { Name = "Features", Source = "Text" } }, From 02ee301e331bed7dce23a728544517c16dfc12da Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 4 Feb 2019 20:29:46 +0000 Subject: [PATCH 03/10] TokenizingByCharactersEstimator --- .../EntryPoints/TextAnalytics.cs | 2 +- .../Text/TokenizingByCharacters.cs | 22 +++++++++---------- .../Common/EntryPoints/core_ep-list.tsv | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index 2eed444b58..0c56c88a73 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -96,7 +96,7 @@ public static CommonOutputs.TransformOutput AnalyzeSentiment(IHostEnvironment en Desc = TokenizingByCharactersTransformer.Summary, UserName = TokenizingByCharactersTransformer.UserName, ShortName = TokenizingByCharactersTransformer.LoaderSignature)] - public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, TokenizingByCharactersTransformer.Arguments input) + public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, TokenizingByCharactersTransformer.Options input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs index 9903c6e07c..64898a92f5 100644 --- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs +++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs @@ -17,7 +17,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Transforms.Text; -[assembly: LoadableClass(TokenizingByCharactersTransformer.Summary, typeof(IDataTransform), typeof(TokenizingByCharactersTransformer), typeof(TokenizingByCharactersTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(TokenizingByCharactersTransformer.Summary, typeof(IDataTransform), typeof(TokenizingByCharactersTransformer), typeof(TokenizingByCharactersTransformer.Options), typeof(SignatureDataTransform), TokenizingByCharactersTransformer.UserName, "CharTokenize", TokenizingByCharactersTransformer.LoaderSignature)] [assembly: LoadableClass(typeof(IDataTransform), typeof(TokenizingByCharactersTransformer), null, typeof(SignatureLoadDataTransform), @@ -53,7 +53,7 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; @@ -106,7 +106,7 @@ private static VersionInfo GetVersionInfo() /// The environment. /// Whether to use marker characters to separate words. /// Pairs of columns to run the tokenization on. - public TokenizingByCharactersTransformer(IHostEnvironment env, bool useMarkerCharacters = TokenizingByCharactersEstimator.Defaults.UseMarkerCharacters, + internal TokenizingByCharactersTransformer(IHostEnvironment env, bool useMarkerCharacters = TokenizingByCharactersEstimator.Defaults.UseMarkerCharacters, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns) { @@ -161,20 +161,20 @@ private static TokenizingByCharactersTransformer Create(IHostEnvironment env, Mo } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Columns, nameof(args.Columns)); - var cols = new (string outputColumnName, string inputColumnName)[args.Columns.Length]; + env.CheckValue(options.Columns, nameof(options.Columns)); + var cols = new (string outputColumnName, string inputColumnName)[options.Columns.Length]; for (int i = 0; i < cols.Length; i++) { - var item = args.Columns[i]; + var item = options.Columns[i]; cols[i] = (item.Name, item.Source ?? item.Name); } - return new TokenizingByCharactersTransformer(env, args.UseMarkerChars, cols).MakeDataTransform(input); + return new TokenizingByCharactersTransformer(env, options.UseMarkerChars, cols).MakeDataTransform(input); } // Factory method for SignatureLoadRowMapper. @@ -565,7 +565,7 @@ internal static class Defaults /// Name of the column resulting from the transformation of . /// Name of the column to transform. If set to , the value of the will be used as source. /// Whether to use marker characters to separate words. - public TokenizingByCharactersEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, + internal TokenizingByCharactersEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, bool useMarkerCharacters = Defaults.UseMarkerCharacters) : this(env, useMarkerCharacters, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }) { @@ -578,7 +578,7 @@ public TokenizingByCharactersEstimator(IHostEnvironment env, string outputColumn /// Whether to use marker characters to separate words. /// Pairs of columns to run the tokenization on. - public TokenizingByCharactersEstimator(IHostEnvironment env, bool useMarkerCharacters = Defaults.UseMarkerCharacters, + internal TokenizingByCharactersEstimator(IHostEnvironment env, bool useMarkerCharacters = Defaults.UseMarkerCharacters, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TokenizingByCharactersEstimator)), new TokenizingByCharactersTransformer(env, useMarkerCharacters, columns)) { diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 3db7170d9b..2a26303b8b 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -75,7 +75,7 @@ Transforms.BinaryPredictionScoreColumnsRenamer For binary prediction, it renames Transforms.BinNormalizer The values are assigned into equidensity bins and a value is mapped to its bin_number/number_of_bins. Microsoft.ML.Data.Normalize Bin Microsoft.ML.Transforms.Normalizers.NormalizeTransform+BinArguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.CategoricalHashOneHotVectorizer Converts the categorical value into an indicator array by hashing the value and using the hash as an index in the bag. If the input column is a vector, a single indicator bag is returned for it. Microsoft.ML.Transforms.Categorical.Categorical CatTransformHash Microsoft.ML.Transforms.Categorical.OneHotHashEncoding+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.CategoricalOneHotVectorizer Converts the categorical value into an indicator array by building a dictionary of categories based on the data and using the id in the dictionary as the index in the array. Microsoft.ML.Transforms.Categorical.Categorical CatTransformDict Microsoft.ML.Transforms.Categorical.OneHotEncodingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.CharacterTokenizer Character-oriented tokenizer where text is considered a sequence of characters. Microsoft.ML.Transforms.Text.TextAnalytics CharTokenize Microsoft.ML.Transforms.Text.TokenizingByCharactersTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.CharacterTokenizer Character-oriented tokenizer where text is considered a sequence of characters. Microsoft.ML.Transforms.Text.TextAnalytics CharTokenize Microsoft.ML.Transforms.Text.TokenizingByCharactersTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ColumnConcatenator Concatenates one or more columns of the same item type. Microsoft.ML.EntryPoints.SchemaManipulation ConcatColumns Microsoft.ML.Data.ColumnConcatenatingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ColumnCopier Duplicates columns from the dataset Microsoft.ML.EntryPoints.SchemaManipulation CopyColumns Microsoft.ML.Transforms.ColumnCopyingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ColumnSelector Selects a set of columns, dropping all others Microsoft.ML.EntryPoints.SchemaManipulation SelectColumns Microsoft.ML.Transforms.ColumnSelectingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput From ed579f4f95106a1e9a4451abac08c17f93a5ba6e Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Mon, 4 Feb 2019 22:50:26 +0000 Subject: [PATCH 04/10] WordTokenizingEstimator --- .../Dynamic/KeyToValueValueToKey.cs | 4 +- .../EntryPoints/TextAnalytics.cs | 2 +- .../Text/TextCatalog.cs | 2 +- .../Text/TextFeaturizingEstimator.cs | 4 +- .../Text/WordBagTransform.cs | 4 +- .../Text/WordHashBagProducingTransform.cs | 4 +- .../Text/WordTokenizing.cs | 80 +++++++++---------- .../Common/EntryPoints/core_ep-list.tsv | 2 +- .../Transformers/TextFeaturizerTests.cs | 2 +- .../Transformers/ValueMappingTests.cs | 4 +- .../Transformers/WordTokenizeTests.cs | 8 +- 11 files changed, 56 insertions(+), 60 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs index ceca716c34..02ff540c0d 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs @@ -31,7 +31,7 @@ public static void KeyToValueValueToKey() // making use of default settings. string defaultColumnName = "DefaultKeys"; // REVIEW create through the catalog extension - var default_pipeline = new WordTokenizingEstimator(ml, "Review") + var default_pipeline = ml.Transforms.Text.TokenizeWords("Review") .Append(ml.Transforms.Conversion.MapValueToKey(defaultColumnName, "Review")); // Another pipeline, that customizes the advanced settings of the ValueToKeyMappingEstimator. @@ -39,7 +39,7 @@ public static void KeyToValueValueToKey() // and condition the order in which they get evaluated by changing sort from the default Occurence (order in which they get encountered) // to value/alphabetically. string customizedColumnName = "CustomizedKeys"; - var customized_pipeline = new WordTokenizingEstimator(ml, "Review") + var customized_pipeline = ml.Transforms.Text.TokenizeWords("Review") .Append(ml.Transforms.Conversion.MapValueToKey(customizedColumnName, "Review", maxNumKeys: 10, sort: ValueToKeyMappingEstimator.SortOrder.Value)); // The transformed data. diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index 0c56c88a73..aa3f445594 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -36,7 +36,7 @@ public static CommonOutputs.TransformOutput TextTransform(IHostEnvironment env, Desc = ML.Transforms.Text.WordTokenizingTransformer.Summary, UserName = ML.Transforms.Text.WordTokenizingTransformer.UserName, ShortName = ML.Transforms.Text.WordTokenizingTransformer.LoaderSignature)] - public static CommonOutputs.TransformOutput DelimitedTokenizeTransform(IHostEnvironment env, WordTokenizingTransformer.Arguments input) + public static CommonOutputs.TransformOutput DelimitedTokenizeTransform(IHostEnvironment env, WordTokenizingTransformer.Options input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "DelimitedTokenizeTransform", input); var xf = ML.Transforms.Text.WordTokenizingTransformer.Create(h, input, input.Data); diff --git a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs index 84daeff936..624f9ee7a9 100644 --- a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs @@ -180,7 +180,7 @@ public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextT /// The text-related transform's catalog. /// Pairs of columns to run the tokenization on. public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextTransforms catalog, - params WordTokenizingTransformer.ColumnInfo[] columns) + params WordTokenizingEstimator.ColumnInfo[] columns) => new WordTokenizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns); /// diff --git a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs index f936ab9ad1..3b7c09a56f 100644 --- a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs +++ b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs @@ -328,11 +328,11 @@ public ITransformer Fit(IDataView input) if (tparams.NeedsWordTokenizationTransform) { - var xfCols = new WordTokenizingTransformer.ColumnInfo[textCols.Length]; + var xfCols = new WordTokenizingEstimator.ColumnInfo[textCols.Length]; wordTokCols = new string[textCols.Length]; for (int i = 0; i < textCols.Length; i++) { - var col = new WordTokenizingTransformer.ColumnInfo(GenerateColumnName(view.Schema, textCols[i], "WordTokenizer"), textCols[i]); + var col = new WordTokenizingEstimator.ColumnInfo(GenerateColumnName(view.Schema, textCols[i], "WordTokenizer"), textCols[i]); xfCols[i] = col; wordTokCols[i] = col.Name; tempCols.Add(col.Name); diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index b90aa03371..ab07436d91 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -124,7 +124,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // REVIEW: In order to make it possible to output separate bags for different columns // using the same dictionary, we need to find a way to make ConcatTransform remember the boundaries. - var tokenizeColumns = new WordTokenizingTransformer.ColumnInfo[args.Columns.Length]; + var tokenizeColumns = new WordTokenizingEstimator.ColumnInfo[args.Columns.Length]; var extractorArgs = new NgramExtractorTransform.Arguments() @@ -144,7 +144,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV h.CheckUserArg(Utils.Size(column.Source) > 0, nameof(column.Source)); h.CheckUserArg(column.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(column.Source)); - tokenizeColumns[iinfo] = new WordTokenizingTransformer.ColumnInfo(column.Name, column.Source.Length > 1 ? column.Name : column.Source[0]); + tokenizeColumns[iinfo] = new WordTokenizingEstimator.ColumnInfo(column.Name, column.Source.Length > 1 ? column.Name : column.Source[0]); extractorArgs.Columns[iinfo] = new NgramExtractorTransform.Column() diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs index 49aae725d4..5cf84b615d 100644 --- a/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs @@ -103,7 +103,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV var uniqueSourceNames = NgramExtractionUtils.GenerateUniqueSourceNames(h, args.Columns, view.Schema); Contracts.Assert(uniqueSourceNames.Length == args.Columns.Length); - var tokenizeColumns = new List(); + var tokenizeColumns = new List(); var extractorCols = new NgramHashExtractingTransformer.Column[args.Columns.Length]; var colCount = args.Columns.Length; List tmpColNames = new List(); @@ -114,7 +114,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV var curTmpNames = new string[srcCount]; Contracts.Assert(uniqueSourceNames[iinfo].Length == args.Columns[iinfo].Source.Length); for (int isrc = 0; isrc < srcCount; isrc++) - tokenizeColumns.Add(new WordTokenizingTransformer.ColumnInfo(curTmpNames[isrc] = uniqueSourceNames[iinfo][isrc], args.Columns[iinfo].Source[isrc])); + tokenizeColumns.Add(new WordTokenizingEstimator.ColumnInfo(curTmpNames[isrc] = uniqueSourceNames[iinfo][isrc], args.Columns[iinfo].Source[isrc])); tmpColNames.AddRange(curTmpNames); extractorCols[iinfo] = diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs index c1c7820afa..ca5157fe46 100644 --- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs +++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs @@ -19,7 +19,7 @@ using Microsoft.ML.Transforms.Text; using Newtonsoft.Json.Linq; -[assembly: LoadableClass(WordTokenizingTransformer.Summary, typeof(IDataTransform), typeof(WordTokenizingTransformer), typeof(WordTokenizingTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(WordTokenizingTransformer.Summary, typeof(IDataTransform), typeof(WordTokenizingTransformer), typeof(WordTokenizingTransformer.Options), typeof(SignatureDataTransform), "Word Tokenizer Transform", "WordTokenizeTransform", "DelimitedTokenizeTransform", "WordToken", "DelimitedTokenize", "Token")] [assembly: LoadableClass(WordTokenizingTransformer.Summary, typeof(IDataTransform), typeof(WordTokenizingTransformer), null, typeof(SignatureLoadDataTransform), @@ -81,16 +81,12 @@ public abstract class ArgumentsBase : TransformInputBase public char[] CharArrayTermSeparators; } - public sealed class Arguments : ArgumentsBase + internal sealed class Options : ArgumentsBase { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; } - public sealed class TokenizeArguments : ArgumentsBase - { - } - internal const string Summary = "The input to this transform is text, and the output is a vector of text containing the words (tokens) in the original text. " + "The separator is space, but can be specified as any other character (or multiple characters) if needed."; @@ -111,35 +107,16 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "DelimitedTokenize"; - public sealed class ColumnInfo - { - public readonly string Name; - public readonly string InputColumnName; - public readonly char[] Separators; + public IReadOnlyCollection Columns => _columns.AsReadOnly(); + private readonly WordTokenizingEstimator.ColumnInfo[] _columns; - /// - /// Describes how the transformer handles one column pair. - /// - /// Name of the column resulting from the transformation of . - /// Name of column to transform. If set to , the value of the will be used as source. - /// Casing text using the rules of the invariant culture. If not specified, space will be used as separator. - public ColumnInfo(string name, string inputColumnName = null, char[] separators = null) - { - Name = name; - InputColumnName = inputColumnName ?? name; - Separators = separators ?? new[] { ' ' }; - } - } - public IReadOnlyCollection Columns => _columns.AsReadOnly(); - private readonly ColumnInfo[] _columns; - - private static (string name, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) + private static (string name, string inputColumnName)[] GetColumnPairs(WordTokenizingEstimator.ColumnInfo[] columns) { Contracts.CheckNonEmpty(columns, nameof(columns)); return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } - public WordTokenizingTransformer(IHostEnvironment env, params ColumnInfo[] columns) : + internal WordTokenizingTransformer(IHostEnvironment env, params WordTokenizingEstimator.ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { _columns = columns.ToArray(); @@ -156,7 +133,7 @@ private WordTokenizingTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) { var columnsLength = ColumnPairs.Length; - _columns = new ColumnInfo[columnsLength]; + _columns = new WordTokenizingEstimator.ColumnInfo[columnsLength]; // *** Binary format *** // // for each added column @@ -165,7 +142,7 @@ private WordTokenizingTransformer(IHost host, ModelLoadContext ctx) : { var separators = ctx.Reader.ReadCharArray(); Contracts.CheckDecode(Utils.Size(separators) > 0); - _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, separators); + _columns[i] = new WordTokenizingEstimator.ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, separators); } } @@ -199,19 +176,19 @@ private static WordTokenizingTransformer Create(IHostEnvironment env, ModelLoadC } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Columns, nameof(args.Columns)); - var cols = new ColumnInfo[args.Columns.Length]; + env.CheckValue(options.Columns, nameof(options.Columns)); + var cols = new WordTokenizingEstimator.ColumnInfo[options.Columns.Length]; for (int i = 0; i < cols.Length; i++) { - var item = args.Columns[i]; - var separators = args.CharArrayTermSeparators ?? PredictionUtil.SeparatorFromString(item.TermSeparators ?? args.TermSeparators); - cols[i] = new ColumnInfo(item.Name, item.Source ?? item.Name, separators); + var item = options.Columns[i]; + var separators = options.CharArrayTermSeparators ?? PredictionUtil.SeparatorFromString(item.TermSeparators ?? options.TermSeparators); + cols[i] = new WordTokenizingEstimator.ColumnInfo(item.Name, item.Source ?? item.Name, separators); } return new WordTokenizingTransformer(env, cols).MakeDataTransform(input); @@ -439,7 +416,7 @@ public sealed class WordTokenizingEstimator : TrivialEstimatorName of the column resulting from the transformation of . /// Name of the column to transform. If set to , the value of the will be used as source. /// The separators to use (uses space character by default). - public WordTokenizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, char[] separators = null) + internal WordTokenizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, char[] separators = null) : this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, separators) { } @@ -450,8 +427,8 @@ public WordTokenizingEstimator(IHostEnvironment env, string outputColumnName, st /// The environment. /// Pairs of columns to run the tokenization on. /// The separators to use (uses space character by default). - public WordTokenizingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, char[] separators = null) - : this(env, columns.Select(x => new WordTokenizingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, separators)).ToArray()) + internal WordTokenizingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, char[] separators = null) + : this(env, columns.Select(x => new ColumnInfo(x.outputColumnName, x.inputColumnName, separators)).ToArray()) { } @@ -460,10 +437,29 @@ public WordTokenizingEstimator(IHostEnvironment env, (string outputColumnName, s /// /// The environment. /// Pairs of columns to run the tokenization on. - public WordTokenizingEstimator(IHostEnvironment env, params WordTokenizingTransformer.ColumnInfo[] columns) + internal WordTokenizingEstimator(IHostEnvironment env, params ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(WordTokenizingEstimator)), new WordTokenizingTransformer(env, columns)) { } + public sealed class ColumnInfo + { + public readonly string Name; + public readonly string InputColumnName; + public readonly char[] Separators; + + /// + /// Describes how the transformer handles one column pair. + /// + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. + /// Casing text using the rules of the invariant culture. If not specified, space will be used as separator. + public ColumnInfo(string name, string inputColumnName = null, char[] separators = null) + { + Name = name; + InputColumnName = inputColumnName ?? name; + Separators = separators ?? new[] { ' ' }; + } + } public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 2a26303b8b..f823ae05c2 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -134,4 +134,4 @@ Transforms.TreeLeafFeaturizer Trains a tree ensemble, or loads it from a file, t Transforms.TwoHeterogeneousModelCombiner Combines a TransformModel and a PredictorModel into a single PredictorModel. Microsoft.ML.EntryPoints.ModelOperations CombineTwoModels Microsoft.ML.EntryPoints.ModelOperations+SimplePredictorModelInput Microsoft.ML.EntryPoints.ModelOperations+PredictorModelOutput Transforms.VectorToImage Converts vector array into image type. Microsoft.ML.ImageAnalytics.EntryPoints.ImageAnalytics VectorToImage Microsoft.ML.ImageAnalytics.VectorToImageTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.WordEmbeddings Word Embeddings transform is a text featurizer which converts vectors of text tokens into sentence vectors using a pre-trained model Microsoft.ML.Transforms.Text.TextAnalytics WordEmbeddings Microsoft.ML.Transforms.Text.WordEmbeddingsExtractingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.WordTokenizer The input to this transform is text, and the output is a vector of text containing the words (tokens) in the original text. The separator is space, but can be specified as any other character (or multiple characters) if needed. Microsoft.ML.Transforms.Text.TextAnalytics DelimitedTokenizeTransform Microsoft.ML.Transforms.Text.WordTokenizingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.WordTokenizer The input to this transform is text, and the output is a vector of text containing the words (tokens) in the original text. The separator is space, but can be specified as any other character (or multiple characters) if needed. Microsoft.ML.Transforms.Text.TextAnalytics DelimitedTokenizeTransform Microsoft.ML.Transforms.Text.WordTokenizingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index e4b4410844..7c8ef48d22 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -183,7 +183,7 @@ public void StopWordsRemoverFromFactory() var tokenized = new WordTokenizingTransformer(ML, new[] { - new WordTokenizingTransformer.ColumnInfo("Text", "Text") + new WordTokenizingEstimator.ColumnInfo("Text", "Text") }).Transform(data); var xf = factory.CreateComponent(ML, tokenized, diff --git a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs index ae42eda606..bfb2881973 100644 --- a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs @@ -87,7 +87,7 @@ public void ValueMapInputIsVectorTest() var values = new List() { 1, 2, 3, 4 }; var estimator = new WordTokenizingEstimator(Env, new[]{ - new WordTokenizingTransformer.ColumnInfo("TokenizeA", "A") + new WordTokenizingEstimator.ColumnInfo("TokenizeA", "A") }).Append(new ValueMappingEstimator, int>(Env, keys, values, new[] { ("VecD", "TokenizeA"), ("E", "B"), ("F", "C") })); var t = estimator.Fit(dataView); @@ -120,7 +120,7 @@ public void ValueMapInputIsVectorAndValueAsStringKeyTypeTest() var values = new List>() { "a".AsMemory(), "b".AsMemory(), "c".AsMemory(), "d".AsMemory() }; var estimator = new WordTokenizingEstimator(Env, new[]{ - new WordTokenizingTransformer.ColumnInfo("TokenizeA", "A") + new WordTokenizingEstimator.ColumnInfo("TokenizeA", "A") }).Append(new ValueMappingEstimator, ReadOnlyMemory>(Env, keys, values, true, new[] { ("VecD", "TokenizeA"), ("E", "B"), ("F", "C") })); var t = estimator.Fit(dataView); diff --git a/test/Microsoft.ML.Tests/Transformers/WordTokenizeTests.cs b/test/Microsoft.ML.Tests/Transformers/WordTokenizeTests.cs index 5f6723526d..397dbf50ae 100644 --- a/test/Microsoft.ML.Tests/Transformers/WordTokenizeTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/WordTokenizeTests.cs @@ -57,8 +57,8 @@ public void WordTokenizeWorkout() var invalidData = new[] { new TestWrong() { A =1, B = new float[2] { 2,3 } } }; var invalidDataView = ML.Data.ReadFromEnumerable(invalidData); var pipe = new WordTokenizingEstimator(Env, new[]{ - new WordTokenizingTransformer.ColumnInfo("TokenizeA", "A"), - new WordTokenizingTransformer.ColumnInfo("TokenizeB", "B"), + new WordTokenizingEstimator.ColumnInfo("TokenizeA", "A"), + new WordTokenizingEstimator.ColumnInfo("TokenizeB", "B"), }); TestEstimatorCore(pipe, dataView, invalidInput: invalidDataView); @@ -99,8 +99,8 @@ public void TestOldSavingAndLoading() var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new WordTokenizingEstimator(Env, new[]{ - new WordTokenizingTransformer.ColumnInfo("TokenizeA", "A"), - new WordTokenizingTransformer.ColumnInfo("TokenizeB", "B"), + new WordTokenizingEstimator.ColumnInfo("TokenizeA", "A"), + new WordTokenizingEstimator.ColumnInfo("TokenizeB", "B"), }); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); From 60e83a77189ad4b3bc1cb1ed6e552993a6b79c57 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 5 Feb 2019 00:08:21 +0000 Subject: [PATCH 05/10] WordBagEstimator, WordHashBagEstimator --- .../Text/WordBagTransform.cs | 34 ++++---- .../Text/WordHashBagProducingTransform.cs | 82 +++++++++---------- .../Text/WrappedTextTransformers.cs | 16 ++-- .../UnitTests/TestEntryPoints.cs | 2 +- 4 files changed, 67 insertions(+), 67 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index ab07436d91..013e99ed0e 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -14,7 +14,7 @@ using Microsoft.ML.Transforms.Conversions; using Microsoft.ML.Transforms.Text; -[assembly: LoadableClass(WordBagBuildingTransformer.Summary, typeof(IDataTransform), typeof(WordBagBuildingTransformer), typeof(WordBagBuildingTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(WordBagBuildingTransformer.Summary, typeof(IDataTransform), typeof(WordBagBuildingTransformer), typeof(WordBagBuildingTransformer.Options), typeof(SignatureDataTransform), "Word Bag Transform", "WordBagTransform", "WordBag")] [assembly: LoadableClass(NgramExtractorTransform.Summary, typeof(INgramExtractorFactory), typeof(NgramExtractorTransform), typeof(NgramExtractorTransform.NgramExtractorArguments), @@ -94,7 +94,7 @@ internal bool TryUnparse(StringBuilder sb) /// public sealed class TokenizeColumn : OneToOneColumn { } - public sealed class Arguments : NgramExtractorTransform.ArgumentsBase + internal sealed class Options : NgramExtractorTransform.ArgumentsBase { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:srcs)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; @@ -105,13 +105,13 @@ public sealed class Arguments : NgramExtractorTransform.ArgumentsBase internal const string Summary = "Produces a bag of counts of ngrams (sequences of consecutive words of length 1-n) in a given text. It does so by building " + "a dictionary of ngrams and using the id in the dictionary as the index in the bag."; - public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); - h.CheckValue(args, nameof(args)); + h.CheckValue(options, nameof(options)); h.CheckValue(input, nameof(input)); - h.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns), "Columns must be specified"); + h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified"); // Compose the WordBagTransform from a tokenize transform, // followed by a NgramExtractionTransform. @@ -124,22 +124,22 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // REVIEW: In order to make it possible to output separate bags for different columns // using the same dictionary, we need to find a way to make ConcatTransform remember the boundaries. - var tokenizeColumns = new WordTokenizingEstimator.ColumnInfo[args.Columns.Length]; + var tokenizeColumns = new WordTokenizingEstimator.ColumnInfo[options.Columns.Length]; var extractorArgs = new NgramExtractorTransform.Arguments() { - MaxNumTerms = args.MaxNumTerms, - NgramLength = args.NgramLength, - SkipLength = args.SkipLength, - AllLengths = args.AllLengths, - Weighting = args.Weighting, - Columns = new NgramExtractorTransform.Column[args.Columns.Length] + MaxNumTerms = options.MaxNumTerms, + NgramLength = options.NgramLength, + SkipLength = options.SkipLength, + AllLengths = options.AllLengths, + Weighting = options.Weighting, + Columns = new NgramExtractorTransform.Column[options.Columns.Length] }; - for (int iinfo = 0; iinfo < args.Columns.Length; iinfo++) + for (int iinfo = 0; iinfo < options.Columns.Length; iinfo++) { - var column = args.Columns[iinfo]; + var column = options.Columns[iinfo]; h.CheckUserArg(!string.IsNullOrWhiteSpace(column.Name), nameof(column.Name)); h.CheckUserArg(Utils.Size(column.Source) > 0, nameof(column.Source)); h.CheckUserArg(column.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(column.Source)); @@ -160,7 +160,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV } IDataView view = input; - view = NgramExtractionUtils.ApplyConcatOnSources(h, args.Columns, view); + view = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns, view); view = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view).Transform(view); return NgramExtractorTransform.Create(h, extractorArgs, view); } @@ -517,7 +517,7 @@ public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColu var concatColumns = new List(); foreach (var col in columns) { - env.CheckUserArg(col != null, nameof(WordBagBuildingTransformer.Arguments.Columns)); + env.CheckUserArg(col != null, nameof(WordBagBuildingTransformer.Options.Columns)); env.CheckUserArg(!string.IsNullOrWhiteSpace(col.Name), nameof(col.Name)); env.CheckUserArg(Utils.Size(col.Source) > 0, nameof(col.Source)); env.CheckUserArg(col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source)); @@ -545,7 +545,7 @@ public static string[][] GenerateUniqueSourceNames(IHostEnvironment env, ManyToO for (int iinfo = 0; iinfo < columns.Length; iinfo++) { var col = columns[iinfo]; - env.CheckUserArg(col != null, nameof(WordHashBagProducingTransformer.Arguments.Columns)); + env.CheckUserArg(col != null, nameof(WordHashBagProducingTransformer.Options.Columns)); env.CheckUserArg(!string.IsNullOrWhiteSpace(col.Name), nameof(col.Name)); env.CheckUserArg(Utils.Size(col.Source) > 0 && col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source)); diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs index 5cf84b615d..6056e41f2d 100644 --- a/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs @@ -14,7 +14,7 @@ using Microsoft.ML.Transforms.Conversions; using Microsoft.ML.Transforms.Text; -[assembly: LoadableClass(WordHashBagProducingTransformer.Summary, typeof(IDataTransform), typeof(WordHashBagProducingTransformer), typeof(WordHashBagProducingTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(WordHashBagProducingTransformer.Summary, typeof(IDataTransform), typeof(WordHashBagProducingTransformer), typeof(WordHashBagProducingTransformer.Options), typeof(SignatureDataTransform), "Word Hash Bag Transform", "WordHashBagTransform", "WordHashBag")] [assembly: LoadableClass(NgramHashExtractingTransformer.Summary, typeof(INgramExtractorFactory), typeof(NgramHashExtractingTransformer), typeof(NgramHashExtractingTransformer.NgramHashExtractorArguments), @@ -73,7 +73,7 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class Arguments : NgramHashExtractingTransformer.ArgumentsBase + internal sealed class Options : NgramHashExtractingTransformer.ArgumentsBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:hashBits:srcs)", Name = "Column", ShortName = "col", SortOrder = 1)] @@ -84,13 +84,13 @@ public sealed class Arguments : NgramHashExtractingTransformer.ArgumentsBase internal const string Summary = "Produces a bag of counts of ngrams (sequences of consecutive words of length 1-n) in a given text. " + "It does so by hashing each ngram and using the hash value as the index in the bag."; - public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); - h.CheckValue(args, nameof(args)); + h.CheckValue(options, nameof(options)); h.CheckValue(input, nameof(input)); - h.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns), "Columns must be specified"); + h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified"); // To each input column to the WordHashBagTransform, a tokenize transform is applied, // followed by applying WordHashVectorizeTransform. @@ -100,21 +100,21 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // The intermediate columns are dropped at the end of using a DropColumnsTransform. IDataView view = input; - var uniqueSourceNames = NgramExtractionUtils.GenerateUniqueSourceNames(h, args.Columns, view.Schema); - Contracts.Assert(uniqueSourceNames.Length == args.Columns.Length); + var uniqueSourceNames = NgramExtractionUtils.GenerateUniqueSourceNames(h, options.Columns, view.Schema); + Contracts.Assert(uniqueSourceNames.Length == options.Columns.Length); var tokenizeColumns = new List(); - var extractorCols = new NgramHashExtractingTransformer.Column[args.Columns.Length]; - var colCount = args.Columns.Length; + var extractorCols = new NgramHashExtractingTransformer.Column[options.Columns.Length]; + var colCount = options.Columns.Length; List tmpColNames = new List(); for (int iinfo = 0; iinfo < colCount; iinfo++) { - var column = args.Columns[iinfo]; + var column = options.Columns[iinfo]; int srcCount = column.Source.Length; var curTmpNames = new string[srcCount]; - Contracts.Assert(uniqueSourceNames[iinfo].Length == args.Columns[iinfo].Source.Length); + Contracts.Assert(uniqueSourceNames[iinfo].Length == options.Columns[iinfo].Source.Length); for (int isrc = 0; isrc < srcCount; isrc++) - tokenizeColumns.Add(new WordTokenizingEstimator.ColumnInfo(curTmpNames[isrc] = uniqueSourceNames[iinfo][isrc], args.Columns[iinfo].Source[isrc])); + tokenizeColumns.Add(new WordTokenizingEstimator.ColumnInfo(curTmpNames[isrc] = uniqueSourceNames[iinfo][isrc], options.Columns[iinfo].Source[isrc])); tmpColNames.AddRange(curTmpNames); extractorCols[iinfo] = @@ -128,7 +128,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV SkipLength = column.SkipLength, Ordered = column.Ordered, InvertHash = column.InvertHash, - FriendlyNames = args.Columns[iinfo].Source, + FriendlyNames = options.Columns[iinfo].Source, AllLengths = column.AllLengths }; } @@ -136,16 +136,16 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV view = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view).Transform(view); var featurizeArgs = - new NgramHashExtractingTransformer.Arguments + new NgramHashExtractingTransformer.Options { - AllLengths = args.AllLengths, - HashBits = args.HashBits, - NgramLength = args.NgramLength, - SkipLength = args.SkipLength, - Ordered = args.Ordered, - Seed = args.Seed, + AllLengths = options.AllLengths, + HashBits = options.HashBits, + NgramLength = options.NgramLength, + SkipLength = options.SkipLength, + Ordered = options.Ordered, + Seed = options.Seed, Columns = extractorCols.ToArray(), - InvertHash = args.InvertHash + InvertHash = options.InvertHash }; view = NgramHashExtractingTransformer.Create(h, featurizeArgs, view); @@ -296,7 +296,7 @@ internal static class DefaultArguments [TlcModule.Component(Name = "NGramHash", FriendlyName = "NGram Hash Extractor Transform", Alias = "NGramHashExtractorTransform,NGramHashExtractor", Desc = "Extracts NGrams from text and convert them to vector using hashing trick.")] - public sealed class NgramHashExtractorArguments : ArgumentsBase, INgramExtractorFactoryFactory + internal sealed class NgramHashExtractorArguments : ArgumentsBase, INgramExtractorFactoryFactory { public INgramExtractorFactory CreateComponent(IHostEnvironment env, TermLoaderArguments loaderArgs) { @@ -304,7 +304,7 @@ public INgramExtractorFactory CreateComponent(IHostEnvironment env, TermLoaderAr } } - public sealed class Arguments : ArgumentsBase + internal sealed class Options : ArgumentsBase { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:srcs)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; @@ -314,14 +314,14 @@ public sealed class Arguments : ArgumentsBase internal const string LoaderSignature = "NgramHashExtractor"; - public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input, TermLoaderArguments termLoaderArgs = null) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(LoaderSignature); - h.CheckValue(args, nameof(args)); + h.CheckValue(options, nameof(options)); h.CheckValue(input, nameof(input)); - h.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns), "Columns must be specified"); + h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified"); // To each input column to the NgramHashExtractorArguments, a HashTransform using 31 // bits (to minimize collisions) is applied first, followed by an NgramHashTransform. @@ -331,15 +331,15 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV if (termLoaderArgs != null) termCols = new List(); var hashColumns = new List(); - var ngramHashColumns = new NgramHashingTransformer.ColumnInfo[args.Columns.Length]; + var ngramHashColumns = new NgramHashingTransformer.ColumnInfo[options.Columns.Length]; - var colCount = args.Columns.Length; + var colCount = options.Columns.Length; // The NGramHashExtractor has a ManyToOne column type. To avoid stepping over the source // column name when a 'name' destination column name was specified, we use temporary column names. string[][] tmpColNames = new string[colCount][]; for (int iinfo = 0; iinfo < colCount; iinfo++) { - var column = args.Columns[iinfo]; + var column = options.Columns[iinfo]; h.CheckUserArg(!string.IsNullOrWhiteSpace(column.Name), nameof(column.Name)); h.CheckUserArg(Utils.Size(column.Source) > 0 && column.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(column.Source)); @@ -361,18 +361,18 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV } hashColumns.Add(new HashingTransformer.ColumnInfo(tmpName, termLoaderArgs == null ? column.Source[isrc] : tmpName, - 30, column.Seed ?? args.Seed, false, column.InvertHash ?? args.InvertHash)); + 30, column.Seed ?? options.Seed, false, column.InvertHash ?? options.InvertHash)); } ngramHashColumns[iinfo] = new NgramHashingTransformer.ColumnInfo(column.Name, tmpColNames[iinfo], - column.NgramLength ?? args.NgramLength, - column.SkipLength ?? args.SkipLength, - column.AllLengths ?? args.AllLengths, - column.HashBits ?? args.HashBits, - column.Seed ?? args.Seed, - column.Ordered ?? args.Ordered, - column.InvertHash ?? args.InvertHash); + column.NgramLength ?? options.NgramLength, + column.SkipLength ?? options.SkipLength, + column.AllLengths ?? options.AllLengths, + column.HashBits ?? options.HashBits, + column.Seed ?? options.Seed, + column.Ordered ?? options.Ordered, + column.InvertHash ?? options.InvertHash); ngramHashColumns[iinfo].FriendlyNames = column.FriendlyNames; } @@ -406,7 +406,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV return ColumnSelectingTransformer.CreateDrop(h, view, tmpColNames.SelectMany(cols => cols).ToArray()); } - public static IDataTransform Create(NgramHashExtractorArguments extractorArgs, IHostEnvironment env, IDataView input, + internal static IDataTransform Create(NgramHashExtractorArguments extractorArgs, IHostEnvironment env, IDataView input, ExtractorColumn[] cols, TermLoaderArguments termLoaderArgs = null) { Contracts.CheckValue(env, nameof(env)); @@ -414,7 +414,7 @@ public static IDataTransform Create(NgramHashExtractorArguments extractorArgs, I h.CheckValue(extractorArgs, nameof(extractorArgs)); h.CheckValue(input, nameof(input)); h.CheckUserArg(extractorArgs.SkipLength < extractorArgs.NgramLength, nameof(extractorArgs.SkipLength), "Should be less than " + nameof(extractorArgs.NgramLength)); - h.CheckUserArg(Utils.Size(cols) > 0, nameof(Arguments.Columns), "Must be specified"); + h.CheckUserArg(Utils.Size(cols) > 0, nameof(Options.Columns), "Must be specified"); h.AssertValueOrNull(termLoaderArgs); var extractorCols = new Column[cols.Length]; @@ -429,7 +429,7 @@ public static IDataTransform Create(NgramHashExtractorArguments extractorArgs, I }; } - var args = new Arguments + var args = new Options { Columns = extractorCols, NgramLength = extractorArgs.NgramLength, @@ -444,7 +444,7 @@ public static IDataTransform Create(NgramHashExtractorArguments extractorArgs, I return Create(h, args, input, termLoaderArgs); } - public static INgramExtractorFactory Create(IHostEnvironment env, NgramHashExtractorArguments extractorArgs, + internal static INgramExtractorFactory Create(IHostEnvironment env, NgramHashExtractorArguments extractorArgs, TermLoaderArguments termLoaderArgs) { Contracts.CheckValue(env, nameof(env)); diff --git a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs index 1bda6b77f7..653a398844 100644 --- a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs +++ b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs @@ -35,7 +35,7 @@ public sealed class WordBagEstimator : TrainedWrapperEstimatorBase /// Whether to include all ngram lengths up to or only . /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. - public WordBagEstimator(IHostEnvironment env, + internal WordBagEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int ngramLength = 1, @@ -59,7 +59,7 @@ public WordBagEstimator(IHostEnvironment env, /// Whether to include all ngram lengths up to or only . /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. - public WordBagEstimator(IHostEnvironment env, + internal WordBagEstimator(IHostEnvironment env, string outputColumnName, string[] inputColumnNames, int ngramLength = 1, @@ -82,7 +82,7 @@ public WordBagEstimator(IHostEnvironment env, /// Whether to include all ngram lengths up to or only . /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. - public WordBagEstimator(IHostEnvironment env, + internal WordBagEstimator(IHostEnvironment env, (string outputColumnName, string[] inputColumnNames)[] columns, int ngramLength = 1, int skipLength = 0, @@ -108,7 +108,7 @@ public WordBagEstimator(IHostEnvironment env, public override TransformWrapper Fit(IDataView input) { // Create arguments. - var args = new WordBagBuildingTransformer.Arguments + var args = new WordBagBuildingTransformer.Options { Columns = _columns.Select(x => new WordBagBuildingTransformer.Column { Name = x.outputColumnName, Source = x.sourceColumnsNames }).ToArray(), NgramLength = _ngramLength, @@ -154,7 +154,7 @@ public sealed class WordHashBagEstimator : TrainedWrapperEstimatorBase /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. - public WordHashBagEstimator(IHostEnvironment env, + internal WordHashBagEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int hashBits = 16, @@ -185,7 +185,7 @@ public WordHashBagEstimator(IHostEnvironment env, /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. - public WordHashBagEstimator(IHostEnvironment env, + internal WordHashBagEstimator(IHostEnvironment env, string outputColumnName, string[] inputColumnNames, int hashBits = 16, @@ -215,7 +215,7 @@ public WordHashBagEstimator(IHostEnvironment env, /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. - public WordHashBagEstimator(IHostEnvironment env, + internal WordHashBagEstimator(IHostEnvironment env, (string outputColumnName, string[] inputColumnNames)[] columns, int hashBits = 16, int ngramLength = 1, @@ -245,7 +245,7 @@ public WordHashBagEstimator(IHostEnvironment env, public override TransformWrapper Fit(IDataView input) { // Create arguments. - var args = new WordHashBagProducingTransformer.Arguments + var args = new WordHashBagProducingTransformer.Options { Columns = _columns.Select(x => new WordHashBagProducingTransformer.Column { Name = x.outputColumnName ,Source = x.inputColumnNames}).ToArray(), HashBits = _hashBits, diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 6335259029..9e05ad0663 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -1007,7 +1007,7 @@ public void EntryPointPipelineEnsembleText() else { data = WordHashBagProducingTransformer.Create(Env, - new WordHashBagProducingTransformer.Arguments() + new WordHashBagProducingTransformer.Options() { Columns = new[] { new WordHashBagProducingTransformer.Column() { Name = "Features", Source = new[] { "Text" } }, } From 49915069148398ff6d1719b1c39a9d8dd9bc4a54 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 5 Feb 2019 01:06:54 +0000 Subject: [PATCH 06/10] NgramExtractingEstimator, NgramHashingEstimator --- .../TextStaticExtensions.cs | 4 +- .../EntryPoints/TextAnalytics.cs | 2 +- .../Text/NgramHashingTransformer.cs | 402 +++++++++--------- .../Text/NgramTransform.cs | 194 ++++----- .../Text/TextCatalog.cs | 2 +- .../Text/WordBagTransform.cs | 6 +- .../Text/WordHashBagProducingTransform.cs | 6 +- .../Common/EntryPoints/core_ep-list.tsv | 2 +- 8 files changed, 309 insertions(+), 309 deletions(-) diff --git a/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs index a2155330c6..e3a0c5a462 100644 --- a/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs @@ -559,9 +559,9 @@ public override IEstimator Reconcile(IHostEnvironment env, IReadOnlyCollection usedNames) { Contracts.Assert(toOutput.Length == 1); - var columns = new List(); + var columns = new List(); foreach (var outCol in toOutput) - columns.Add(new NgramHashingTransformer.ColumnInfo(outputNames[outCol], new[] { inputNames[((OutPipelineColumn)outCol).Input] }, + columns.Add(new NgramHashingEstimator.ColumnInfo(outputNames[outCol], new[] { inputNames[((OutPipelineColumn)outCol).Input] }, _ngramLength, _skipLength, _allLengths, _hashBits, _seed, _ordered, _invertHash)); return new NgramHashingEstimator(env, columns.ToArray()); diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index aa3f445594..697b39601e 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -51,7 +51,7 @@ public static CommonOutputs.TransformOutput DelimitedTokenizeTransform(IHostEnvi Desc = NgramExtractingTransformer.Summary, UserName = NgramExtractingTransformer.UserName, ShortName = NgramExtractingTransformer.LoaderSignature)] - public static CommonOutputs.TransformOutput NGramTransform(IHostEnvironment env, NgramExtractingTransformer.Arguments input) + public static CommonOutputs.TransformOutput NGramTransform(IHostEnvironment env, NgramExtractingTransformer.Options input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "NGramTransform", input); var xf = NgramExtractingTransformer.Create(h, input, input.Data); diff --git a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs index cb76b256e7..bb6490d283 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs @@ -17,7 +17,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Transforms.Text; -[assembly: LoadableClass(NgramHashingTransformer.Summary, typeof(IDataTransform), typeof(NgramHashingTransformer), typeof(NgramHashingTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(NgramHashingTransformer.Summary, typeof(IDataTransform), typeof(NgramHashingTransformer), typeof(NgramHashingTransformer.Options), typeof(SignatureDataTransform), "Ngram Hash Transform", "NgramHashTransform", "NgramHash")] [assembly: LoadableClass(NgramHashingTransformer.Summary, typeof(IDataTransform), typeof(NgramHashingTransformer), null, typeof(SignatureLoadDataTransform), @@ -112,7 +112,7 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class Arguments + internal sealed class Options { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:hashBits:src)", ShortName = "col", @@ -173,180 +173,7 @@ private static VersionInfo GetVersionInfo() private const int VersionTransformer = 0x00010003; - /// - /// Describes how the transformer handles one pair of mulitple inputs - singular output columns. - /// - public sealed class ColumnInfo - { - public readonly string Name; - public readonly string[] InputColumnNames; - public readonly int NgramLength; - public readonly int SkipLength; - public readonly bool AllLengths; - public readonly int HashBits; - public readonly uint Seed; - public readonly bool Ordered; - public readonly int InvertHash; - public readonly bool RehashUnigrams; - // For all source columns, use these friendly names for the source - // column names instead of the real column names. - internal string[] FriendlyNames; - - /// - /// Describes how the transformer handles one column pair. - /// - /// Name of the column resulting from the transformation of . - /// Name of the columns to transform. - /// Maximum ngram length. - /// Maximum number of tokens to skip when constructing an ngram. - /// "Whether to store all ngram lengths up to ngramLength, or only ngramLength. - /// Number of bits to hash into. Must be between 1 and 31, inclusive. - /// Hashing seed. - /// Whether the position of each term should be included in the hash. - /// During hashing we constuct mappings between original values and the produced hash values. - /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. - /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. - /// 0 does not retain any input values. -1 retains all input values mapping to each hash. - /// Whether to rehash unigrams. - public ColumnInfo(string name, - string[] inputColumnNames, - int ngramLength = NgramHashingEstimator.Defaults.NgramLength, - int skipLength = NgramHashingEstimator.Defaults.SkipLength, - bool allLengths = NgramHashingEstimator.Defaults.AllLengths, - int hashBits = NgramHashingEstimator.Defaults.HashBits, - uint seed = NgramHashingEstimator.Defaults.Seed, - bool ordered = NgramHashingEstimator.Defaults.Ordered, - int invertHash = NgramHashingEstimator.Defaults.InvertHash, - bool rehashUnigrams = NgramHashingEstimator.Defaults.RehashUnigrams) - { - Contracts.CheckValue(name, nameof(name)); - Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames)); - Contracts.CheckParam(!inputColumnNames.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputColumnNames), - "Contained some null or empty items"); - if (invertHash < -1) - throw Contracts.ExceptParam(nameof(invertHash), "Value too small, must be -1 or larger"); - // If the bits is 31 or higher, we can't declare a KeyValues of the appropriate length, - // this requiring a VBuffer of length 1u << 31 which exceeds int.MaxValue. - if (invertHash != 0 && hashBits >= 31) - throw Contracts.ExceptParam(nameof(hashBits), $"Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", hashBits); - - if (NgramLength + SkipLength > NgramBufferBuilder.MaxSkipNgramLength) - { - throw Contracts.ExceptUserArg(nameof(skipLength), - $"The sum of skipLength and ngramLength must be less than or equal to {NgramBufferBuilder.MaxSkipNgramLength}"); - } - FriendlyNames = null; - Name = name; - InputColumnNames = inputColumnNames; - NgramLength = ngramLength; - SkipLength = skipLength; - AllLengths = allLengths; - HashBits = hashBits; - Seed = seed; - Ordered = ordered; - InvertHash = invertHash; - RehashUnigrams = rehashUnigrams; - } - - internal ColumnInfo(ModelLoadContext ctx) - { - Contracts.AssertValue(ctx); - - // *** Binary format *** - // size of Inputs - // string[] Inputs; - // string Output; - // int: NgramLength - // int: SkipLength - // int: HashBits - // uint: Seed - // byte: Rehash - // byte: Ordered - // byte: AllLengths - var inputsLength = ctx.Reader.ReadInt32(); - InputColumnNames = new string[inputsLength]; - for (int i = 0; i < InputColumnNames.Length; i++) - InputColumnNames[i] = ctx.LoadNonEmptyString(); - Name = ctx.LoadNonEmptyString(); - NgramLength = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength); - SkipLength = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength); - Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength); - HashBits = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(1 <= HashBits && HashBits <= 30); - Seed = ctx.Reader.ReadUInt32(); - RehashUnigrams = ctx.Reader.ReadBoolByte(); - Ordered = ctx.Reader.ReadBoolByte(); - AllLengths = ctx.Reader.ReadBoolByte(); - } - - internal ColumnInfo(ModelLoadContext ctx, string name, string[] inputColumnNames) - { - Contracts.AssertValue(ctx); - Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames)); - Contracts.CheckParam(!inputColumnNames.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputColumnNames), - "Contained some null or empty items"); - InputColumnNames = inputColumnNames; - Name = name; - // *** Binary format *** - // string Output; - // int: NgramLength - // int: SkipLength - // int: HashBits - // uint: Seed - // byte: Rehash - // byte: Ordered - // byte: AllLengths - NgramLength = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength); - SkipLength = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength); - Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength); - HashBits = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(1 <= HashBits && HashBits <= 30); - Seed = ctx.Reader.ReadUInt32(); - RehashUnigrams = ctx.Reader.ReadBoolByte(); - Ordered = ctx.Reader.ReadBoolByte(); - AllLengths = ctx.Reader.ReadBoolByte(); - } - - internal void Save(ModelSaveContext ctx) - { - Contracts.AssertValue(ctx); - - // *** Binary format *** - // size of Inputs - // string[] Inputs; - // string Output; - // int: NgramLength - // int: SkipLength - // int: HashBits - // uint: Seed - // byte: Rehash - // byte: Ordered - // byte: AllLengths - Contracts.Assert(InputColumnNames.Length > 0); - ctx.Writer.Write(InputColumnNames.Length); - for (int i = 0; i < InputColumnNames.Length; i++) - ctx.SaveNonEmptyString(InputColumnNames[i]); - ctx.SaveNonEmptyString(Name); - - Contracts.Assert(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength); - ctx.Writer.Write(NgramLength); - Contracts.Assert(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength); - Contracts.Assert(NgramLength + SkipLength <= NgramBufferBuilder.MaxSkipNgramLength); - ctx.Writer.Write(SkipLength); - Contracts.Assert(1 <= HashBits && HashBits <= 30); - ctx.Writer.Write(HashBits); - ctx.Writer.Write(Seed); - ctx.Writer.WriteBoolByte(RehashUnigrams); - ctx.Writer.WriteBoolByte(Ordered); - ctx.Writer.WriteBoolByte(AllLengths); - } - } - - private readonly ImmutableArray _columns; + private readonly ImmutableArray _columns; private readonly VBuffer>[] _slotNames; private readonly VectorType[] _slotNamesTypes; @@ -355,7 +182,7 @@ internal void Save(ModelSaveContext ctx) /// /// Host Environment. /// Description of dataset columns and how to process them. - public NgramHashingTransformer(IHostEnvironment env, params ColumnInfo[] columns) : + internal NgramHashingTransformer(IHostEnvironment env, params NgramHashingEstimator.ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NgramHashingTransformer))) { _columns = columns.ToImmutableArray(); @@ -366,7 +193,7 @@ public NgramHashingTransformer(IHostEnvironment env, params ColumnInfo[] columns } } - internal NgramHashingTransformer(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) : + internal NgramHashingTransformer(IHostEnvironment env, IDataView input, params NgramHashingEstimator.ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NgramHashingTransformer))) { Contracts.CheckValue(columns, nameof(columns)); @@ -463,14 +290,14 @@ private NgramHashingTransformer(IHostEnvironment env, ModelLoadContext ctx, bool } var columnsLength = ctx.Reader.ReadInt32(); Contracts.CheckDecode(columnsLength > 0); - var columns = new ColumnInfo[columnsLength]; + var columns = new NgramHashingEstimator.ColumnInfo[columnsLength]; if (!loadLegacy) { // *** Binary format *** // int number of columns // columns for (int i = 0; i < columnsLength; i++) - columns[i] = new ColumnInfo(ctx); + columns[i] = new NgramHashingEstimator.ColumnInfo(ctx); } else { @@ -500,38 +327,38 @@ private NgramHashingTransformer(IHostEnvironment env, ModelLoadContext ctx, bool // int number of columns // columns for (int i = 0; i < columnsLength; i++) - columns[i] = new ColumnInfo(ctx, outputs[i], inputs[i]); + columns[i] = new NgramHashingEstimator.ColumnInfo(ctx, outputs[i], inputs[i]); } _columns = columns.ToImmutableArray(); TextModelHelper.LoadAll(Host, ctx, columnsLength, out _slotNames, out _slotNamesTypes); } // Factory method for SignatureDataTransform. - private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Column, nameof(args.Column)); - var cols = new ColumnInfo[args.Column.Length]; + env.CheckValue(options.Column, nameof(options.Column)); + var cols = new NgramHashingEstimator.ColumnInfo[options.Column.Length]; using (var ch = env.Start("ValidateArgs")) { for (int i = 0; i < cols.Length; i++) { - var item = args.Column[i]; - cols[i] = new ColumnInfo( + var item = options.Column[i]; + cols[i] = new NgramHashingEstimator.ColumnInfo( item.Name, item.Source ?? new string[] { item.Name }, - item.NgramLength ?? args.NgramLength, - item.SkipLength ?? args.SkipLength, - item.AllLengths ?? args.AllLengths, - item.HashBits ?? args.HashBits, - item.Seed ?? args.Seed, - item.Ordered ?? args.Ordered, - item.InvertHash ?? args.InvertHash, - item.RehashUnigrams ?? args.RehashUnigrams + item.NgramLength ?? options.NgramLength, + item.SkipLength ?? options.SkipLength, + item.AllLengths ?? options.AllLengths, + item.HashBits ?? options.HashBits, + item.Seed ?? options.Seed, + item.Ordered ?? options.Ordered, + item.InvertHash ?? options.InvertHash, + item.RehashUnigrams ?? options.RehashUnigrams ); }; } @@ -1044,6 +871,179 @@ public VBuffer>[] SlotNamesMetadata(out VectorType[] types) /// public sealed class NgramHashingEstimator : IEstimator { + /// + /// Describes how the transformer handles one pair of mulitple inputs - singular output columns. + /// + public sealed class ColumnInfo + { + public readonly string Name; + public readonly string[] InputColumnNames; + public readonly int NgramLength; + public readonly int SkipLength; + public readonly bool AllLengths; + public readonly int HashBits; + public readonly uint Seed; + public readonly bool Ordered; + public readonly int InvertHash; + public readonly bool RehashUnigrams; + // For all source columns, use these friendly names for the source + // column names instead of the real column names. + internal string[] FriendlyNames; + + /// + /// Describes how the transformer handles one column pair. + /// + /// Name of the column resulting from the transformation of . + /// Name of the columns to transform. + /// Maximum ngram length. + /// Maximum number of tokens to skip when constructing an ngram. + /// "Whether to store all ngram lengths up to ngramLength, or only ngramLength. + /// Number of bits to hash into. Must be between 1 and 31, inclusive. + /// Hashing seed. + /// Whether the position of each term should be included in the hash. + /// During hashing we constuct mappings between original values and the produced hash values. + /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. + /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. + /// 0 does not retain any input values. -1 retains all input values mapping to each hash. + /// Whether to rehash unigrams. + public ColumnInfo(string name, + string[] inputColumnNames, + int ngramLength = NgramHashingEstimator.Defaults.NgramLength, + int skipLength = NgramHashingEstimator.Defaults.SkipLength, + bool allLengths = NgramHashingEstimator.Defaults.AllLengths, + int hashBits = NgramHashingEstimator.Defaults.HashBits, + uint seed = NgramHashingEstimator.Defaults.Seed, + bool ordered = NgramHashingEstimator.Defaults.Ordered, + int invertHash = NgramHashingEstimator.Defaults.InvertHash, + bool rehashUnigrams = NgramHashingEstimator.Defaults.RehashUnigrams) + { + Contracts.CheckValue(name, nameof(name)); + Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames)); + Contracts.CheckParam(!inputColumnNames.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputColumnNames), + "Contained some null or empty items"); + if (invertHash < -1) + throw Contracts.ExceptParam(nameof(invertHash), "Value too small, must be -1 or larger"); + // If the bits is 31 or higher, we can't declare a KeyValues of the appropriate length, + // this requiring a VBuffer of length 1u << 31 which exceeds int.MaxValue. + if (invertHash != 0 && hashBits >= 31) + throw Contracts.ExceptParam(nameof(hashBits), $"Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", hashBits); + + if (NgramLength + SkipLength > NgramBufferBuilder.MaxSkipNgramLength) + { + throw Contracts.ExceptUserArg(nameof(skipLength), + $"The sum of skipLength and ngramLength must be less than or equal to {NgramBufferBuilder.MaxSkipNgramLength}"); + } + FriendlyNames = null; + Name = name; + InputColumnNames = inputColumnNames; + NgramLength = ngramLength; + SkipLength = skipLength; + AllLengths = allLengths; + HashBits = hashBits; + Seed = seed; + Ordered = ordered; + InvertHash = invertHash; + RehashUnigrams = rehashUnigrams; + } + + internal ColumnInfo(ModelLoadContext ctx) + { + Contracts.AssertValue(ctx); + + // *** Binary format *** + // size of Inputs + // string[] Inputs; + // string Output; + // int: NgramLength + // int: SkipLength + // int: HashBits + // uint: Seed + // byte: Rehash + // byte: Ordered + // byte: AllLengths + var inputsLength = ctx.Reader.ReadInt32(); + InputColumnNames = new string[inputsLength]; + for (int i = 0; i < InputColumnNames.Length; i++) + InputColumnNames[i] = ctx.LoadNonEmptyString(); + Name = ctx.LoadNonEmptyString(); + NgramLength = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength); + SkipLength = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength); + Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength); + HashBits = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(1 <= HashBits && HashBits <= 30); + Seed = ctx.Reader.ReadUInt32(); + RehashUnigrams = ctx.Reader.ReadBoolByte(); + Ordered = ctx.Reader.ReadBoolByte(); + AllLengths = ctx.Reader.ReadBoolByte(); + } + + internal ColumnInfo(ModelLoadContext ctx, string name, string[] inputColumnNames) + { + Contracts.AssertValue(ctx); + Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames)); + Contracts.CheckParam(!inputColumnNames.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputColumnNames), + "Contained some null or empty items"); + InputColumnNames = inputColumnNames; + Name = name; + // *** Binary format *** + // string Output; + // int: NgramLength + // int: SkipLength + // int: HashBits + // uint: Seed + // byte: Rehash + // byte: Ordered + // byte: AllLengths + NgramLength = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength); + SkipLength = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength); + Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength); + HashBits = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(1 <= HashBits && HashBits <= 30); + Seed = ctx.Reader.ReadUInt32(); + RehashUnigrams = ctx.Reader.ReadBoolByte(); + Ordered = ctx.Reader.ReadBoolByte(); + AllLengths = ctx.Reader.ReadBoolByte(); + } + + internal void Save(ModelSaveContext ctx) + { + Contracts.AssertValue(ctx); + + // *** Binary format *** + // size of Inputs + // string[] Inputs; + // string Output; + // int: NgramLength + // int: SkipLength + // int: HashBits + // uint: Seed + // byte: Rehash + // byte: Ordered + // byte: AllLengths + Contracts.Assert(InputColumnNames.Length > 0); + ctx.Writer.Write(InputColumnNames.Length); + for (int i = 0; i < InputColumnNames.Length; i++) + ctx.SaveNonEmptyString(InputColumnNames[i]); + ctx.SaveNonEmptyString(Name); + + Contracts.Assert(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength); + ctx.Writer.Write(NgramLength); + Contracts.Assert(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength); + Contracts.Assert(NgramLength + SkipLength <= NgramBufferBuilder.MaxSkipNgramLength); + ctx.Writer.Write(SkipLength); + Contracts.Assert(1 <= HashBits && HashBits <= 30); + ctx.Writer.Write(HashBits); + ctx.Writer.Write(Seed); + ctx.Writer.WriteBoolByte(RehashUnigrams); + ctx.Writer.WriteBoolByte(Ordered); + ctx.Writer.WriteBoolByte(AllLengths); + } + } + internal static class Defaults { internal const int NgramLength = 2; @@ -1057,7 +1057,7 @@ internal static class Defaults } private readonly IHost _host; - private readonly NgramHashingTransformer.ColumnInfo[] _columns; + private readonly ColumnInfo[] _columns; /// /// Produces a bag of counts of hashed ngrams in @@ -1079,7 +1079,7 @@ internal static class Defaults /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. - public NgramHashingEstimator(IHostEnvironment env, + internal NgramHashingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int hashBits = 16, @@ -1113,7 +1113,7 @@ public NgramHashingEstimator(IHostEnvironment env, /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. - public NgramHashingEstimator(IHostEnvironment env, + internal NgramHashingEstimator(IHostEnvironment env, string outputColumnName, string[] inputColumnNames, int hashBits = 16, @@ -1146,7 +1146,7 @@ public NgramHashingEstimator(IHostEnvironment env, /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. - public NgramHashingEstimator(IHostEnvironment env, + internal NgramHashingEstimator(IHostEnvironment env, (string outputColumnName, string[] inputColumnName)[] columns, int hashBits = 16, int ngramLength = 2, @@ -1155,7 +1155,7 @@ public NgramHashingEstimator(IHostEnvironment env, uint seed = 314489979, bool ordered = true, int invertHash = 0) - : this(env, columns.Select(x => new NgramHashingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, hashBits, seed, ordered, invertHash)).ToArray()) + : this(env, columns.Select(x => new ColumnInfo(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, hashBits, seed, ordered, invertHash)).ToArray()) { } @@ -1169,7 +1169,7 @@ public NgramHashingEstimator(IHostEnvironment env, /// /// The environment. /// Array of columns which specifies the behavior of the transformation. - public NgramHashingEstimator(IHostEnvironment env, params NgramHashingTransformer.ColumnInfo[] columns) + internal NgramHashingEstimator(IHostEnvironment env, params ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(NgramHashingEstimator)); diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index 3e3371ef44..593941f3df 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -18,7 +18,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Transforms.Text; -[assembly: LoadableClass(NgramExtractingTransformer.Summary, typeof(IDataTransform), typeof(NgramExtractingTransformer), typeof(NgramExtractingTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(NgramExtractingTransformer.Summary, typeof(IDataTransform), typeof(NgramExtractingTransformer), typeof(NgramExtractingTransformer.Options), typeof(SignatureDataTransform), "Ngram Transform", "NgramTransform", "Ngram")] [assembly: LoadableClass(NgramExtractingTransformer.Summary, typeof(IDataTransform), typeof(NgramExtractingTransformer), null, typeof(SignatureLoadDataTransform), @@ -77,7 +77,7 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; @@ -122,81 +122,6 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(NgramExtractingTransformer).Assembly.FullName); } - /// - /// Describes how the transformer handles one column pair. - /// - public sealed class ColumnInfo - { - public readonly string Name; - public readonly string InputColumnName; - public readonly int NgramLength; - public readonly int SkipLength; - public readonly bool AllLengths; - public readonly NgramExtractingEstimator.WeightingCriteria Weighting; - /// - /// Contains the maximum number of grams to store in the dictionary, for each level of ngrams, - /// from 1 (in position 0) up to ngramLength (in position ngramLength-1) - /// - public readonly ImmutableArray Limits; - - /// - /// Describes how the transformer handles one Gcn column pair. - /// - /// Name of the column resulting from the transformation of . - /// Name of column to transform. If set to , the value of the will be used as source. - /// Maximum ngram length. - /// Maximum number of tokens to skip when constructing an ngram. - /// "Whether to store all ngram lengths up to ngramLength, or only ngramLength. - /// The weighting criteria. - /// Maximum number of ngrams to store in the dictionary. - public ColumnInfo(string name, string inputColumnName = null, - int ngramLength = NgramExtractingEstimator.Defaults.NgramLength, - int skipLength = NgramExtractingEstimator.Defaults.SkipLength, - bool allLengths = NgramExtractingEstimator.Defaults.AllLengths, - NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.Defaults.Weighting, - int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms) - : this(name, ngramLength, skipLength, allLengths, weighting, new int[] { maxNumTerms }, inputColumnName ?? name) - { - } - - internal ColumnInfo(string name, - int ngramLength, - int skipLength, - bool allLengths, - NgramExtractingEstimator.WeightingCriteria weighting, - int[] maxNumTerms, - string inputColumnName = null) - { - Name = name; - InputColumnName = inputColumnName ?? name; - NgramLength = ngramLength; - Contracts.CheckUserArg(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength, nameof(ngramLength)); - SkipLength = skipLength; - if (NgramLength + SkipLength > NgramBufferBuilder.MaxSkipNgramLength) - { - throw Contracts.ExceptUserArg(nameof(skipLength), - $"The sum of skipLength and ngramLength must be less than or equal to {NgramBufferBuilder.MaxSkipNgramLength}"); - } - AllLengths = allLengths; - Weighting = weighting; - var limits = new int[ngramLength]; - if (!AllLengths) - { - Contracts.CheckUserArg(Utils.Size(maxNumTerms) == 0 || - Utils.Size(maxNumTerms) == 1 && maxNumTerms[0] > 0, nameof(maxNumTerms)); - limits[ngramLength - 1] = Utils.Size(maxNumTerms) == 0 ? NgramExtractingEstimator.Defaults.MaxNumTerms : maxNumTerms[0]; - } - else - { - Contracts.CheckUserArg(Utils.Size(maxNumTerms) <= ngramLength, nameof(maxNumTerms)); - Contracts.CheckUserArg(Utils.Size(maxNumTerms) == 0 || maxNumTerms.All(i => i >= 0) && maxNumTerms[maxNumTerms.Length - 1] > 0, nameof(maxNumTerms)); - var extend = Utils.Size(maxNumTerms) == 0 ? NgramExtractingEstimator.Defaults.MaxNumTerms : maxNumTerms[maxNumTerms.Length - 1]; - limits = Utils.BuildArray(ngramLength, i => i < Utils.Size(maxNumTerms) ? maxNumTerms[i] : extend); - } - Limits = ImmutableArray.Create(limits); - } - } - private sealed class TransformInfo { // Position i, indicates whether the pool contains any (i+1)-grams @@ -207,7 +132,7 @@ private sealed class TransformInfo public bool RequireIdf => Weighting == NgramExtractingEstimator.WeightingCriteria.Idf || Weighting == NgramExtractingEstimator.WeightingCriteria.TfIdf; - public TransformInfo(ColumnInfo info) + public TransformInfo(NgramExtractingEstimator.ColumnInfo info) { NgramLength = info.NgramLength; SkipLength = info.SkipLength; @@ -267,7 +192,7 @@ public void Save(ModelSaveContext ctx) // Ngram inverse document frequencies private readonly double[][] _invDocFreqs; - private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(NgramExtractingEstimator.ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); @@ -280,7 +205,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, NgramExtractingEstimator.ExpectedColumnType, type.ToString()); } - internal NgramExtractingTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] columns) + internal NgramExtractingTransformer(IHostEnvironment env, IDataView input, NgramExtractingEstimator.ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NgramExtractingTransformer)), GetColumnPairs(columns)) { var transformInfos = new TransformInfo[columns.Length]; @@ -294,7 +219,7 @@ internal NgramExtractingTransformer(IHostEnvironment env, IDataView input, Colum _ngramMaps = Train(Host, columns, _transformInfos, input, out _invDocFreqs); } - private static SequencePool[] Train(IHostEnvironment env, ColumnInfo[] columns, ImmutableArray transformInfos, IDataView trainingData, out double[][] invDocFreqs) + private static SequencePool[] Train(IHostEnvironment env, NgramExtractingEstimator.ColumnInfo[] columns, ImmutableArray transformInfos, IDataView trainingData, out double[][] invDocFreqs) { var helpers = new NgramBufferBuilder[columns.Length]; var getters = new ValueGetter>[columns.Length]; @@ -482,27 +407,27 @@ private NgramExtractingTransformer(IHost host, ModelLoadContext ctx) : } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Columns, nameof(args.Columns)); - var cols = new ColumnInfo[args.Columns.Length]; + env.CheckValue(options.Columns, nameof(options.Columns)); + var cols = new NgramExtractingEstimator.ColumnInfo[options.Columns.Length]; using (var ch = env.Start("ValidateArgs")) { for (int i = 0; i < cols.Length; i++) { - var item = args.Columns[i]; - var maxNumTerms = Utils.Size(item.MaxNumTerms) > 0 ? item.MaxNumTerms : args.MaxNumTerms; - cols[i] = new ColumnInfo( + var item = options.Columns[i]; + var maxNumTerms = Utils.Size(item.MaxNumTerms) > 0 ? item.MaxNumTerms : options.MaxNumTerms; + cols[i] = new NgramExtractingEstimator.ColumnInfo( item.Name, - item.NgramLength ?? args.NgramLength, - item.SkipLength ?? args.SkipLength, - item.AllLengths ?? args.AllLengths, - item.Weighting ?? args.Weighting, + item.NgramLength ?? options.NgramLength, + item.SkipLength ?? options.SkipLength, + item.AllLengths ?? options.AllLengths, + item.Weighting ?? options.Weighting, maxNumTerms, item.Source ?? item.Name); }; @@ -777,7 +702,7 @@ internal static class Defaults } private readonly IHost _host; - private readonly NgramExtractingTransformer.ColumnInfo[] _columns; + private readonly ColumnInfo[] _columns; /// /// Produces a bag of counts of ngrams (sequences of consecutive words) in @@ -791,7 +716,7 @@ internal static class Defaults /// Whether to include all ngram lengths up to or only . /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. - public NgramExtractingEstimator(IHostEnvironment env, + internal NgramExtractingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int ngramLength = Defaults.NgramLength, int skipLength = Defaults.SkipLength, @@ -813,14 +738,14 @@ public NgramExtractingEstimator(IHostEnvironment env, /// Whether to include all ngram lengths up to or only . /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. - public NgramExtractingEstimator(IHostEnvironment env, + internal NgramExtractingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, int ngramLength = Defaults.NgramLength, int skipLength = Defaults.SkipLength, bool allLengths = Defaults.AllLengths, int maxNumTerms = Defaults.MaxNumTerms, WeightingCriteria weighting = Defaults.Weighting) - : this(env, columns.Select(x => new NgramExtractingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, weighting, maxNumTerms)).ToArray()) + : this(env, columns.Select(x => new ColumnInfo(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, weighting, maxNumTerms)).ToArray()) { } @@ -830,7 +755,7 @@ public NgramExtractingEstimator(IHostEnvironment env, /// /// The environment. /// Array of columns with information how to transform data. - public NgramExtractingEstimator(IHostEnvironment env, params NgramExtractingTransformer.ColumnInfo[] columns) + internal NgramExtractingEstimator(IHostEnvironment env, params ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(NgramExtractingEstimator)); @@ -865,6 +790,81 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col) internal const string ExpectedColumnType = "Expected vector of Key type, and Key is convertible to U4"; + /// + /// Describes how the transformer handles one column pair. + /// + public sealed class ColumnInfo + { + public readonly string Name; + public readonly string InputColumnName; + public readonly int NgramLength; + public readonly int SkipLength; + public readonly bool AllLengths; + public readonly NgramExtractingEstimator.WeightingCriteria Weighting; + /// + /// Contains the maximum number of grams to store in the dictionary, for each level of ngrams, + /// from 1 (in position 0) up to ngramLength (in position ngramLength-1) + /// + public readonly ImmutableArray Limits; + + /// + /// Describes how the transformer handles one Gcn column pair. + /// + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. + /// Maximum ngram length. + /// Maximum number of tokens to skip when constructing an ngram. + /// "Whether to store all ngram lengths up to ngramLength, or only ngramLength. + /// The weighting criteria. + /// Maximum number of ngrams to store in the dictionary. + public ColumnInfo(string name, string inputColumnName = null, + int ngramLength = NgramExtractingEstimator.Defaults.NgramLength, + int skipLength = NgramExtractingEstimator.Defaults.SkipLength, + bool allLengths = NgramExtractingEstimator.Defaults.AllLengths, + NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.Defaults.Weighting, + int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms) + : this(name, ngramLength, skipLength, allLengths, weighting, new int[] { maxNumTerms }, inputColumnName ?? name) + { + } + + internal ColumnInfo(string name, + int ngramLength, + int skipLength, + bool allLengths, + NgramExtractingEstimator.WeightingCriteria weighting, + int[] maxNumTerms, + string inputColumnName = null) + { + Name = name; + InputColumnName = inputColumnName ?? name; + NgramLength = ngramLength; + Contracts.CheckUserArg(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength, nameof(ngramLength)); + SkipLength = skipLength; + if (NgramLength + SkipLength > NgramBufferBuilder.MaxSkipNgramLength) + { + throw Contracts.ExceptUserArg(nameof(skipLength), + $"The sum of skipLength and ngramLength must be less than or equal to {NgramBufferBuilder.MaxSkipNgramLength}"); + } + AllLengths = allLengths; + Weighting = weighting; + var limits = new int[ngramLength]; + if (!AllLengths) + { + Contracts.CheckUserArg(Utils.Size(maxNumTerms) == 0 || + Utils.Size(maxNumTerms) == 1 && maxNumTerms[0] > 0, nameof(maxNumTerms)); + limits[ngramLength - 1] = Utils.Size(maxNumTerms) == 0 ? NgramExtractingEstimator.Defaults.MaxNumTerms : maxNumTerms[0]; + } + else + { + Contracts.CheckUserArg(Utils.Size(maxNumTerms) <= ngramLength, nameof(maxNumTerms)); + Contracts.CheckUserArg(Utils.Size(maxNumTerms) == 0 || maxNumTerms.All(i => i >= 0) && maxNumTerms[maxNumTerms.Length - 1] > 0, nameof(maxNumTerms)); + var extend = Utils.Size(maxNumTerms) == 0 ? NgramExtractingEstimator.Defaults.MaxNumTerms : maxNumTerms[maxNumTerms.Length - 1]; + limits = Utils.BuildArray(ngramLength, i => i < Utils.Size(maxNumTerms) ? maxNumTerms[i] : extend); + } + Limits = ImmutableArray.Create(limits); + } + } + public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); diff --git a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs index 624f9ee7a9..e04571e0dc 100644 --- a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs @@ -241,7 +241,7 @@ public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.Text /// The text-related transform's catalog. /// Pairs of columns to run the ngram process on. public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.TextTransforms catalog, - params NgramExtractingTransformer.ColumnInfo[] columns) + params NgramExtractingEstimator.ColumnInfo[] columns) => new NgramExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns); /// diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index 013e99ed0e..c489e6ae32 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -220,7 +220,7 @@ internal bool TryUnparse(StringBuilder sb) /// /// This class is a merger of and - /// , with the allLength option removed. + /// , with the allLength option removed. /// public abstract class ArgumentsBase { @@ -349,11 +349,11 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV view = new MissingValueDroppingTransformer(h, missingDropColumns.Select(x => (x, x)).ToArray()).Transform(view); } - var ngramColumns = new NgramExtractingTransformer.ColumnInfo[args.Columns.Length]; + var ngramColumns = new NgramExtractingEstimator.ColumnInfo[args.Columns.Length]; for (int iinfo = 0; iinfo < args.Columns.Length; iinfo++) { var column = args.Columns[iinfo]; - ngramColumns[iinfo] = new NgramExtractingTransformer.ColumnInfo(column.Name, + ngramColumns[iinfo] = new NgramExtractingEstimator.ColumnInfo(column.Name, column.NgramLength ?? args.NgramLength, column.SkipLength ?? args.SkipLength, column.AllLengths ?? args.AllLengths, diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs index 6056e41f2d..8825724b18 100644 --- a/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs @@ -246,7 +246,7 @@ internal bool TryUnparse(StringBuilder sb) /// /// This class is a merger of and - /// , with the ordered option, + /// , with the ordered option, /// the rehashUnigrams option and the allLength option removed. /// public abstract class ArgumentsBase @@ -331,7 +331,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa if (termLoaderArgs != null) termCols = new List(); var hashColumns = new List(); - var ngramHashColumns = new NgramHashingTransformer.ColumnInfo[options.Columns.Length]; + var ngramHashColumns = new NgramHashingEstimator.ColumnInfo[options.Columns.Length]; var colCount = options.Columns.Length; // The NGramHashExtractor has a ManyToOne column type. To avoid stepping over the source @@ -365,7 +365,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa } ngramHashColumns[iinfo] = - new NgramHashingTransformer.ColumnInfo(column.Name, tmpColNames[iinfo], + new NgramHashingEstimator.ColumnInfo(column.Name, tmpColNames[iinfo], column.NgramLength ?? options.NgramLength, column.SkipLength ?? options.SkipLength, column.AllLengths ?? options.AllLengths, diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index f823ae05c2..4cb4f9e638 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -112,7 +112,7 @@ Transforms.MissingValuesDropper Removes NAs from vector columns. Microsoft.ML.Tr Transforms.MissingValuesRowDropper Filters out rows that contain missing values. Microsoft.ML.Transforms.NAHandling Filter Microsoft.ML.Transforms.NAFilter+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.MissingValueSubstitutor Create an output column of the same type and size of the input column, where missing values are replaced with either the default value or the mean/min/max value (for non-text columns only). Microsoft.ML.Transforms.NAHandling Replace Microsoft.ML.Transforms.MissingValueReplacingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ModelCombiner Combines a sequence of TransformModels into a single model Microsoft.ML.EntryPoints.ModelOperations CombineTransformModels Microsoft.ML.EntryPoints.ModelOperations+CombineTransformModelsInput Microsoft.ML.EntryPoints.ModelOperations+CombineTransformModelsOutput -Transforms.NGramTranslator Produces a bag of counts of ngrams (sequences of consecutive values of length 1-n) in a given vector of keys. It does so by building a dictionary of ngrams and using the id in the dictionary as the index in the bag. Microsoft.ML.Transforms.Text.TextAnalytics NGramTransform Microsoft.ML.Transforms.Text.NgramExtractingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.NGramTranslator Produces a bag of counts of ngrams (sequences of consecutive values of length 1-n) in a given vector of keys. It does so by building a dictionary of ngrams and using the id in the dictionary as the index in the bag. Microsoft.ML.Transforms.Text.TextAnalytics NGramTransform Microsoft.ML.Transforms.Text.NgramExtractingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.NoOperation Does nothing. Microsoft.ML.Data.NopTransform Nop Microsoft.ML.Data.NopTransform+NopInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.OptionalColumnCreator If the source column does not exist after deserialization, create a column with the right type and default values. Microsoft.ML.Transforms.OptionalColumnTransform MakeOptional Microsoft.ML.Transforms.OptionalColumnTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.PcaCalculator PCA is a dimensionality-reduction transform which computes the projection of a numeric vector onto a low-rank subspace. Microsoft.ML.Transforms.Projections.PcaTransformer Calculate Microsoft.ML.Transforms.Projections.PcaTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput From 42cb7837ec157dce411f58f2ea739675353ef180 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 5 Feb 2019 03:18:38 +0000 Subject: [PATCH 07/10] StopWordsRemovingEstimator; CustomStopWordsRemovingEstimator; TextNormalizingEstimator --- .../TextStaticExtensions.cs | 4 +- .../Text/StopWordsRemovingTransformer.cs | 139 +++++++++--------- .../Text/TextFeaturizingEstimator.cs | 4 +- .../Text/TextNormalizing.cs | 22 +-- 4 files changed, 84 insertions(+), 85 deletions(-) diff --git a/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs index e3a0c5a462..9002c7742a 100644 --- a/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs @@ -151,9 +151,9 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); - var columns = new List(); + var columns = new List(); foreach (var outCol in toOutput) - columns.Add(new StopWordsRemovingTransformer.ColumnInfo(outputNames[outCol], inputNames[((OutPipelineColumn)outCol).Input], _language)); + columns.Add(new StopWordsRemovingEstimator.ColumnInfo(outputNames[outCol], inputNames[((OutPipelineColumn)outCol).Input], _language)); return new StopWordsRemovingEstimator(env, columns.ToArray()); } diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index 9b6ae91cca..2b2349e529 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -20,7 +20,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Transforms.Text; -[assembly: LoadableClass(StopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(StopWordsRemovingTransformer), typeof(StopWordsRemovingTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(StopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(StopWordsRemovingTransformer), typeof(StopWordsRemovingTransformer.Options), typeof(SignatureDataTransform), "Stopwords Remover Transform", "StopWordsRemoverTransform", "StopWordsRemover", "StopWords")] [assembly: LoadableClass(StopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(StopWordsRemovingTransformer), null, typeof(SignatureLoadDataTransform), @@ -32,7 +32,7 @@ [assembly: LoadableClass(typeof(IRowMapper), typeof(StopWordsRemovingTransformer), null, typeof(SignatureLoadRowMapper), "Stopwords Remover Transform", StopWordsRemovingTransformer.LoaderSignature)] -[assembly: LoadableClass(CustomStopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransformer), typeof(CustomStopWordsRemovingTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(CustomStopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransformer), typeof(CustomStopWordsRemovingTransformer.Options), typeof(SignatureDataTransform), "Custom Stopwords Remover Transform", "CustomStopWordsRemoverTransform", "CustomStopWords")] [assembly: LoadableClass(CustomStopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransformer), null, typeof(SignatureLoadDataTransform), @@ -58,7 +58,7 @@ public sealed class PredefinedStopWordsRemoverFactory : IStopWordsRemoverFactory { public IDataTransform CreateComponent(IHostEnvironment env, IDataView input, OneToOneColumn[] columns) { - return new StopWordsRemovingEstimator(env, columns.Select(x => new StopWordsRemovingTransformer.ColumnInfo(x.Name, x.Source)).ToArray()).Fit(input).Transform(input) as IDataTransform; + return new StopWordsRemovingEstimator(env, columns.Select(x => new StopWordsRemovingEstimator.ColumnInfo(x.Name, x.Source)).ToArray()).Fit(input).Transform(input) as IDataTransform; } } @@ -99,7 +99,7 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class Arguments + internal sealed class Options { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; @@ -131,9 +131,9 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(StopWordsRemovingTransformer).Assembly.FullName); } - public IReadOnlyCollection Columns => _columns.AsReadOnly(); + public IReadOnlyCollection Columns => _columns.AsReadOnly(); - private readonly ColumnInfo[] _columns; + private readonly StopWordsRemovingEstimator.ColumnInfo[] _columns; private static volatile NormStr.Pool[] _stopWords; private static volatile Dictionary, StopWordsRemovingEstimator.Language> _langsDictionary; @@ -179,38 +179,7 @@ private static NormStr.Pool[] StopWords } } - /// - /// Describes how the transformer handles one column pair. - /// - public sealed class ColumnInfo - { - public readonly string Name; - public readonly string InputColumnName; - public readonly StopWordsRemovingEstimator.Language Language; - public readonly string LanguageColumn; - - /// - /// Describes how the transformer handles one column pair. - /// - /// Name of the column resulting from the transformation of . - /// Name of the column to transform. If set to , the value of the will be used as source. - /// Language-specific stop words list. - /// Optional column to use for languages. This overrides language value. - public ColumnInfo(string name, - string inputColumnName = null, - StopWordsRemovingEstimator.Language language = StopWordsRemovingEstimator.Defaults.DefaultLanguage, - string languageColumn = null) - { - Contracts.CheckNonWhiteSpace(name, nameof(name)); - - Name = name; - InputColumnName = inputColumnName ?? name; - Language = language; - LanguageColumn = languageColumn; - } - } - - private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(StopWordsRemovingEstimator.ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); @@ -228,7 +197,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol /// /// The environment. /// Pairs of columns to remove stop words from. - public StopWordsRemovingTransformer(IHostEnvironment env, params ColumnInfo[] columns) : + internal StopWordsRemovingTransformer(IHostEnvironment env, params StopWordsRemovingEstimator.ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { _columns = columns; @@ -262,13 +231,13 @@ private StopWordsRemovingTransformer(IHost host, ModelLoadContext ctx) : // foreach column: // int: the stopwords list language // string: the id of languages column name - _columns = new ColumnInfo[columnsLength]; + _columns = new StopWordsRemovingEstimator.ColumnInfo[columnsLength]; for (int i = 0; i < columnsLength; i++) { var lang = (StopWordsRemovingEstimator.Language)ctx.Reader.ReadInt32(); Contracts.CheckDecode(Enum.IsDefined(typeof(StopWordsRemovingEstimator.Language), lang)); var langColName = ctx.LoadStringOrNull(); - _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, lang, langColName); + _columns[i] = new StopWordsRemovingEstimator.ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, lang, langColName); } } @@ -283,22 +252,22 @@ private static StopWordsRemovingTransformer Create(IHostEnvironment env, ModelLo } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Columns, nameof(args.Columns)); - var cols = new ColumnInfo[args.Columns.Length]; + env.CheckValue(options.Columns, nameof(options.Columns)); + var cols = new StopWordsRemovingEstimator.ColumnInfo[options.Columns.Length]; for (int i = 0; i < cols.Length; i++) { - var item = args.Columns[i]; - cols[i] = new ColumnInfo( + var item = options.Columns[i]; + cols[i] = new StopWordsRemovingEstimator.ColumnInfo( item.Name, item.Source ?? item.Name, - item.Language ?? args.Language, - item.LanguagesColumn ?? args.LanguagesColumn); + item.Language ?? options.Language, + item.LanguagesColumn ?? options.LanguagesColumn); } return new StopWordsRemovingTransformer(env, cols).MakeDataTransform(input); } @@ -519,6 +488,36 @@ private protected override Func GetDependenciesCore(Func a /// public sealed class StopWordsRemovingEstimator : TrivialEstimator { + /// + /// Describes how the transformer handles one column pair. + /// + public sealed class ColumnInfo + { + public readonly string Name; + public readonly string InputColumnName; + public readonly StopWordsRemovingEstimator.Language Language; + public readonly string LanguageColumn; + + /// + /// Describes how the transformer handles one column pair. + /// + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. + /// Language-specific stop words list. + /// Optional column to use for languages. This overrides language value. + public ColumnInfo(string name, + string inputColumnName = null, + StopWordsRemovingEstimator.Language language = StopWordsRemovingEstimator.Defaults.DefaultLanguage, + string languageColumn = null) + { + Contracts.CheckNonWhiteSpace(name, nameof(name)); + + Name = name; + InputColumnName = inputColumnName ?? name; + Language = language; + LanguageColumn = languageColumn; + } + } /// /// Stopwords language. This enumeration is serialized. /// @@ -562,7 +561,7 @@ public static bool IsColumnTypeValid(ColumnType type) => /// Name of the column resulting from the transformation of . /// Name of the column to transform. If set to , the value of the will be used as source. /// Langauge of the input text column . - public StopWordsRemovingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, Language language = Language.English) + internal StopWordsRemovingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, Language language = Language.English) : this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, language) { } @@ -574,12 +573,12 @@ public StopWordsRemovingEstimator(IHostEnvironment env, string outputColumnName, /// The environment. /// Pairs of columns to remove stop words on. /// Langauge of the input text columns . - public StopWordsRemovingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, Language language = Language.English) - : this(env, columns.Select(x => new StopWordsRemovingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, language)).ToArray()) + internal StopWordsRemovingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, Language language = Language.English) + : this(env, columns.Select(x => new ColumnInfo(x.outputColumnName, x.inputColumnName, language)).ToArray()) { } - public StopWordsRemovingEstimator(IHostEnvironment env, params StopWordsRemovingTransformer.ColumnInfo[] columns) + internal StopWordsRemovingEstimator(IHostEnvironment env, params ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(StopWordsRemovingEstimator)), new StopWordsRemovingTransformer(env, columns)) { } @@ -642,7 +641,7 @@ public abstract class ArgumentsBase public string StopwordsColumn; } - public sealed class Arguments : ArgumentsBase + internal sealed class Options : ArgumentsBase { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; @@ -713,7 +712,7 @@ private IDataLoader GetLoaderForStopwords(IChannel ch, string dataFile, if (isBinary || isTranspose) { ch.Assert(isBinary != isTranspose); - ch.CheckUserArg(!string.IsNullOrWhiteSpace(stopwordsCol), nameof(Arguments.StopwordsColumn), + ch.CheckUserArg(!string.IsNullOrWhiteSpace(stopwordsCol), nameof(Options.StopwordsColumn), "stopwordsColumn should be specified"); if (isBinary) dataLoader = new BinaryLoader(Host, new BinaryLoader.Arguments(), fileSource); @@ -772,7 +771,7 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory stopwords, string d warnEmpty = false; } } - ch.CheckUserArg(stopWordsMap.Count > 0, nameof(Arguments.Stopword), "stopwords is empty"); + ch.CheckUserArg(stopWordsMap.Count > 0, nameof(Options.Stopword), "stopwords is empty"); } else { @@ -780,9 +779,9 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory stopwords, string d var loader = GetLoaderForStopwords(ch, dataFile, loaderFactory, ref srcCol); if (!loader.Schema.TryGetColumnIndex(srcCol, out int colSrcIndex)) - throw ch.ExceptUserArg(nameof(Arguments.StopwordsColumn), "Unknown column '{0}'", srcCol); + throw ch.ExceptUserArg(nameof(Options.StopwordsColumn), "Unknown column '{0}'", srcCol); var typeSrc = loader.Schema[colSrcIndex].Type; - ch.CheckUserArg(typeSrc is TextType, nameof(Arguments.StopwordsColumn), "Must be a scalar text column"); + ch.CheckUserArg(typeSrc is TextType, nameof(Options.StopwordsColumn), "Must be a scalar text column"); // Accumulate the stopwords. using (var cursor = loader.GetRowCursor(loader.Schema[srcCol])) @@ -805,7 +804,7 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory stopwords, string d } } } - ch.CheckUserArg(stopWordsMap.Count > 0, nameof(Arguments.DataFile), "dataFile is empty"); + ch.CheckUserArg(stopWordsMap.Count > 0, nameof(Options.DataFile), "dataFile is empty"); } } @@ -817,7 +816,7 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory stopwords, string d /// The environment. /// Array of words to remove. /// Pairs of columns to remove stop words from. - public CustomStopWordsRemovingTransformer(IHostEnvironment env, string[] stopwords, params (string outputColumnName, string inputColumnName)[] columns) : + internal CustomStopWordsRemovingTransformer(IHostEnvironment env, string[] stopwords, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns) { _stopWordsMap = new NormStr.Pool(); @@ -938,24 +937,24 @@ private static CustomStopWordsRemovingTransformer Create(IHostEnvironment env, M } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Columns, nameof(args.Columns)); - var cols = new (string outputColumnName, string inputColumnName)[args.Columns.Length]; + env.CheckValue(options.Columns, nameof(options.Columns)); + var cols = new (string outputColumnName, string inputColumnName)[options.Columns.Length]; for (int i = 0; i < cols.Length; i++) { - var item = args.Columns[i]; + var item = options.Columns[i]; cols[i] = (item.Name, item.Source ?? item.Name); } CustomStopWordsRemovingTransformer transfrom = null; - if (Utils.Size(args.Stopwords) > 0) - transfrom = new CustomStopWordsRemovingTransformer(env, args.Stopwords, cols); + if (Utils.Size(options.Stopwords) > 0) + transfrom = new CustomStopWordsRemovingTransformer(env, options.Stopwords, cols); else - transfrom = new CustomStopWordsRemovingTransformer(env, args.Stopword, args.DataFile, args.StopwordsColumn, args.Loader, cols); + transfrom = new CustomStopWordsRemovingTransformer(env, options.Stopword, options.DataFile, options.StopwordsColumn, options.Loader, cols); return transfrom.MakeDataTransform(input); } @@ -1057,7 +1056,7 @@ public sealed class CustomStopWordsRemovingEstimator : TrivialEstimatorName of the column resulting from the transformation of . /// Name of the column to transform. If set to , the value of the will be used as source. /// Array of words to remove. - public CustomStopWordsRemovingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, params string[] stopwords) + internal CustomStopWordsRemovingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, params string[] stopwords) : this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, stopwords) { } @@ -1069,7 +1068,7 @@ public CustomStopWordsRemovingEstimator(IHostEnvironment env, string outputColum /// The environment. /// Pairs of columns to remove stop words on. /// Array of words to remove. - public CustomStopWordsRemovingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, string[] stopwords) : + internal CustomStopWordsRemovingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, string[] stopwords) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CustomStopWordsRemovingEstimator)), new CustomStopWordsRemovingTransformer(env, stopwords, columns)) { } diff --git a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs index 3b7c09a56f..ec8610f988 100644 --- a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs +++ b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs @@ -344,12 +344,12 @@ public ITransformer Fit(IDataView input) if (tparams.UsePredefinedStopWordRemover) { Contracts.Assert(wordTokCols != null, "StopWords transform requires that word tokenization has been applied to the input text."); - var xfCols = new StopWordsRemovingTransformer.ColumnInfo[wordTokCols.Length]; + var xfCols = new StopWordsRemovingEstimator.ColumnInfo[wordTokCols.Length]; var dstCols = new string[wordTokCols.Length]; for (int i = 0; i < wordTokCols.Length; i++) { var tempName = GenerateColumnName(view.Schema, wordTokCols[i], "StopWordsRemoverTransform"); - var col = new StopWordsRemovingTransformer.ColumnInfo(tempName, wordTokCols[i], tparams.StopwordsLanguage); + var col = new StopWordsRemovingEstimator.ColumnInfo(tempName, wordTokCols[i], tparams.StopwordsLanguage); dstCols[i] = tempName; tempCols.Add(tempName); diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs index 4e1df85671..b3d20583a2 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs @@ -16,7 +16,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Transforms.Text; -[assembly: LoadableClass(TextNormalizingTransformer.Summary, typeof(IDataTransform), typeof(TextNormalizingTransformer), typeof(TextNormalizingTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(TextNormalizingTransformer.Summary, typeof(IDataTransform), typeof(TextNormalizingTransformer), typeof(TextNormalizingTransformer.Options), typeof(SignatureDataTransform), "Text Normalizer Transform", "TextNormalizerTransform", "TextNormalizer", "TextNorm")] [assembly: LoadableClass(TextNormalizingTransformer.Summary, typeof(IDataTransform), typeof(TextNormalizingTransformer), null, typeof(SignatureLoadDataTransform), @@ -53,7 +53,7 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class Arguments + internal sealed class Options { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; @@ -96,7 +96,7 @@ private static VersionInfo GetVersionInfo() private readonly bool _keepPunctuations; private readonly bool _keepNumbers; - public TextNormalizingTransformer(IHostEnvironment env, + internal TextNormalizingTransformer(IHostEnvironment env, TextNormalizingEstimator.CaseNormalizationMode textCase = TextNormalizingEstimator.Defaults.TextCase, bool keepDiacritics = TextNormalizingEstimator.Defaults.KeepDiacritics, bool keepPunctuations = TextNormalizingEstimator.Defaults.KeepPunctuations, @@ -167,20 +167,20 @@ private TextNormalizingTransformer(IHost host, ModelLoadContext ctx) } // Factory method for SignatureDataTransform. - private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Columns, nameof(args.Columns)); - var cols = new (string outputColumnName, string inputColumnName)[args.Columns.Length]; + env.CheckValue(options.Columns, nameof(options.Columns)); + var cols = new (string outputColumnName, string inputColumnName)[options.Columns.Length]; for (int i = 0; i < cols.Length; i++) { - var item = args.Columns[i]; + var item = options.Columns[i]; cols[i] = (item.Name, item.Source ?? item.Name); } - return new TextNormalizingTransformer(env, args.TextCase, args.KeepDiacritics, args.KeepPunctuations, args.KeepNumbers, cols).MakeDataTransform(input); + return new TextNormalizingTransformer(env, options.TextCase, options.KeepDiacritics, options.KeepPunctuations, options.KeepNumbers, cols).MakeDataTransform(input); } // Factory method for SignatureLoadDataTransform. @@ -465,7 +465,7 @@ internal static class Defaults /// Whether to keep diacritical marks or remove them. /// Whether to keep punctuation marks or remove them. /// Whether to keep numbers or remove them. - public TextNormalizingEstimator(IHostEnvironment env, + internal TextNormalizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, CaseNormalizationMode textCase = Defaults.TextCase, @@ -486,7 +486,7 @@ public TextNormalizingEstimator(IHostEnvironment env, /// Whether to keep punctuation marks or remove them. /// Whether to keep numbers or remove them. /// Pairs of columns to run the text normalization on. - public TextNormalizingEstimator(IHostEnvironment env, + internal TextNormalizingEstimator(IHostEnvironment env, CaseNormalizationMode textCase = Defaults.TextCase, bool keepDiacritics = Defaults.KeepDiacritics, bool keepPunctuations = Defaults.KeepPunctuations, From c23c46494ecbb7f711a902ddb09432b115398d97 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 5 Feb 2019 18:48:16 +0000 Subject: [PATCH 08/10] making Column internal --- .../Text/LdaTransform.cs | 2 +- .../Text/NgramHashingTransformer.cs | 2 +- .../Text/NgramTransform.cs | 2 +- .../Text/StopWordsRemovingTransformer.cs | 4 +- .../Text/TextNormalizing.cs | 2 +- .../Text/TokenizingByCharacters.cs | 2 +- .../Text/WordBagTransform.cs | 42 +++++++++---------- .../Text/WordEmbeddingsExtractor.cs | 2 +- .../Text/WordHashBagProducingTransform.cs | 4 +- .../DataPipe/TestDataPipe.cs | 10 ++--- 10 files changed, 34 insertions(+), 38 deletions(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index c6520c6d99..2ff4437409 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -106,7 +106,7 @@ internal sealed class Options : TransformInputBase public bool OutputTopicWordSummary; } - public sealed class Column : OneToOneColumn + internal sealed class Column : OneToOneColumn { [Argument(ArgumentType.AtMostOnce, HelpText = "The number of topics")] public int? NumTopic; diff --git a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs index bb6490d283..0c857cd2c0 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs @@ -37,7 +37,7 @@ namespace Microsoft.ML.Transforms.Text /// public sealed class NgramHashingTransformer : RowToRowTransformerBase { - public sealed class Column : ManyToOneColumn + internal sealed class Column : ManyToOneColumn { [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum ngram length", ShortName = "ngram")] public int? NgramLength; diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index 593941f3df..b4c7939efb 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -38,7 +38,7 @@ namespace Microsoft.ML.Transforms.Text /// public sealed class NgramExtractingTransformer : OneToOneTransformerBase { - public sealed class Column : OneToOneColumn + internal sealed class Column : OneToOneColumn { [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum ngram length", ShortName = "ngram")] public int? NgramLength; diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index 2b2349e529..4510d6be05 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -70,7 +70,7 @@ public IDataTransform CreateComponent(IHostEnvironment env, IDataView input, One /// public sealed class StopWordsRemovingTransformer : OneToOneTransformerBase { - public sealed class Column : OneToOneColumn + internal sealed class Column : OneToOneColumn { [Argument(ArgumentType.AtMostOnce, HelpText = "Optional column to use for languages. This overrides sentence separator language value.", @@ -606,7 +606,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) /// public sealed class CustomStopWordsRemovingTransformer : OneToOneTransformerBase { - public sealed class Column : OneToOneColumn + internal sealed class Column : OneToOneColumn { internal static Column Parse(string str) { diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs index b3d20583a2..6bafe87230 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs @@ -36,7 +36,7 @@ namespace Microsoft.ML.Transforms.Text /// public sealed class TextNormalizingTransformer : OneToOneTransformerBase { - public sealed class Column : OneToOneColumn + internal sealed class Column : OneToOneColumn { internal static Column Parse(string str) { diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs index 64898a92f5..17c116ad92 100644 --- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs +++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs @@ -36,7 +36,7 @@ namespace Microsoft.ML.Transforms.Text /// public sealed class TokenizingByCharactersTransformer : OneToOneTransformerBase { - public sealed class Column : OneToOneColumn + internal sealed class Column : OneToOneColumn { internal static Column Parse(string str) { diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index c489e6ae32..7b5db1f0b2 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -42,7 +42,7 @@ public sealed class ExtractorColumn : ManyToOneColumn public static class WordBagBuildingTransformer { - public sealed class Column : ManyToOneColumn + internal sealed class Column : ManyToOneColumn { [Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length", ShortName = "ngram")] public int? NgramLength; @@ -127,7 +127,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa var tokenizeColumns = new WordTokenizingEstimator.ColumnInfo[options.Columns.Length]; var extractorArgs = - new NgramExtractorTransform.Arguments() + new NgramExtractorTransform.Options() { MaxNumTerms = options.MaxNumTerms, NgramLength = options.NgramLength, @@ -173,7 +173,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa /// public static class NgramExtractorTransform { - public sealed class Column : OneToOneColumn + internal sealed class Column : OneToOneColumn { [Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length (stores all lengths up to the specified Ngram length)", ShortName = "ngram")] public int? NgramLength; @@ -254,7 +254,7 @@ public INgramExtractorFactory CreateComponent(IHostEnvironment env, TermLoaderAr } } - public sealed class Arguments : ArgumentsBase + internal sealed class Options : ArgumentsBase { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; @@ -265,22 +265,22 @@ public sealed class Arguments : ArgumentsBase internal const string LoaderSignature = "NgramExtractor"; - public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input, TermLoaderArguments termLoaderArgs = null) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(LoaderSignature); - h.CheckValue(args, nameof(args)); + h.CheckValue(options, nameof(options)); h.CheckValue(input, nameof(input)); - h.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns), "Columns must be specified"); + h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified"); IDataView view = input; var termCols = new List(); - var isTermCol = new bool[args.Columns.Length]; + var isTermCol = new bool[options.Columns.Length]; - for (int i = 0; i < args.Columns.Length; i++) + for (int i = 0; i < options.Columns.Length; i++) { - var col = args.Columns[i]; + var col = options.Columns[i]; h.CheckNonWhiteSpace(col.Name, nameof(col.Name)); h.CheckNonWhiteSpace(col.Source, nameof(col.Source)); @@ -324,7 +324,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV termArgs = new ValueToKeyMappingTransformer.Options() { - MaxNumTerms = Utils.Size(args.MaxNumTerms) > 0 ? args.MaxNumTerms[0] : NgramExtractingEstimator.Defaults.MaxNumTerms, + MaxNumTerms = Utils.Size(options.MaxNumTerms) > 0 ? options.MaxNumTerms[0] : NgramExtractingEstimator.Defaults.MaxNumTerms, Columns = new ValueToKeyMappingTransformer.Column[termCols.Count] }; } @@ -349,16 +349,16 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV view = new MissingValueDroppingTransformer(h, missingDropColumns.Select(x => (x, x)).ToArray()).Transform(view); } - var ngramColumns = new NgramExtractingEstimator.ColumnInfo[args.Columns.Length]; - for (int iinfo = 0; iinfo < args.Columns.Length; iinfo++) + var ngramColumns = new NgramExtractingEstimator.ColumnInfo[options.Columns.Length]; + for (int iinfo = 0; iinfo < options.Columns.Length; iinfo++) { - var column = args.Columns[iinfo]; + var column = options.Columns[iinfo]; ngramColumns[iinfo] = new NgramExtractingEstimator.ColumnInfo(column.Name, - column.NgramLength ?? args.NgramLength, - column.SkipLength ?? args.SkipLength, - column.AllLengths ?? args.AllLengths, - column.Weighting ?? args.Weighting, - column.MaxNumTerms ?? args.MaxNumTerms, + column.NgramLength ?? options.NgramLength, + column.SkipLength ?? options.SkipLength, + column.AllLengths ?? options.AllLengths, + column.Weighting ?? options.Weighting, + column.MaxNumTerms ?? options.MaxNumTerms, isTermCol[iinfo] ? column.Name : column.Source ); } @@ -374,7 +374,7 @@ public static IDataTransform Create(IHostEnvironment env, NgramExtractorArgument h.CheckValue(extractorArgs, nameof(extractorArgs)); h.CheckValue(input, nameof(input)); h.CheckUserArg(extractorArgs.SkipLength < extractorArgs.NgramLength, nameof(extractorArgs.SkipLength), "Should be less than " + nameof(extractorArgs.NgramLength)); - h.CheckUserArg(Utils.Size(cols) > 0, nameof(Arguments.Columns), "Must be specified"); + h.CheckUserArg(Utils.Size(cols) > 0, nameof(Options.Columns), "Must be specified"); h.CheckValueOrNull(termLoaderArgs); var extractorCols = new Column[cols.Length]; @@ -384,7 +384,7 @@ public static IDataTransform Create(IHostEnvironment env, NgramExtractorArgument extractorCols[i] = new Column { Name = cols[i].Name, Source = cols[i].Source[0] }; } - var args = new Arguments + var args = new Options { Columns = extractorCols, NgramLength = extractorArgs.NgramLength, diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index c6874c36e2..043f315fe2 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -38,7 +38,7 @@ namespace Microsoft.ML.Transforms.Text /// public sealed class WordEmbeddingsExtractingTransformer : OneToOneTransformerBase { - public sealed class Column : OneToOneColumn + internal sealed class Column : OneToOneColumn { internal static Column Parse(string str) { diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs index 8825724b18..6c9f42d4b1 100644 --- a/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs @@ -26,7 +26,7 @@ namespace Microsoft.ML.Transforms.Text { public static class WordHashBagProducingTransformer { - public sealed class Column : NgramHashExtractingTransformer.ColumnBase + internal sealed class Column : NgramHashExtractingTransformer.ColumnBase { internal static Column Parse(string str) { @@ -193,7 +193,7 @@ public abstract class ColumnBase : ManyToOneColumn public bool? AllLengths; } - public sealed class Column : ColumnBase + internal sealed class Column : ColumnBase { // For all source columns, use these friendly names for the source // column names instead of the real column names. diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 2489efaaaf..c893c3042e 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -1362,6 +1362,7 @@ public void TestLDATransform() public void TestLdaTransformerEmptyDocumentException() { var builder = new ArrayDataViewBuilder(Env); + string colName = "Zeros"; var data = new[] { new[] { (Float)0.0, (Float)0.0, (Float)0.0 }, @@ -1369,21 +1370,16 @@ public void TestLdaTransformerEmptyDocumentException() new[] { (Float)0.0, (Float)0.0, (Float)0.0 }, }; - builder.AddColumn("Zeros", NumberType.Float, data); + builder.AddColumn(colName, NumberType.Float, data); var srcView = builder.GetDataView(); - var col = new LatentDirichletAllocationTransformer.Column() - { - Source = "Zeros", - }; - try { var lda = ML.Transforms.Text.LatentDirichletAllocation("Zeros").Fit(srcView).Transform(srcView); } catch (InvalidOperationException ex) { - Assert.Equal(ex.Message, string.Format("The specified documents are all empty in column '{0}'.", col.Source)); + Assert.Equal(ex.Message, string.Format("The specified documents are all empty in column '{0}'.", colName)); return; } From a81c3d23707b61fd97d685e36d6f929deceffb13 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 5 Feb 2019 19:54:42 +0000 Subject: [PATCH 09/10] review comments regarding adding summary comments --- .../Text/LdaTransform.cs | 48 ++++++++++++++++++- .../Text/NgramHashingTransformer.cs | 7 +++ .../Text/NgramTransform.cs | 7 +++ .../Text/WordEmbeddingsExtractor.cs | 7 +++ 4 files changed, 68 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 2ff4437409..cd7e0bd675 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -980,20 +980,62 @@ internal LatentDirichletAllocationEstimator(IHostEnvironment env, params ColumnI _columns = columns.ToImmutableArray(); } + /// + /// Describes how the transformer handles one column pair. + /// public sealed class ColumnInfo { + /// + /// Name of the column resulting from the transformation of . + /// public readonly string Name; + /// + /// Name of column to transform. If set to , the value of the will be used as source. + /// public readonly string InputColumnName; + /// + /// The number of topics. + /// public readonly int NumTopic; + /// + /// Dirichlet prior on document-topic vectors. + /// public readonly float AlphaSum; + /// + /// Dirichlet prior on vocab-topic vectors. + /// public readonly float Beta; + /// + /// Number of Metropolis Hasting step. + /// public readonly int MHStep; + /// + /// Number of iterations. + /// public readonly int NumIter; + /// + /// Compute log likelihood over local dataset on this iteration interval. + /// public readonly int LikelihoodInterval; + /// + /// The number of training threads. + /// public readonly int NumThread; + /// + /// The threshold of maximum count of tokens per doc. + /// public readonly int NumMaxDocToken; + /// + /// The number of words to summarize the topic. + /// public readonly int NumSummaryTermPerTopic; + /// + /// The number of burn-in iterations. + /// public readonly int NumBurninIter; + /// + /// Reset the random number generator for each document. + /// public readonly bool ResetRandomGenerator; /// @@ -1150,7 +1192,8 @@ internal void Save(ModelSaveContext ctx) } /// - /// Returns the schema that would be produced by the transformation. + /// Returns the of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. /// public SchemaShape GetOutputSchema(SchemaShape inputSchema) { @@ -1169,6 +1212,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } + /// + /// Trains and returns a . + /// public LatentDirichletAllocationTransformer Fit(IDataView input) { return LatentDirichletAllocationTransformer.TrainLdaTransformer(_host, input, _columns.ToArray()); diff --git a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs index 0c857cd2c0..7d01048a44 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs @@ -1202,6 +1202,10 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col) internal const string ExpectedColumnType = "Expected vector of Key type, and Key is convertible to U4"; + /// + /// Returns the of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. + /// public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); @@ -1222,6 +1226,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } + /// + /// Trains and returns a . + /// public NgramHashingTransformer Fit(IDataView input) => new NgramHashingTransformer(_host, input, _columns); } } diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index b4c7939efb..b32a96aea9 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -762,6 +762,9 @@ internal NgramExtractingEstimator(IHostEnvironment env, params ColumnInfo[] colu _columns = columns; } + /// + /// Trains and returns a . + /// public NgramExtractingTransformer Fit(IDataView input) => new NgramExtractingTransformer(_host, input, _columns); internal static bool IsColumnTypeValid(ColumnType type) @@ -865,6 +868,10 @@ internal ColumnInfo(string name, } } + /// + /// Returns the of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. + /// public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index 043f315fe2..0174fceb8c 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -850,6 +850,10 @@ public ColumnInfo(string name, string inputColumnName = null) } } + /// + /// Returns the of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. + /// public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); @@ -867,6 +871,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } + /// + /// Trains and returns a . + /// public WordEmbeddingsExtractingTransformer Fit(IDataView input) { bool customLookup = !string.IsNullOrWhiteSpace(_customLookupTable); From 7df1ed05a57f1654b668dd7227727da9c6628339 Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Wed, 6 Feb 2019 00:24:15 +0000 Subject: [PATCH 10/10] fixed CookBook + summary comment on InputColumnName --- docs/code/MlNetCookBook.md | 4 ++-- src/Microsoft.ML.Transforms/Text/LdaTransform.cs | 2 +- src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md index 61d6858cc1..1a4b423e82 100644 --- a/docs/code/MlNetCookBook.md +++ b/docs/code/MlNetCookBook.md @@ -782,7 +782,7 @@ var pipeline = // NLP pipeline 4: word embeddings. .Append(mlContext.Transforms.Text.TokenizeWords("TokenizedMessage", "NormalizedMessage")) .Append(mlContext.Transforms.Text.ExtractWordEmbeddings("Embeddings", "TokenizedMessage", - WordEmbeddingsExtractingTransformer.PretrainedModelKind.GloVeTwitter25D)); + WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVeTwitter25D)); // Let's train our pipeline, and then apply it to the same data. // Note that even on a small dataset of 70KB the pipeline above can take up to a minute to completely train. @@ -1020,4 +1020,4 @@ newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeo ITransformer loadedModel; using (var fs = File.OpenRead(modelPath)) loadedModel = newContext.Model.Load(fs); -``` \ No newline at end of file +``` diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index cd7e0bd675..7141d697ec 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -990,7 +990,7 @@ public sealed class ColumnInfo /// public readonly string Name; /// - /// Name of column to transform. If set to , the value of the will be used as source. + /// Name of column to transform. /// public readonly string InputColumnName; /// diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index 8946d34f6d..8a6c8108b5 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -842,7 +842,7 @@ public sealed class ColumnInfo /// public readonly string Name; /// - /// Name of column to transform. If set to , the value of the will be used as source. + /// Name of column to transform. /// public readonly string InputColumnName;