diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md
index 61d6858cc1..1a4b423e82 100644
--- a/docs/code/MlNetCookBook.md
+++ b/docs/code/MlNetCookBook.md
@@ -782,7 +782,7 @@ var pipeline =
// NLP pipeline 4: word embeddings.
.Append(mlContext.Transforms.Text.TokenizeWords("TokenizedMessage", "NormalizedMessage"))
.Append(mlContext.Transforms.Text.ExtractWordEmbeddings("Embeddings", "TokenizedMessage",
- WordEmbeddingsExtractingTransformer.PretrainedModelKind.GloVeTwitter25D));
+ WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVeTwitter25D));
// Let's train our pipeline, and then apply it to the same data.
// Note that even on a small dataset of 70KB the pipeline above can take up to a minute to completely train.
@@ -1020,4 +1020,4 @@ newContext.CompositionContainer = new CompositionContainer(new TypeCatalog(typeo
ITransformer loadedModel;
using (var fs = File.OpenRead(modelPath))
loadedModel = newContext.Model.Load(fs);
-```
\ No newline at end of file
+```
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs
index ceca716c34..02ff540c0d 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValueValueToKey.cs
@@ -31,7 +31,7 @@ public static void KeyToValueValueToKey()
// making use of default settings.
string defaultColumnName = "DefaultKeys";
// REVIEW create through the catalog extension
- var default_pipeline = new WordTokenizingEstimator(ml, "Review")
+ var default_pipeline = ml.Transforms.Text.TokenizeWords("Review")
.Append(ml.Transforms.Conversion.MapValueToKey(defaultColumnName, "Review"));
// Another pipeline, that customizes the advanced settings of the ValueToKeyMappingEstimator.
@@ -39,7 +39,7 @@ public static void KeyToValueValueToKey()
// and condition the order in which they get evaluated by changing sort from the default Occurence (order in which they get encountered)
// to value/alphabetically.
string customizedColumnName = "CustomizedKeys";
- var customized_pipeline = new WordTokenizingEstimator(ml, "Review")
+ var customized_pipeline = ml.Transforms.Text.TokenizeWords("Review")
.Append(ml.Transforms.Conversion.MapValueToKey(customizedColumnName, "Review", maxNumKeys: 10, sort: ValueToKeyMappingEstimator.SortOrder.Value));
// The transformed data.
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/WordEmbeddingTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/WordEmbeddingTransform.cs
index 1a47c7f5b2..7b305f47e1 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/WordEmbeddingTransform.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/WordEmbeddingTransform.cs
@@ -62,7 +62,7 @@ public static void ExtractEmbeddings()
// Let's apply pretrained word embedding model GloVeTwitter25D.
// 25D means each word mapped into 25 dimensional space, basically each word represented by 25 float values.
var gloveWordEmbedding = ml.Transforms.Text.ExtractWordEmbeddings("GloveEmbeddings", "CleanWords",
- WordEmbeddingsExtractingTransformer.PretrainedModelKind.GloVeTwitter25D);
+ WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVeTwitter25D);
// We also have option to apply custom word embedding models.
// Let's first create one.
diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
index 8833856c30..bf2af44b4c 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
@@ -180,6 +180,10 @@ protected TrivialWrapperEstimator(IHost host, TransformWrapper transformer)
{
}
+ ///
+ /// Returns the of the schema which will be produced by the transformer.
+ /// Used for schema propagation and verification in a pipeline.
+ ///
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
diff --git a/src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs
index f87e5ef4a1..22d17db2ae 100644
--- a/src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs
+++ b/src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs
@@ -101,13 +101,13 @@ public override IEstimator Reconcile(IHostEnvironment env,
IReadOnlyDictionary outputNames,
IReadOnlyCollection usedNames)
{
- var infos = new LatentDirichletAllocationTransformer.ColumnInfo[toOutput.Length];
+ var infos = new LatentDirichletAllocationEstimator.ColumnInfo[toOutput.Length];
Action onFit = null;
for (int i = 0; i < toOutput.Length; ++i)
{
var tcol = (ILdaCol)toOutput[i];
- infos[i] = new LatentDirichletAllocationTransformer.ColumnInfo(outputNames[toOutput[i]],
+ infos[i] = new LatentDirichletAllocationEstimator.ColumnInfo(outputNames[toOutput[i]],
inputNames[tcol.Input],
tcol.Config.NumTopic,
tcol.Config.AlphaSum,
diff --git a/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs
index a2155330c6..9002c7742a 100644
--- a/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs
+++ b/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs
@@ -151,9 +151,9 @@ public override IEstimator Reconcile(IHostEnvironment env,
{
Contracts.Assert(toOutput.Length == 1);
- var columns = new List();
+ var columns = new List();
foreach (var outCol in toOutput)
- columns.Add(new StopWordsRemovingTransformer.ColumnInfo(outputNames[outCol], inputNames[((OutPipelineColumn)outCol).Input], _language));
+ columns.Add(new StopWordsRemovingEstimator.ColumnInfo(outputNames[outCol], inputNames[((OutPipelineColumn)outCol).Input], _language));
return new StopWordsRemovingEstimator(env, columns.ToArray());
}
@@ -559,9 +559,9 @@ public override IEstimator Reconcile(IHostEnvironment env,
IReadOnlyCollection usedNames)
{
Contracts.Assert(toOutput.Length == 1);
- var columns = new List();
+ var columns = new List();
foreach (var outCol in toOutput)
- columns.Add(new NgramHashingTransformer.ColumnInfo(outputNames[outCol], new[] { inputNames[((OutPipelineColumn)outCol).Input] },
+ columns.Add(new NgramHashingEstimator.ColumnInfo(outputNames[outCol], new[] { inputNames[((OutPipelineColumn)outCol).Input] },
_ngramLength, _skipLength, _allLengths, _hashBits, _seed, _ordered, _invertHash));
return new NgramHashingEstimator(env, columns.ToArray());
diff --git a/src/Microsoft.ML.StaticPipe/WordEmbeddingsStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/WordEmbeddingsStaticExtensions.cs
index 1c0908d9ae..f0a2de00db 100644
--- a/src/Microsoft.ML.StaticPipe/WordEmbeddingsStaticExtensions.cs
+++ b/src/Microsoft.ML.StaticPipe/WordEmbeddingsStaticExtensions.cs
@@ -15,7 +15,7 @@ public static class WordEmbeddingsStaticExtensions
/// Vector of tokenized text.
/// The pretrained word embedding model.
///
- public static Vector WordEmbeddings(this VarVector input, WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe)
+ public static Vector WordEmbeddings(this VarVector input, WordEmbeddingsExtractingEstimator.PretrainedModelKind modelKind = WordEmbeddingsExtractingEstimator.PretrainedModelKind.Sswe)
{
Contracts.CheckValue(input, nameof(input));
return new OutColumn(input, modelKind);
@@ -34,7 +34,7 @@ private sealed class OutColumn : Vector
{
public PipelineColumn Input { get; }
- public OutColumn(VarVector input, WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe)
+ public OutColumn(VarVector input, WordEmbeddingsExtractingEstimator.PretrainedModelKind modelKind = WordEmbeddingsExtractingEstimator.PretrainedModelKind.Sswe)
: base(new Reconciler(modelKind), input)
{
Input = input;
@@ -49,10 +49,10 @@ public OutColumn(VarVector input, string customModelFile = null)
private sealed class Reconciler : EstimatorReconciler
{
- private readonly WordEmbeddingsExtractingTransformer.PretrainedModelKind? _modelKind;
+ private readonly WordEmbeddingsExtractingEstimator.PretrainedModelKind? _modelKind;
private readonly string _customLookupTable;
- public Reconciler(WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe)
+ public Reconciler(WordEmbeddingsExtractingEstimator.PretrainedModelKind modelKind = WordEmbeddingsExtractingEstimator.PretrainedModelKind.Sswe)
{
_modelKind = modelKind;
_customLookupTable = null;
@@ -72,11 +72,11 @@ public override IEstimator Reconcile(IHostEnvironment env,
{
Contracts.Assert(toOutput.Length == 1);
- var cols = new WordEmbeddingsExtractingTransformer.ColumnInfo[toOutput.Length];
+ var cols = new WordEmbeddingsExtractingEstimator.ColumnInfo[toOutput.Length];
for (int i = 0; i < toOutput.Length; ++i)
{
var outCol = (OutColumn)toOutput[i];
- cols[i] = new WordEmbeddingsExtractingTransformer.ColumnInfo(outputNames[outCol], inputNames[outCol.Input]);
+ cols[i] = new WordEmbeddingsExtractingEstimator.ColumnInfo(outputNames[outCol], inputNames[outCol.Input]);
}
bool customLookup = !string.IsNullOrWhiteSpace(_customLookupTable);
diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs
index 4f5caadc87..697b39601e 100644
--- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs
+++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs
@@ -36,7 +36,7 @@ public static CommonOutputs.TransformOutput TextTransform(IHostEnvironment env,
Desc = ML.Transforms.Text.WordTokenizingTransformer.Summary,
UserName = ML.Transforms.Text.WordTokenizingTransformer.UserName,
ShortName = ML.Transforms.Text.WordTokenizingTransformer.LoaderSignature)]
- public static CommonOutputs.TransformOutput DelimitedTokenizeTransform(IHostEnvironment env, WordTokenizingTransformer.Arguments input)
+ public static CommonOutputs.TransformOutput DelimitedTokenizeTransform(IHostEnvironment env, WordTokenizingTransformer.Options input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "DelimitedTokenizeTransform", input);
var xf = ML.Transforms.Text.WordTokenizingTransformer.Create(h, input, input.Data);
@@ -51,7 +51,7 @@ public static CommonOutputs.TransformOutput DelimitedTokenizeTransform(IHostEnvi
Desc = NgramExtractingTransformer.Summary,
UserName = NgramExtractingTransformer.UserName,
ShortName = NgramExtractingTransformer.LoaderSignature)]
- public static CommonOutputs.TransformOutput NGramTransform(IHostEnvironment env, NgramExtractingTransformer.Arguments input)
+ public static CommonOutputs.TransformOutput NGramTransform(IHostEnvironment env, NgramExtractingTransformer.Options input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "NGramTransform", input);
var xf = NgramExtractingTransformer.Create(h, input, input.Data);
@@ -96,7 +96,7 @@ public static CommonOutputs.TransformOutput AnalyzeSentiment(IHostEnvironment en
Desc = TokenizingByCharactersTransformer.Summary,
UserName = TokenizingByCharactersTransformer.UserName,
ShortName = TokenizingByCharactersTransformer.LoaderSignature)]
- public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, TokenizingByCharactersTransformer.Arguments input)
+ public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, TokenizingByCharactersTransformer.Options input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(input, nameof(input));
@@ -114,13 +114,13 @@ public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, T
Desc = LatentDirichletAllocationTransformer.Summary,
UserName = LatentDirichletAllocationTransformer.UserName,
ShortName = LatentDirichletAllocationTransformer.ShortName)]
- public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LatentDirichletAllocationTransformer.Arguments input)
+ public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LatentDirichletAllocationTransformer.Options input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(input, nameof(input));
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input);
- var cols = input.Columns.Select(colPair => new LatentDirichletAllocationTransformer.ColumnInfo(colPair, input)).ToArray();
+ var cols = input.Columns.Select(colPair => new LatentDirichletAllocationEstimator.ColumnInfo(colPair, input)).ToArray();
var est = new LatentDirichletAllocationEstimator(h, cols);
var view = est.Fit(input.Data).Transform(input.Data);
@@ -135,7 +135,7 @@ public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, Laten
Desc = WordEmbeddingsExtractingTransformer.Summary,
UserName = WordEmbeddingsExtractingTransformer.UserName,
ShortName = WordEmbeddingsExtractingTransformer.ShortName)]
- public static CommonOutputs.TransformOutput WordEmbeddings(IHostEnvironment env, WordEmbeddingsExtractingTransformer.Arguments input)
+ public static CommonOutputs.TransformOutput WordEmbeddings(IHostEnvironment env, WordEmbeddingsExtractingTransformer.Options input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(input, nameof(input));
diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs
index e9f8e5a7a9..7141d697ec 100644
--- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs
@@ -19,7 +19,7 @@
using Microsoft.ML.TextAnalytics;
using Microsoft.ML.Transforms.Text;
-[assembly: LoadableClass(LatentDirichletAllocationTransformer.Summary, typeof(IDataTransform), typeof(LatentDirichletAllocationTransformer), typeof(LatentDirichletAllocationTransformer.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(LatentDirichletAllocationTransformer.Summary, typeof(IDataTransform), typeof(LatentDirichletAllocationTransformer), typeof(LatentDirichletAllocationTransformer.Options), typeof(SignatureDataTransform),
"Latent Dirichlet Allocation Transform", LatentDirichletAllocationTransformer.LoaderSignature, "Lda")]
[assembly: LoadableClass(LatentDirichletAllocationTransformer.Summary, typeof(IDataTransform), typeof(LatentDirichletAllocationTransformer), null, typeof(SignatureLoadDataTransform),
@@ -51,7 +51,7 @@ namespace Microsoft.ML.Transforms.Text
///
public sealed class LatentDirichletAllocationTransformer : OneToOneTransformerBase
{
- public sealed class Arguments : TransformInputBase
+ internal sealed class Options : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:srcs)", Name = "Column", ShortName = "col", SortOrder = 49)]
public Column[] Columns;
@@ -106,7 +106,7 @@ public sealed class Arguments : TransformInputBase
public bool OutputTopicWordSummary;
}
- public sealed class Column : OneToOneColumn
+ internal sealed class Column : OneToOneColumn
{
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of topics")]
public int? NumTopic;
@@ -161,175 +161,6 @@ internal bool TryUnparse(StringBuilder sb)
}
}
- public sealed class ColumnInfo
- {
- public readonly string Name;
- public readonly string InputColumnName;
- public readonly int NumTopic;
- public readonly float AlphaSum;
- public readonly float Beta;
- public readonly int MHStep;
- public readonly int NumIter;
- public readonly int LikelihoodInterval;
- public readonly int NumThread;
- public readonly int NumMaxDocToken;
- public readonly int NumSummaryTermPerTopic;
- public readonly int NumBurninIter;
- public readonly bool ResetRandomGenerator;
-
- ///
- /// Describes how the transformer handles one column pair.
- ///
- /// The column containing the output scores over a set of topics, represented as a vector of floats.
- /// The column representing the document as a vector of floats.A null value for the column means is replaced.
- /// The number of topics.
- /// Dirichlet prior on document-topic vectors.
- /// Dirichlet prior on vocab-topic vectors.
- /// Number of Metropolis Hasting step.
- /// Number of iterations.
- /// Compute log likelihood over local dataset on this iteration interval.
- /// The number of training threads. Default value depends on number of logical processors.
- /// The threshold of maximum count of tokens per doc.
- /// The number of words to summarize the topic.
- /// The number of burn-in iterations.
- /// Reset the random number generator for each document.
- public ColumnInfo(string name,
- string inputColumnName = null,
- int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic,
- float alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum,
- float beta = LatentDirichletAllocationEstimator.Defaults.Beta,
- int mhStep = LatentDirichletAllocationEstimator.Defaults.Mhstep,
- int numIter = LatentDirichletAllocationEstimator.Defaults.NumIterations,
- int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval,
- int numThread = LatentDirichletAllocationEstimator.Defaults.NumThreads,
- int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken,
- int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic,
- int numBurninIter = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations,
- bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator)
- {
- Contracts.CheckValue(name, nameof(name));
- Contracts.CheckValueOrNull(inputColumnName);
- Contracts.CheckParam(numTopic > 0, nameof(numTopic), "Must be positive.");
- Contracts.CheckParam(mhStep > 0, nameof(mhStep), "Must be positive.");
- Contracts.CheckParam(numIter > 0, nameof(numIter), "Must be positive.");
- Contracts.CheckParam(likelihoodInterval > 0, nameof(likelihoodInterval), "Must be positive.");
- Contracts.CheckParam(numThread >= 0, nameof(numThread), "Must be positive or zero.");
- Contracts.CheckParam(numMaxDocToken > 0, nameof(numMaxDocToken), "Must be positive.");
- Contracts.CheckParam(numSummaryTermPerTopic > 0, nameof(numSummaryTermPerTopic), "Must be positive");
- Contracts.CheckParam(numBurninIter >= 0, nameof(numBurninIter), "Must be non-negative.");
-
- Name = name;
- InputColumnName = inputColumnName ?? name;
- NumTopic = numTopic;
- AlphaSum = alphaSum;
- Beta = beta;
- MHStep = mhStep;
- NumIter = numIter;
- LikelihoodInterval = likelihoodInterval;
- NumThread = numThread;
- NumMaxDocToken = numMaxDocToken;
- NumSummaryTermPerTopic = numSummaryTermPerTopic;
- NumBurninIter = numBurninIter;
- ResetRandomGenerator = resetRandomGenerator;
- }
-
- internal ColumnInfo(Column item, Arguments args) :
- this(item.Name,
- item.Source ?? item.Name,
- item.NumTopic ?? args.NumTopic,
- item.AlphaSum ?? args.AlphaSum,
- item.Beta ?? args.Beta,
- item.Mhstep ?? args.Mhstep,
- item.NumIterations ?? args.NumIterations,
- item.LikelihoodInterval ?? args.LikelihoodInterval,
- item.NumThreads ?? args.NumThreads,
- item.NumMaxDocToken ?? args.NumMaxDocToken,
- item.NumSummaryTermPerTopic ?? args.NumSummaryTermPerTopic,
- item.NumBurninIterations ?? args.NumBurninIterations,
- item.ResetRandomGenerator ?? args.ResetRandomGenerator)
- {
- }
-
- internal ColumnInfo(IExceptionContext ectx, ModelLoadContext ctx)
- {
- Contracts.AssertValue(ectx);
- ectx.AssertValue(ctx);
-
- // *** Binary format ***
- // int NumTopic;
- // float AlphaSum;
- // float Beta;
- // int MHStep;
- // int NumIter;
- // int LikelihoodInterval;
- // int NumThread;
- // int NumMaxDocToken;
- // int NumSummaryTermPerTopic;
- // int NumBurninIter;
- // byte ResetRandomGenerator;
-
- NumTopic = ctx.Reader.ReadInt32();
- ectx.CheckDecode(NumTopic > 0);
-
- AlphaSum = ctx.Reader.ReadSingle();
-
- Beta = ctx.Reader.ReadSingle();
-
- MHStep = ctx.Reader.ReadInt32();
- ectx.CheckDecode(MHStep > 0);
-
- NumIter = ctx.Reader.ReadInt32();
- ectx.CheckDecode(NumIter > 0);
-
- LikelihoodInterval = ctx.Reader.ReadInt32();
- ectx.CheckDecode(LikelihoodInterval > 0);
-
- NumThread = ctx.Reader.ReadInt32();
- ectx.CheckDecode(NumThread >= 0);
-
- NumMaxDocToken = ctx.Reader.ReadInt32();
- ectx.CheckDecode(NumMaxDocToken > 0);
-
- NumSummaryTermPerTopic = ctx.Reader.ReadInt32();
- ectx.CheckDecode(NumSummaryTermPerTopic > 0);
-
- NumBurninIter = ctx.Reader.ReadInt32();
- ectx.CheckDecode(NumBurninIter >= 0);
-
- ResetRandomGenerator = ctx.Reader.ReadBoolByte();
- }
-
- internal void Save(ModelSaveContext ctx)
- {
- Contracts.AssertValue(ctx);
-
- // *** Binary format ***
- // int NumTopic;
- // float AlphaSum;
- // float Beta;
- // int MHStep;
- // int NumIter;
- // int LikelihoodInterval;
- // int NumThread;
- // int NumMaxDocToken;
- // int NumSummaryTermPerTopic;
- // int NumBurninIter;
- // byte ResetRandomGenerator;
-
- ctx.Writer.Write(NumTopic);
- ctx.Writer.Write(AlphaSum);
- ctx.Writer.Write(Beta);
- ctx.Writer.Write(MHStep);
- ctx.Writer.Write(NumIter);
- ctx.Writer.Write(LikelihoodInterval);
- ctx.Writer.Write(NumThread);
- ctx.Writer.Write(NumMaxDocToken);
- ctx.Writer.Write(NumSummaryTermPerTopic);
- ctx.Writer.Write(NumBurninIter);
- ctx.Writer.WriteBoolByte(ResetRandomGenerator);
- }
- }
-
///
/// Provide details about the topics discovered by LightLDA.
///
@@ -365,7 +196,7 @@ internal LdaSummary GetLdaDetails(int iinfo)
private sealed class LdaState : IDisposable
{
- internal readonly ColumnInfo InfoEx;
+ internal readonly LatentDirichletAllocationEstimator.ColumnInfo InfoEx;
private readonly int _numVocab;
private readonly object _preparationSyncRoot;
private readonly object _testSyncRoot;
@@ -378,7 +209,7 @@ private LdaState()
_testSyncRoot = new object();
}
- internal LdaState(IExceptionContext ectx, ColumnInfo ex, int numVocab)
+ internal LdaState(IExceptionContext ectx, LatentDirichletAllocationEstimator.ColumnInfo ex, int numVocab)
: this()
{
Contracts.AssertValue(ectx);
@@ -415,7 +246,7 @@ internal LdaState(IExceptionContext ectx, ModelLoadContext ctx)
// (serializing term by term, for one term)
// int: term_id, int: topic_num, KeyValuePair[]: termTopicVector
- InfoEx = new ColumnInfo(ectx, ctx);
+ InfoEx = new LatentDirichletAllocationEstimator.ColumnInfo(ectx, ctx);
_numVocab = ctx.Reader.ReadInt32();
ectx.CheckDecode(_numVocab > 0);
@@ -771,7 +602,7 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(LatentDirichletAllocationTransformer).Assembly.FullName);
}
- private readonly ColumnInfo[] _columns;
+ private readonly LatentDirichletAllocationEstimator.ColumnInfo[] _columns;
private readonly LdaState[] _ldas;
private readonly List>> _columnMappings;
@@ -781,7 +612,7 @@ private static VersionInfo GetVersionInfo()
internal const string UserName = "Latent Dirichlet Allocation Transform";
internal const string ShortName = "LightLda";
- private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns)
+ private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(LatentDirichletAllocationEstimator.ColumnInfo[] columns)
{
Contracts.CheckValue(columns, nameof(columns));
return columns.Select(x => (x.Name, x.InputColumnName)).ToArray();
@@ -797,7 +628,7 @@ private static (string outputColumnName, string inputColumnName)[] GetColumnPair
private LatentDirichletAllocationTransformer(IHostEnvironment env,
LdaState[] ldas,
List>> columnMappings,
- params ColumnInfo[] columns)
+ params LatentDirichletAllocationEstimator.ColumnInfo[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LatentDirichletAllocationTransformer)), GetColumnPairs(columns))
{
Host.AssertNonEmpty(ColumnPairs);
@@ -817,7 +648,7 @@ private LatentDirichletAllocationTransformer(IHost host, ModelLoadContext ctx) :
// Note: columnsLength would be just one in most cases.
var columnsLength = ColumnPairs.Length;
- _columns = new ColumnInfo[columnsLength];
+ _columns = new LatentDirichletAllocationEstimator.ColumnInfo[columnsLength];
_ldas = new LdaState[columnsLength];
for (int i = 0; i < _ldas.Length; i++)
{
@@ -826,7 +657,7 @@ private LatentDirichletAllocationTransformer(IHost host, ModelLoadContext ctx) :
}
}
- internal static LatentDirichletAllocationTransformer TrainLdaTransformer(IHostEnvironment env, IDataView inputData, params ColumnInfo[] columns)
+ internal static LatentDirichletAllocationTransformer TrainLdaTransformer(IHostEnvironment env, IDataView inputData, params LatentDirichletAllocationEstimator.ColumnInfo[] columns)
{
var ldas = new LdaState[columns.Length];
@@ -869,14 +700,14 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sch
=> Create(env, ctx).MakeRowMapper(inputSchema);
// Factory method for SignatureDataTransform.
- private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- env.CheckValue(args.Columns, nameof(args.Columns));
+ env.CheckValue(options.Columns, nameof(options.Columns));
- var cols = args.Columns.Select(colPair => new ColumnInfo(colPair, args)).ToArray();
+ var cols = options.Columns.Select(colPair => new LatentDirichletAllocationEstimator.ColumnInfo(colPair, options)).ToArray();
return TrainLdaTransformer(env, input, cols).MakeDataTransform(input);
}
@@ -929,7 +760,7 @@ private static int GetFrequency(double value)
return result;
}
- private static List>> Train(IHostEnvironment env, IChannel ch, IDataView inputData, LdaState[] states, params ColumnInfo[] columns)
+ private static List>> Train(IHostEnvironment env, IChannel ch, IDataView inputData, LdaState[] states, params LatentDirichletAllocationEstimator.ColumnInfo[] columns)
{
env.AssertValue(ch);
ch.AssertValue(inputData);
@@ -1104,7 +935,7 @@ internal static class Defaults
}
private readonly IHost _host;
- private readonly ImmutableArray _columns;
+ private readonly ImmutableArray _columns;
///
/// The environment.
@@ -1121,7 +952,7 @@ internal static class Defaults
/// The number of words to summarize the topic.
/// The number of burn-in iterations.
/// Reset the random number generator for each document.
- public LatentDirichletAllocationEstimator(IHostEnvironment env,
+ internal LatentDirichletAllocationEstimator(IHostEnvironment env,
string outputColumnName, string inputColumnName = null,
int numTopic = Defaults.NumTopic,
float alphaSum = Defaults.AlphaSum,
@@ -1134,7 +965,7 @@ public LatentDirichletAllocationEstimator(IHostEnvironment env,
int numSummaryTermPerTopic = Defaults.NumSummaryTermPerTopic,
int numBurninIterations = Defaults.NumBurninIterations,
bool resetRandomGenerator = Defaults.ResetRandomGenerator)
- : this(env, new[] { new LatentDirichletAllocationTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName,
+ : this(env, new[] { new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName,
numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken,
numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator) })
{ }
@@ -1142,7 +973,7 @@ public LatentDirichletAllocationEstimator(IHostEnvironment env,
///
/// The environment.
/// Describes the parameters of the LDA process for each column pair.
- public LatentDirichletAllocationEstimator(IHostEnvironment env, params LatentDirichletAllocationTransformer.ColumnInfo[] columns)
+ internal LatentDirichletAllocationEstimator(IHostEnvironment env, params ColumnInfo[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(LatentDirichletAllocationEstimator));
@@ -1150,7 +981,219 @@ public LatentDirichletAllocationEstimator(IHostEnvironment env, params LatentDir
}
///
- /// Returns the schema that would be produced by the transformation.
+ /// Describes how the transformer handles one column pair.
+ ///
+ public sealed class ColumnInfo
+ {
+ ///
+ /// Name of the column resulting from the transformation of .
+ ///
+ public readonly string Name;
+ ///
+ /// Name of column to transform.
+ ///
+ public readonly string InputColumnName;
+ ///
+ /// The number of topics.
+ ///
+ public readonly int NumTopic;
+ ///
+ /// Dirichlet prior on document-topic vectors.
+ ///
+ public readonly float AlphaSum;
+ ///
+ /// Dirichlet prior on vocab-topic vectors.
+ ///
+ public readonly float Beta;
+ ///
+ /// Number of Metropolis Hasting step.
+ ///
+ public readonly int MHStep;
+ ///
+ /// Number of iterations.
+ ///
+ public readonly int NumIter;
+ ///
+ /// Compute log likelihood over local dataset on this iteration interval.
+ ///
+ public readonly int LikelihoodInterval;
+ ///
+ /// The number of training threads.
+ ///
+ public readonly int NumThread;
+ ///
+ /// The threshold of maximum count of tokens per doc.
+ ///
+ public readonly int NumMaxDocToken;
+ ///
+ /// The number of words to summarize the topic.
+ ///
+ public readonly int NumSummaryTermPerTopic;
+ ///
+ /// The number of burn-in iterations.
+ ///
+ public readonly int NumBurninIter;
+ ///
+ /// Reset the random number generator for each document.
+ ///
+ public readonly bool ResetRandomGenerator;
+
+ ///
+ /// Describes how the transformer handles one column pair.
+ ///
+ /// The column containing the output scores over a set of topics, represented as a vector of floats.
+ /// The column representing the document as a vector of floats.A null value for the column means is replaced.
+ /// The number of topics.
+ /// Dirichlet prior on document-topic vectors.
+ /// Dirichlet prior on vocab-topic vectors.
+ /// Number of Metropolis Hasting step.
+ /// Number of iterations.
+ /// Compute log likelihood over local dataset on this iteration interval.
+ /// The number of training threads. Default value depends on number of logical processors.
+ /// The threshold of maximum count of tokens per doc.
+ /// The number of words to summarize the topic.
+ /// The number of burn-in iterations.
+ /// Reset the random number generator for each document.
+ public ColumnInfo(string name,
+ string inputColumnName = null,
+ int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic,
+ float alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum,
+ float beta = LatentDirichletAllocationEstimator.Defaults.Beta,
+ int mhStep = LatentDirichletAllocationEstimator.Defaults.Mhstep,
+ int numIter = LatentDirichletAllocationEstimator.Defaults.NumIterations,
+ int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval,
+ int numThread = LatentDirichletAllocationEstimator.Defaults.NumThreads,
+ int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken,
+ int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic,
+ int numBurninIter = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations,
+ bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator)
+ {
+ Contracts.CheckValue(name, nameof(name));
+ Contracts.CheckValueOrNull(inputColumnName);
+ Contracts.CheckParam(numTopic > 0, nameof(numTopic), "Must be positive.");
+ Contracts.CheckParam(mhStep > 0, nameof(mhStep), "Must be positive.");
+ Contracts.CheckParam(numIter > 0, nameof(numIter), "Must be positive.");
+ Contracts.CheckParam(likelihoodInterval > 0, nameof(likelihoodInterval), "Must be positive.");
+ Contracts.CheckParam(numThread >= 0, nameof(numThread), "Must be positive or zero.");
+ Contracts.CheckParam(numMaxDocToken > 0, nameof(numMaxDocToken), "Must be positive.");
+ Contracts.CheckParam(numSummaryTermPerTopic > 0, nameof(numSummaryTermPerTopic), "Must be positive");
+ Contracts.CheckParam(numBurninIter >= 0, nameof(numBurninIter), "Must be non-negative.");
+
+ Name = name;
+ InputColumnName = inputColumnName ?? name;
+ NumTopic = numTopic;
+ AlphaSum = alphaSum;
+ Beta = beta;
+ MHStep = mhStep;
+ NumIter = numIter;
+ LikelihoodInterval = likelihoodInterval;
+ NumThread = numThread;
+ NumMaxDocToken = numMaxDocToken;
+ NumSummaryTermPerTopic = numSummaryTermPerTopic;
+ NumBurninIter = numBurninIter;
+ ResetRandomGenerator = resetRandomGenerator;
+ }
+
+ internal ColumnInfo(LatentDirichletAllocationTransformer.Column item, LatentDirichletAllocationTransformer.Options options) :
+ this(item.Name,
+ item.Source ?? item.Name,
+ item.NumTopic ?? options.NumTopic,
+ item.AlphaSum ?? options.AlphaSum,
+ item.Beta ?? options.Beta,
+ item.Mhstep ?? options.Mhstep,
+ item.NumIterations ?? options.NumIterations,
+ item.LikelihoodInterval ?? options.LikelihoodInterval,
+ item.NumThreads ?? options.NumThreads,
+ item.NumMaxDocToken ?? options.NumMaxDocToken,
+ item.NumSummaryTermPerTopic ?? options.NumSummaryTermPerTopic,
+ item.NumBurninIterations ?? options.NumBurninIterations,
+ item.ResetRandomGenerator ?? options.ResetRandomGenerator)
+ {
+ }
+
+ internal ColumnInfo(IExceptionContext ectx, ModelLoadContext ctx)
+ {
+ Contracts.AssertValue(ectx);
+ ectx.AssertValue(ctx);
+
+ // *** Binary format ***
+ // int NumTopic;
+ // float AlphaSum;
+ // float Beta;
+ // int MHStep;
+ // int NumIter;
+ // int LikelihoodInterval;
+ // int NumThread;
+ // int NumMaxDocToken;
+ // int NumSummaryTermPerTopic;
+ // int NumBurninIter;
+ // byte ResetRandomGenerator;
+
+ NumTopic = ctx.Reader.ReadInt32();
+ ectx.CheckDecode(NumTopic > 0);
+
+ AlphaSum = ctx.Reader.ReadSingle();
+
+ Beta = ctx.Reader.ReadSingle();
+
+ MHStep = ctx.Reader.ReadInt32();
+ ectx.CheckDecode(MHStep > 0);
+
+ NumIter = ctx.Reader.ReadInt32();
+ ectx.CheckDecode(NumIter > 0);
+
+ LikelihoodInterval = ctx.Reader.ReadInt32();
+ ectx.CheckDecode(LikelihoodInterval > 0);
+
+ NumThread = ctx.Reader.ReadInt32();
+ ectx.CheckDecode(NumThread >= 0);
+
+ NumMaxDocToken = ctx.Reader.ReadInt32();
+ ectx.CheckDecode(NumMaxDocToken > 0);
+
+ NumSummaryTermPerTopic = ctx.Reader.ReadInt32();
+ ectx.CheckDecode(NumSummaryTermPerTopic > 0);
+
+ NumBurninIter = ctx.Reader.ReadInt32();
+ ectx.CheckDecode(NumBurninIter >= 0);
+
+ ResetRandomGenerator = ctx.Reader.ReadBoolByte();
+ }
+
+ internal void Save(ModelSaveContext ctx)
+ {
+ Contracts.AssertValue(ctx);
+
+ // *** Binary format ***
+ // int NumTopic;
+ // float AlphaSum;
+ // float Beta;
+ // int MHStep;
+ // int NumIter;
+ // int LikelihoodInterval;
+ // int NumThread;
+ // int NumMaxDocToken;
+ // int NumSummaryTermPerTopic;
+ // int NumBurninIter;
+ // byte ResetRandomGenerator;
+
+ ctx.Writer.Write(NumTopic);
+ ctx.Writer.Write(AlphaSum);
+ ctx.Writer.Write(Beta);
+ ctx.Writer.Write(MHStep);
+ ctx.Writer.Write(NumIter);
+ ctx.Writer.Write(LikelihoodInterval);
+ ctx.Writer.Write(NumThread);
+ ctx.Writer.Write(NumMaxDocToken);
+ ctx.Writer.Write(NumSummaryTermPerTopic);
+ ctx.Writer.Write(NumBurninIter);
+ ctx.Writer.WriteBoolByte(ResetRandomGenerator);
+ }
+ }
+
+ ///
+ /// Returns the of the schema which will be produced by the transformer.
+ /// Used for schema propagation and verification in a pipeline.
///
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
@@ -1169,6 +1212,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
return new SchemaShape(result.Values);
}
+ ///
+ /// Trains and returns a .
+ ///
public LatentDirichletAllocationTransformer Fit(IDataView input)
{
return LatentDirichletAllocationTransformer.TrainLdaTransformer(_host, input, _columns.ToArray());
diff --git a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs
index cb76b256e7..7d01048a44 100644
--- a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs
+++ b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs
@@ -17,7 +17,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Transforms.Text;
-[assembly: LoadableClass(NgramHashingTransformer.Summary, typeof(IDataTransform), typeof(NgramHashingTransformer), typeof(NgramHashingTransformer.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(NgramHashingTransformer.Summary, typeof(IDataTransform), typeof(NgramHashingTransformer), typeof(NgramHashingTransformer.Options), typeof(SignatureDataTransform),
"Ngram Hash Transform", "NgramHashTransform", "NgramHash")]
[assembly: LoadableClass(NgramHashingTransformer.Summary, typeof(IDataTransform), typeof(NgramHashingTransformer), null, typeof(SignatureLoadDataTransform),
@@ -37,7 +37,7 @@ namespace Microsoft.ML.Transforms.Text
///
public sealed class NgramHashingTransformer : RowToRowTransformerBase
{
- public sealed class Column : ManyToOneColumn
+ internal sealed class Column : ManyToOneColumn
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum ngram length", ShortName = "ngram")]
public int? NgramLength;
@@ -112,7 +112,7 @@ internal bool TryUnparse(StringBuilder sb)
}
}
- public sealed class Arguments
+ internal sealed class Options
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:hashBits:src)",
ShortName = "col",
@@ -173,180 +173,7 @@ private static VersionInfo GetVersionInfo()
private const int VersionTransformer = 0x00010003;
- ///
- /// Describes how the transformer handles one pair of mulitple inputs - singular output columns.
- ///
- public sealed class ColumnInfo
- {
- public readonly string Name;
- public readonly string[] InputColumnNames;
- public readonly int NgramLength;
- public readonly int SkipLength;
- public readonly bool AllLengths;
- public readonly int HashBits;
- public readonly uint Seed;
- public readonly bool Ordered;
- public readonly int InvertHash;
- public readonly bool RehashUnigrams;
- // For all source columns, use these friendly names for the source
- // column names instead of the real column names.
- internal string[] FriendlyNames;
-
- ///
- /// Describes how the transformer handles one column pair.
- ///
- /// Name of the column resulting from the transformation of .
- /// Name of the columns to transform.
- /// Maximum ngram length.
- /// Maximum number of tokens to skip when constructing an ngram.
- /// "Whether to store all ngram lengths up to ngramLength, or only ngramLength.
- /// Number of bits to hash into. Must be between 1 and 31, inclusive.
- /// Hashing seed.
- /// Whether the position of each term should be included in the hash.
- /// During hashing we constuct mappings between original values and the produced hash values.
- /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one.
- /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained.
- /// 0 does not retain any input values. -1 retains all input values mapping to each hash.
- /// Whether to rehash unigrams.
- public ColumnInfo(string name,
- string[] inputColumnNames,
- int ngramLength = NgramHashingEstimator.Defaults.NgramLength,
- int skipLength = NgramHashingEstimator.Defaults.SkipLength,
- bool allLengths = NgramHashingEstimator.Defaults.AllLengths,
- int hashBits = NgramHashingEstimator.Defaults.HashBits,
- uint seed = NgramHashingEstimator.Defaults.Seed,
- bool ordered = NgramHashingEstimator.Defaults.Ordered,
- int invertHash = NgramHashingEstimator.Defaults.InvertHash,
- bool rehashUnigrams = NgramHashingEstimator.Defaults.RehashUnigrams)
- {
- Contracts.CheckValue(name, nameof(name));
- Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames));
- Contracts.CheckParam(!inputColumnNames.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputColumnNames),
- "Contained some null or empty items");
- if (invertHash < -1)
- throw Contracts.ExceptParam(nameof(invertHash), "Value too small, must be -1 or larger");
- // If the bits is 31 or higher, we can't declare a KeyValues of the appropriate length,
- // this requiring a VBuffer of length 1u << 31 which exceeds int.MaxValue.
- if (invertHash != 0 && hashBits >= 31)
- throw Contracts.ExceptParam(nameof(hashBits), $"Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", hashBits);
-
- if (NgramLength + SkipLength > NgramBufferBuilder.MaxSkipNgramLength)
- {
- throw Contracts.ExceptUserArg(nameof(skipLength),
- $"The sum of skipLength and ngramLength must be less than or equal to {NgramBufferBuilder.MaxSkipNgramLength}");
- }
- FriendlyNames = null;
- Name = name;
- InputColumnNames = inputColumnNames;
- NgramLength = ngramLength;
- SkipLength = skipLength;
- AllLengths = allLengths;
- HashBits = hashBits;
- Seed = seed;
- Ordered = ordered;
- InvertHash = invertHash;
- RehashUnigrams = rehashUnigrams;
- }
-
- internal ColumnInfo(ModelLoadContext ctx)
- {
- Contracts.AssertValue(ctx);
-
- // *** Binary format ***
- // size of Inputs
- // string[] Inputs;
- // string Output;
- // int: NgramLength
- // int: SkipLength
- // int: HashBits
- // uint: Seed
- // byte: Rehash
- // byte: Ordered
- // byte: AllLengths
- var inputsLength = ctx.Reader.ReadInt32();
- InputColumnNames = new string[inputsLength];
- for (int i = 0; i < InputColumnNames.Length; i++)
- InputColumnNames[i] = ctx.LoadNonEmptyString();
- Name = ctx.LoadNonEmptyString();
- NgramLength = ctx.Reader.ReadInt32();
- Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength);
- SkipLength = ctx.Reader.ReadInt32();
- Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
- Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength);
- HashBits = ctx.Reader.ReadInt32();
- Contracts.CheckDecode(1 <= HashBits && HashBits <= 30);
- Seed = ctx.Reader.ReadUInt32();
- RehashUnigrams = ctx.Reader.ReadBoolByte();
- Ordered = ctx.Reader.ReadBoolByte();
- AllLengths = ctx.Reader.ReadBoolByte();
- }
-
- internal ColumnInfo(ModelLoadContext ctx, string name, string[] inputColumnNames)
- {
- Contracts.AssertValue(ctx);
- Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames));
- Contracts.CheckParam(!inputColumnNames.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputColumnNames),
- "Contained some null or empty items");
- InputColumnNames = inputColumnNames;
- Name = name;
- // *** Binary format ***
- // string Output;
- // int: NgramLength
- // int: SkipLength
- // int: HashBits
- // uint: Seed
- // byte: Rehash
- // byte: Ordered
- // byte: AllLengths
- NgramLength = ctx.Reader.ReadInt32();
- Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength);
- SkipLength = ctx.Reader.ReadInt32();
- Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
- Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength);
- HashBits = ctx.Reader.ReadInt32();
- Contracts.CheckDecode(1 <= HashBits && HashBits <= 30);
- Seed = ctx.Reader.ReadUInt32();
- RehashUnigrams = ctx.Reader.ReadBoolByte();
- Ordered = ctx.Reader.ReadBoolByte();
- AllLengths = ctx.Reader.ReadBoolByte();
- }
-
- internal void Save(ModelSaveContext ctx)
- {
- Contracts.AssertValue(ctx);
-
- // *** Binary format ***
- // size of Inputs
- // string[] Inputs;
- // string Output;
- // int: NgramLength
- // int: SkipLength
- // int: HashBits
- // uint: Seed
- // byte: Rehash
- // byte: Ordered
- // byte: AllLengths
- Contracts.Assert(InputColumnNames.Length > 0);
- ctx.Writer.Write(InputColumnNames.Length);
- for (int i = 0; i < InputColumnNames.Length; i++)
- ctx.SaveNonEmptyString(InputColumnNames[i]);
- ctx.SaveNonEmptyString(Name);
-
- Contracts.Assert(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength);
- ctx.Writer.Write(NgramLength);
- Contracts.Assert(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
- Contracts.Assert(NgramLength + SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
- ctx.Writer.Write(SkipLength);
- Contracts.Assert(1 <= HashBits && HashBits <= 30);
- ctx.Writer.Write(HashBits);
- ctx.Writer.Write(Seed);
- ctx.Writer.WriteBoolByte(RehashUnigrams);
- ctx.Writer.WriteBoolByte(Ordered);
- ctx.Writer.WriteBoolByte(AllLengths);
- }
- }
-
- private readonly ImmutableArray _columns;
+ private readonly ImmutableArray _columns;
private readonly VBuffer>[] _slotNames;
private readonly VectorType[] _slotNamesTypes;
@@ -355,7 +182,7 @@ internal void Save(ModelSaveContext ctx)
///
/// Host Environment.
/// Description of dataset columns and how to process them.
- public NgramHashingTransformer(IHostEnvironment env, params ColumnInfo[] columns) :
+ internal NgramHashingTransformer(IHostEnvironment env, params NgramHashingEstimator.ColumnInfo[] columns) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NgramHashingTransformer)))
{
_columns = columns.ToImmutableArray();
@@ -366,7 +193,7 @@ public NgramHashingTransformer(IHostEnvironment env, params ColumnInfo[] columns
}
}
- internal NgramHashingTransformer(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) :
+ internal NgramHashingTransformer(IHostEnvironment env, IDataView input, params NgramHashingEstimator.ColumnInfo[] columns) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NgramHashingTransformer)))
{
Contracts.CheckValue(columns, nameof(columns));
@@ -463,14 +290,14 @@ private NgramHashingTransformer(IHostEnvironment env, ModelLoadContext ctx, bool
}
var columnsLength = ctx.Reader.ReadInt32();
Contracts.CheckDecode(columnsLength > 0);
- var columns = new ColumnInfo[columnsLength];
+ var columns = new NgramHashingEstimator.ColumnInfo[columnsLength];
if (!loadLegacy)
{
// *** Binary format ***
// int number of columns
// columns
for (int i = 0; i < columnsLength; i++)
- columns[i] = new ColumnInfo(ctx);
+ columns[i] = new NgramHashingEstimator.ColumnInfo(ctx);
}
else
{
@@ -500,38 +327,38 @@ private NgramHashingTransformer(IHostEnvironment env, ModelLoadContext ctx, bool
// int number of columns
// columns
for (int i = 0; i < columnsLength; i++)
- columns[i] = new ColumnInfo(ctx, outputs[i], inputs[i]);
+ columns[i] = new NgramHashingEstimator.ColumnInfo(ctx, outputs[i], inputs[i]);
}
_columns = columns.ToImmutableArray();
TextModelHelper.LoadAll(Host, ctx, columnsLength, out _slotNames, out _slotNamesTypes);
}
// Factory method for SignatureDataTransform.
- private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- env.CheckValue(args.Column, nameof(args.Column));
- var cols = new ColumnInfo[args.Column.Length];
+ env.CheckValue(options.Column, nameof(options.Column));
+ var cols = new NgramHashingEstimator.ColumnInfo[options.Column.Length];
using (var ch = env.Start("ValidateArgs"))
{
for (int i = 0; i < cols.Length; i++)
{
- var item = args.Column[i];
- cols[i] = new ColumnInfo(
+ var item = options.Column[i];
+ cols[i] = new NgramHashingEstimator.ColumnInfo(
item.Name,
item.Source ?? new string[] { item.Name },
- item.NgramLength ?? args.NgramLength,
- item.SkipLength ?? args.SkipLength,
- item.AllLengths ?? args.AllLengths,
- item.HashBits ?? args.HashBits,
- item.Seed ?? args.Seed,
- item.Ordered ?? args.Ordered,
- item.InvertHash ?? args.InvertHash,
- item.RehashUnigrams ?? args.RehashUnigrams
+ item.NgramLength ?? options.NgramLength,
+ item.SkipLength ?? options.SkipLength,
+ item.AllLengths ?? options.AllLengths,
+ item.HashBits ?? options.HashBits,
+ item.Seed ?? options.Seed,
+ item.Ordered ?? options.Ordered,
+ item.InvertHash ?? options.InvertHash,
+ item.RehashUnigrams ?? options.RehashUnigrams
);
};
}
@@ -1044,6 +871,179 @@ public VBuffer>[] SlotNamesMetadata(out VectorType[] types)
///
public sealed class NgramHashingEstimator : IEstimator
{
+ ///
+ /// Describes how the transformer handles one pair of mulitple inputs - singular output columns.
+ ///
+ public sealed class ColumnInfo
+ {
+ public readonly string Name;
+ public readonly string[] InputColumnNames;
+ public readonly int NgramLength;
+ public readonly int SkipLength;
+ public readonly bool AllLengths;
+ public readonly int HashBits;
+ public readonly uint Seed;
+ public readonly bool Ordered;
+ public readonly int InvertHash;
+ public readonly bool RehashUnigrams;
+ // For all source columns, use these friendly names for the source
+ // column names instead of the real column names.
+ internal string[] FriendlyNames;
+
+ ///
+ /// Describes how the transformer handles one column pair.
+ ///
+ /// Name of the column resulting from the transformation of .
+ /// Name of the columns to transform.
+ /// Maximum ngram length.
+ /// Maximum number of tokens to skip when constructing an ngram.
+ /// "Whether to store all ngram lengths up to ngramLength, or only ngramLength.
+ /// Number of bits to hash into. Must be between 1 and 31, inclusive.
+ /// Hashing seed.
+ /// Whether the position of each term should be included in the hash.
+ /// During hashing we constuct mappings between original values and the produced hash values.
+ /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one.
+ /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained.
+ /// 0 does not retain any input values. -1 retains all input values mapping to each hash.
+ /// Whether to rehash unigrams.
+ public ColumnInfo(string name,
+ string[] inputColumnNames,
+ int ngramLength = NgramHashingEstimator.Defaults.NgramLength,
+ int skipLength = NgramHashingEstimator.Defaults.SkipLength,
+ bool allLengths = NgramHashingEstimator.Defaults.AllLengths,
+ int hashBits = NgramHashingEstimator.Defaults.HashBits,
+ uint seed = NgramHashingEstimator.Defaults.Seed,
+ bool ordered = NgramHashingEstimator.Defaults.Ordered,
+ int invertHash = NgramHashingEstimator.Defaults.InvertHash,
+ bool rehashUnigrams = NgramHashingEstimator.Defaults.RehashUnigrams)
+ {
+ Contracts.CheckValue(name, nameof(name));
+ Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames));
+ Contracts.CheckParam(!inputColumnNames.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputColumnNames),
+ "Contained some null or empty items");
+ if (invertHash < -1)
+ throw Contracts.ExceptParam(nameof(invertHash), "Value too small, must be -1 or larger");
+ // If the bits is 31 or higher, we can't declare a KeyValues of the appropriate length,
+ // this requiring a VBuffer of length 1u << 31 which exceeds int.MaxValue.
+ if (invertHash != 0 && hashBits >= 31)
+ throw Contracts.ExceptParam(nameof(hashBits), $"Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", hashBits);
+
+ if (NgramLength + SkipLength > NgramBufferBuilder.MaxSkipNgramLength)
+ {
+ throw Contracts.ExceptUserArg(nameof(skipLength),
+ $"The sum of skipLength and ngramLength must be less than or equal to {NgramBufferBuilder.MaxSkipNgramLength}");
+ }
+ FriendlyNames = null;
+ Name = name;
+ InputColumnNames = inputColumnNames;
+ NgramLength = ngramLength;
+ SkipLength = skipLength;
+ AllLengths = allLengths;
+ HashBits = hashBits;
+ Seed = seed;
+ Ordered = ordered;
+ InvertHash = invertHash;
+ RehashUnigrams = rehashUnigrams;
+ }
+
+ internal ColumnInfo(ModelLoadContext ctx)
+ {
+ Contracts.AssertValue(ctx);
+
+ // *** Binary format ***
+ // size of Inputs
+ // string[] Inputs;
+ // string Output;
+ // int: NgramLength
+ // int: SkipLength
+ // int: HashBits
+ // uint: Seed
+ // byte: Rehash
+ // byte: Ordered
+ // byte: AllLengths
+ var inputsLength = ctx.Reader.ReadInt32();
+ InputColumnNames = new string[inputsLength];
+ for (int i = 0; i < InputColumnNames.Length; i++)
+ InputColumnNames[i] = ctx.LoadNonEmptyString();
+ Name = ctx.LoadNonEmptyString();
+ NgramLength = ctx.Reader.ReadInt32();
+ Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength);
+ SkipLength = ctx.Reader.ReadInt32();
+ Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
+ Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength);
+ HashBits = ctx.Reader.ReadInt32();
+ Contracts.CheckDecode(1 <= HashBits && HashBits <= 30);
+ Seed = ctx.Reader.ReadUInt32();
+ RehashUnigrams = ctx.Reader.ReadBoolByte();
+ Ordered = ctx.Reader.ReadBoolByte();
+ AllLengths = ctx.Reader.ReadBoolByte();
+ }
+
+ internal ColumnInfo(ModelLoadContext ctx, string name, string[] inputColumnNames)
+ {
+ Contracts.AssertValue(ctx);
+ Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames));
+ Contracts.CheckParam(!inputColumnNames.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputColumnNames),
+ "Contained some null or empty items");
+ InputColumnNames = inputColumnNames;
+ Name = name;
+ // *** Binary format ***
+ // string Output;
+ // int: NgramLength
+ // int: SkipLength
+ // int: HashBits
+ // uint: Seed
+ // byte: Rehash
+ // byte: Ordered
+ // byte: AllLengths
+ NgramLength = ctx.Reader.ReadInt32();
+ Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength);
+ SkipLength = ctx.Reader.ReadInt32();
+ Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
+ Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength);
+ HashBits = ctx.Reader.ReadInt32();
+ Contracts.CheckDecode(1 <= HashBits && HashBits <= 30);
+ Seed = ctx.Reader.ReadUInt32();
+ RehashUnigrams = ctx.Reader.ReadBoolByte();
+ Ordered = ctx.Reader.ReadBoolByte();
+ AllLengths = ctx.Reader.ReadBoolByte();
+ }
+
+ internal void Save(ModelSaveContext ctx)
+ {
+ Contracts.AssertValue(ctx);
+
+ // *** Binary format ***
+ // size of Inputs
+ // string[] Inputs;
+ // string Output;
+ // int: NgramLength
+ // int: SkipLength
+ // int: HashBits
+ // uint: Seed
+ // byte: Rehash
+ // byte: Ordered
+ // byte: AllLengths
+ Contracts.Assert(InputColumnNames.Length > 0);
+ ctx.Writer.Write(InputColumnNames.Length);
+ for (int i = 0; i < InputColumnNames.Length; i++)
+ ctx.SaveNonEmptyString(InputColumnNames[i]);
+ ctx.SaveNonEmptyString(Name);
+
+ Contracts.Assert(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength);
+ ctx.Writer.Write(NgramLength);
+ Contracts.Assert(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
+ Contracts.Assert(NgramLength + SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
+ ctx.Writer.Write(SkipLength);
+ Contracts.Assert(1 <= HashBits && HashBits <= 30);
+ ctx.Writer.Write(HashBits);
+ ctx.Writer.Write(Seed);
+ ctx.Writer.WriteBoolByte(RehashUnigrams);
+ ctx.Writer.WriteBoolByte(Ordered);
+ ctx.Writer.WriteBoolByte(AllLengths);
+ }
+ }
+
internal static class Defaults
{
internal const int NgramLength = 2;
@@ -1057,7 +1057,7 @@ internal static class Defaults
}
private readonly IHost _host;
- private readonly NgramHashingTransformer.ColumnInfo[] _columns;
+ private readonly ColumnInfo[] _columns;
///
/// Produces a bag of counts of hashed ngrams in
@@ -1079,7 +1079,7 @@ internal static class Defaults
/// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one.
/// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained.
/// 0 does not retain any input values. -1 retains all input values mapping to each hash.
- public NgramHashingEstimator(IHostEnvironment env,
+ internal NgramHashingEstimator(IHostEnvironment env,
string outputColumnName,
string inputColumnName = null,
int hashBits = 16,
@@ -1113,7 +1113,7 @@ public NgramHashingEstimator(IHostEnvironment env,
/// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one.
/// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained.
/// 0 does not retain any input values. -1 retains all input values mapping to each hash.
- public NgramHashingEstimator(IHostEnvironment env,
+ internal NgramHashingEstimator(IHostEnvironment env,
string outputColumnName,
string[] inputColumnNames,
int hashBits = 16,
@@ -1146,7 +1146,7 @@ public NgramHashingEstimator(IHostEnvironment env,
/// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one.
/// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained.
/// 0 does not retain any input values. -1 retains all input values mapping to each hash.
- public NgramHashingEstimator(IHostEnvironment env,
+ internal NgramHashingEstimator(IHostEnvironment env,
(string outputColumnName, string[] inputColumnName)[] columns,
int hashBits = 16,
int ngramLength = 2,
@@ -1155,7 +1155,7 @@ public NgramHashingEstimator(IHostEnvironment env,
uint seed = 314489979,
bool ordered = true,
int invertHash = 0)
- : this(env, columns.Select(x => new NgramHashingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, hashBits, seed, ordered, invertHash)).ToArray())
+ : this(env, columns.Select(x => new ColumnInfo(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, hashBits, seed, ordered, invertHash)).ToArray())
{
}
@@ -1169,7 +1169,7 @@ public NgramHashingEstimator(IHostEnvironment env,
///
/// The environment.
/// Array of columns which specifies the behavior of the transformation.
- public NgramHashingEstimator(IHostEnvironment env, params NgramHashingTransformer.ColumnInfo[] columns)
+ internal NgramHashingEstimator(IHostEnvironment env, params ColumnInfo[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(NgramHashingEstimator));
@@ -1202,6 +1202,10 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col)
internal const string ExpectedColumnType = "Expected vector of Key type, and Key is convertible to U4";
+ ///
+ /// Returns the of the schema which will be produced by the transformer.
+ /// Used for schema propagation and verification in a pipeline.
+ ///
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
@@ -1222,6 +1226,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
return new SchemaShape(result.Values);
}
+ ///
+ /// Trains and returns a .
+ ///
public NgramHashingTransformer Fit(IDataView input) => new NgramHashingTransformer(_host, input, _columns);
}
}
diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs
index 3e3371ef44..b32a96aea9 100644
--- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs
@@ -18,7 +18,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Transforms.Text;
-[assembly: LoadableClass(NgramExtractingTransformer.Summary, typeof(IDataTransform), typeof(NgramExtractingTransformer), typeof(NgramExtractingTransformer.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(NgramExtractingTransformer.Summary, typeof(IDataTransform), typeof(NgramExtractingTransformer), typeof(NgramExtractingTransformer.Options), typeof(SignatureDataTransform),
"Ngram Transform", "NgramTransform", "Ngram")]
[assembly: LoadableClass(NgramExtractingTransformer.Summary, typeof(IDataTransform), typeof(NgramExtractingTransformer), null, typeof(SignatureLoadDataTransform),
@@ -38,7 +38,7 @@ namespace Microsoft.ML.Transforms.Text
///
public sealed class NgramExtractingTransformer : OneToOneTransformerBase
{
- public sealed class Column : OneToOneColumn
+ internal sealed class Column : OneToOneColumn
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum ngram length", ShortName = "ngram")]
public int? NgramLength;
@@ -77,7 +77,7 @@ internal bool TryUnparse(StringBuilder sb)
}
}
- public sealed class Arguments : TransformInputBase
+ internal sealed class Options : TransformInputBase
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
@@ -122,81 +122,6 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(NgramExtractingTransformer).Assembly.FullName);
}
- ///
- /// Describes how the transformer handles one column pair.
- ///
- public sealed class ColumnInfo
- {
- public readonly string Name;
- public readonly string InputColumnName;
- public readonly int NgramLength;
- public readonly int SkipLength;
- public readonly bool AllLengths;
- public readonly NgramExtractingEstimator.WeightingCriteria Weighting;
- ///
- /// Contains the maximum number of grams to store in the dictionary, for each level of ngrams,
- /// from 1 (in position 0) up to ngramLength (in position ngramLength-1)
- ///
- public readonly ImmutableArray Limits;
-
- ///
- /// Describes how the transformer handles one Gcn column pair.
- ///
- /// Name of the column resulting from the transformation of .
- /// Name of column to transform. If set to , the value of the will be used as source.
- /// Maximum ngram length.
- /// Maximum number of tokens to skip when constructing an ngram.
- /// "Whether to store all ngram lengths up to ngramLength, or only ngramLength.
- /// The weighting criteria.
- /// Maximum number of ngrams to store in the dictionary.
- public ColumnInfo(string name, string inputColumnName = null,
- int ngramLength = NgramExtractingEstimator.Defaults.NgramLength,
- int skipLength = NgramExtractingEstimator.Defaults.SkipLength,
- bool allLengths = NgramExtractingEstimator.Defaults.AllLengths,
- NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.Defaults.Weighting,
- int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms)
- : this(name, ngramLength, skipLength, allLengths, weighting, new int[] { maxNumTerms }, inputColumnName ?? name)
- {
- }
-
- internal ColumnInfo(string name,
- int ngramLength,
- int skipLength,
- bool allLengths,
- NgramExtractingEstimator.WeightingCriteria weighting,
- int[] maxNumTerms,
- string inputColumnName = null)
- {
- Name = name;
- InputColumnName = inputColumnName ?? name;
- NgramLength = ngramLength;
- Contracts.CheckUserArg(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength, nameof(ngramLength));
- SkipLength = skipLength;
- if (NgramLength + SkipLength > NgramBufferBuilder.MaxSkipNgramLength)
- {
- throw Contracts.ExceptUserArg(nameof(skipLength),
- $"The sum of skipLength and ngramLength must be less than or equal to {NgramBufferBuilder.MaxSkipNgramLength}");
- }
- AllLengths = allLengths;
- Weighting = weighting;
- var limits = new int[ngramLength];
- if (!AllLengths)
- {
- Contracts.CheckUserArg(Utils.Size(maxNumTerms) == 0 ||
- Utils.Size(maxNumTerms) == 1 && maxNumTerms[0] > 0, nameof(maxNumTerms));
- limits[ngramLength - 1] = Utils.Size(maxNumTerms) == 0 ? NgramExtractingEstimator.Defaults.MaxNumTerms : maxNumTerms[0];
- }
- else
- {
- Contracts.CheckUserArg(Utils.Size(maxNumTerms) <= ngramLength, nameof(maxNumTerms));
- Contracts.CheckUserArg(Utils.Size(maxNumTerms) == 0 || maxNumTerms.All(i => i >= 0) && maxNumTerms[maxNumTerms.Length - 1] > 0, nameof(maxNumTerms));
- var extend = Utils.Size(maxNumTerms) == 0 ? NgramExtractingEstimator.Defaults.MaxNumTerms : maxNumTerms[maxNumTerms.Length - 1];
- limits = Utils.BuildArray(ngramLength, i => i < Utils.Size(maxNumTerms) ? maxNumTerms[i] : extend);
- }
- Limits = ImmutableArray.Create(limits);
- }
- }
-
private sealed class TransformInfo
{
// Position i, indicates whether the pool contains any (i+1)-grams
@@ -207,7 +132,7 @@ private sealed class TransformInfo
public bool RequireIdf => Weighting == NgramExtractingEstimator.WeightingCriteria.Idf || Weighting == NgramExtractingEstimator.WeightingCriteria.TfIdf;
- public TransformInfo(ColumnInfo info)
+ public TransformInfo(NgramExtractingEstimator.ColumnInfo info)
{
NgramLength = info.NgramLength;
SkipLength = info.SkipLength;
@@ -267,7 +192,7 @@ public void Save(ModelSaveContext ctx)
// Ngram inverse document frequencies
private readonly double[][] _invDocFreqs;
- private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns)
+ private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(NgramExtractingEstimator.ColumnInfo[] columns)
{
Contracts.CheckValue(columns, nameof(columns));
return columns.Select(x => (x.Name, x.InputColumnName)).ToArray();
@@ -280,7 +205,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, NgramExtractingEstimator.ExpectedColumnType, type.ToString());
}
- internal NgramExtractingTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] columns)
+ internal NgramExtractingTransformer(IHostEnvironment env, IDataView input, NgramExtractingEstimator.ColumnInfo[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NgramExtractingTransformer)), GetColumnPairs(columns))
{
var transformInfos = new TransformInfo[columns.Length];
@@ -294,7 +219,7 @@ internal NgramExtractingTransformer(IHostEnvironment env, IDataView input, Colum
_ngramMaps = Train(Host, columns, _transformInfos, input, out _invDocFreqs);
}
- private static SequencePool[] Train(IHostEnvironment env, ColumnInfo[] columns, ImmutableArray transformInfos, IDataView trainingData, out double[][] invDocFreqs)
+ private static SequencePool[] Train(IHostEnvironment env, NgramExtractingEstimator.ColumnInfo[] columns, ImmutableArray transformInfos, IDataView trainingData, out double[][] invDocFreqs)
{
var helpers = new NgramBufferBuilder[columns.Length];
var getters = new ValueGetter>[columns.Length];
@@ -482,27 +407,27 @@ private NgramExtractingTransformer(IHost host, ModelLoadContext ctx) :
}
// Factory method for SignatureDataTransform.
- internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- env.CheckValue(args.Columns, nameof(args.Columns));
- var cols = new ColumnInfo[args.Columns.Length];
+ env.CheckValue(options.Columns, nameof(options.Columns));
+ var cols = new NgramExtractingEstimator.ColumnInfo[options.Columns.Length];
using (var ch = env.Start("ValidateArgs"))
{
for (int i = 0; i < cols.Length; i++)
{
- var item = args.Columns[i];
- var maxNumTerms = Utils.Size(item.MaxNumTerms) > 0 ? item.MaxNumTerms : args.MaxNumTerms;
- cols[i] = new ColumnInfo(
+ var item = options.Columns[i];
+ var maxNumTerms = Utils.Size(item.MaxNumTerms) > 0 ? item.MaxNumTerms : options.MaxNumTerms;
+ cols[i] = new NgramExtractingEstimator.ColumnInfo(
item.Name,
- item.NgramLength ?? args.NgramLength,
- item.SkipLength ?? args.SkipLength,
- item.AllLengths ?? args.AllLengths,
- item.Weighting ?? args.Weighting,
+ item.NgramLength ?? options.NgramLength,
+ item.SkipLength ?? options.SkipLength,
+ item.AllLengths ?? options.AllLengths,
+ item.Weighting ?? options.Weighting,
maxNumTerms,
item.Source ?? item.Name);
};
@@ -777,7 +702,7 @@ internal static class Defaults
}
private readonly IHost _host;
- private readonly NgramExtractingTransformer.ColumnInfo[] _columns;
+ private readonly ColumnInfo[] _columns;
///
/// Produces a bag of counts of ngrams (sequences of consecutive words) in
@@ -791,7 +716,7 @@ internal static class Defaults
/// Whether to include all ngram lengths up to or only .
/// Maximum number of ngrams to store in the dictionary.
/// Statistical measure used to evaluate how important a word is to a document in a corpus.
- public NgramExtractingEstimator(IHostEnvironment env,
+ internal NgramExtractingEstimator(IHostEnvironment env,
string outputColumnName, string inputColumnName = null,
int ngramLength = Defaults.NgramLength,
int skipLength = Defaults.SkipLength,
@@ -813,14 +738,14 @@ public NgramExtractingEstimator(IHostEnvironment env,
/// Whether to include all ngram lengths up to or only .
/// Maximum number of ngrams to store in the dictionary.
/// Statistical measure used to evaluate how important a word is to a document in a corpus.
- public NgramExtractingEstimator(IHostEnvironment env,
+ internal NgramExtractingEstimator(IHostEnvironment env,
(string outputColumnName, string inputColumnName)[] columns,
int ngramLength = Defaults.NgramLength,
int skipLength = Defaults.SkipLength,
bool allLengths = Defaults.AllLengths,
int maxNumTerms = Defaults.MaxNumTerms,
WeightingCriteria weighting = Defaults.Weighting)
- : this(env, columns.Select(x => new NgramExtractingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, weighting, maxNumTerms)).ToArray())
+ : this(env, columns.Select(x => new ColumnInfo(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, weighting, maxNumTerms)).ToArray())
{
}
@@ -830,13 +755,16 @@ public NgramExtractingEstimator(IHostEnvironment env,
///
/// The environment.
/// Array of columns with information how to transform data.
- public NgramExtractingEstimator(IHostEnvironment env, params NgramExtractingTransformer.ColumnInfo[] columns)
+ internal NgramExtractingEstimator(IHostEnvironment env, params ColumnInfo[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(NgramExtractingEstimator));
_columns = columns;
}
+ ///
+ /// Trains and returns a .
+ ///
public NgramExtractingTransformer Fit(IDataView input) => new NgramExtractingTransformer(_host, input, _columns);
internal static bool IsColumnTypeValid(ColumnType type)
@@ -865,6 +793,85 @@ internal static bool IsSchemaColumnValid(SchemaShape.Column col)
internal const string ExpectedColumnType = "Expected vector of Key type, and Key is convertible to U4";
+ ///
+ /// Describes how the transformer handles one column pair.
+ ///
+ public sealed class ColumnInfo
+ {
+ public readonly string Name;
+ public readonly string InputColumnName;
+ public readonly int NgramLength;
+ public readonly int SkipLength;
+ public readonly bool AllLengths;
+ public readonly NgramExtractingEstimator.WeightingCriteria Weighting;
+ ///
+ /// Contains the maximum number of grams to store in the dictionary, for each level of ngrams,
+ /// from 1 (in position 0) up to ngramLength (in position ngramLength-1)
+ ///
+ public readonly ImmutableArray Limits;
+
+ ///
+ /// Describes how the transformer handles one Gcn column pair.
+ ///
+ /// Name of the column resulting from the transformation of .
+ /// Name of column to transform. If set to , the value of the will be used as source.
+ /// Maximum ngram length.
+ /// Maximum number of tokens to skip when constructing an ngram.
+ /// "Whether to store all ngram lengths up to ngramLength, or only ngramLength.
+ /// The weighting criteria.
+ /// Maximum number of ngrams to store in the dictionary.
+ public ColumnInfo(string name, string inputColumnName = null,
+ int ngramLength = NgramExtractingEstimator.Defaults.NgramLength,
+ int skipLength = NgramExtractingEstimator.Defaults.SkipLength,
+ bool allLengths = NgramExtractingEstimator.Defaults.AllLengths,
+ NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.Defaults.Weighting,
+ int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms)
+ : this(name, ngramLength, skipLength, allLengths, weighting, new int[] { maxNumTerms }, inputColumnName ?? name)
+ {
+ }
+
+ internal ColumnInfo(string name,
+ int ngramLength,
+ int skipLength,
+ bool allLengths,
+ NgramExtractingEstimator.WeightingCriteria weighting,
+ int[] maxNumTerms,
+ string inputColumnName = null)
+ {
+ Name = name;
+ InputColumnName = inputColumnName ?? name;
+ NgramLength = ngramLength;
+ Contracts.CheckUserArg(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength, nameof(ngramLength));
+ SkipLength = skipLength;
+ if (NgramLength + SkipLength > NgramBufferBuilder.MaxSkipNgramLength)
+ {
+ throw Contracts.ExceptUserArg(nameof(skipLength),
+ $"The sum of skipLength and ngramLength must be less than or equal to {NgramBufferBuilder.MaxSkipNgramLength}");
+ }
+ AllLengths = allLengths;
+ Weighting = weighting;
+ var limits = new int[ngramLength];
+ if (!AllLengths)
+ {
+ Contracts.CheckUserArg(Utils.Size(maxNumTerms) == 0 ||
+ Utils.Size(maxNumTerms) == 1 && maxNumTerms[0] > 0, nameof(maxNumTerms));
+ limits[ngramLength - 1] = Utils.Size(maxNumTerms) == 0 ? NgramExtractingEstimator.Defaults.MaxNumTerms : maxNumTerms[0];
+ }
+ else
+ {
+ Contracts.CheckUserArg(Utils.Size(maxNumTerms) <= ngramLength, nameof(maxNumTerms));
+ Contracts.CheckUserArg(Utils.Size(maxNumTerms) == 0 || maxNumTerms.All(i => i >= 0) && maxNumTerms[maxNumTerms.Length - 1] > 0, nameof(maxNumTerms));
+ var extend = Utils.Size(maxNumTerms) == 0 ? NgramExtractingEstimator.Defaults.MaxNumTerms : maxNumTerms[maxNumTerms.Length - 1];
+ limits = Utils.BuildArray(ngramLength, i => i < Utils.Size(maxNumTerms) ? maxNumTerms[i] : extend);
+ }
+ Limits = ImmutableArray.Create(limits);
+ }
+ }
+
+ ///
+ /// Returns the of the schema which will be produced by the transformer.
+ /// Used for schema propagation and verification in a pipeline.
+ ///
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs
index e4da959d74..84917a4cf6 100644
--- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs
+++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs
@@ -20,7 +20,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Transforms.Text;
-[assembly: LoadableClass(StopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(StopWordsRemovingTransformer), typeof(StopWordsRemovingTransformer.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(StopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(StopWordsRemovingTransformer), typeof(StopWordsRemovingTransformer.Options), typeof(SignatureDataTransform),
"Stopwords Remover Transform", "StopWordsRemoverTransform", "StopWordsRemover", "StopWords")]
[assembly: LoadableClass(StopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(StopWordsRemovingTransformer), null, typeof(SignatureLoadDataTransform),
@@ -32,7 +32,7 @@
[assembly: LoadableClass(typeof(IRowMapper), typeof(StopWordsRemovingTransformer), null, typeof(SignatureLoadRowMapper),
"Stopwords Remover Transform", StopWordsRemovingTransformer.LoaderSignature)]
-[assembly: LoadableClass(CustomStopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransformer), typeof(CustomStopWordsRemovingTransformer.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(CustomStopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransformer), typeof(CustomStopWordsRemovingTransformer.Options), typeof(SignatureDataTransform),
"Custom Stopwords Remover Transform", "CustomStopWordsRemoverTransform", "CustomStopWords")]
[assembly: LoadableClass(CustomStopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransformer), null, typeof(SignatureLoadDataTransform),
@@ -58,7 +58,7 @@ internal sealed class PredefinedStopWordsRemoverFactory : IStopWordsRemoverFacto
{
public IDataTransform CreateComponent(IHostEnvironment env, IDataView input, OneToOneColumn[] columns)
{
- return new StopWordsRemovingEstimator(env, columns.Select(x => new StopWordsRemovingTransformer.ColumnInfo(x.Name, x.Source)).ToArray()).Fit(input).Transform(input) as IDataTransform;
+ return new StopWordsRemovingEstimator(env, columns.Select(x => new StopWordsRemovingEstimator.ColumnInfo(x.Name, x.Source)).ToArray()).Fit(input).Transform(input) as IDataTransform;
}
}
@@ -99,7 +99,7 @@ internal bool TryUnparse(StringBuilder sb)
}
}
- internal sealed class Arguments
+ internal sealed class Options
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
@@ -131,9 +131,9 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(StopWordsRemovingTransformer).Assembly.FullName);
}
- public IReadOnlyCollection Columns => _columns.AsReadOnly();
+ public IReadOnlyCollection Columns => _columns.AsReadOnly();
- private readonly ColumnInfo[] _columns;
+ private readonly StopWordsRemovingEstimator.ColumnInfo[] _columns;
private static volatile NormStr.Pool[] _stopWords;
private static volatile Dictionary, StopWordsRemovingEstimator.Language> _langsDictionary;
@@ -179,38 +179,7 @@ private static NormStr.Pool[] StopWords
}
}
- ///
- /// Describes how the transformer handles one column pair.
- ///
- public sealed class ColumnInfo
- {
- public readonly string Name;
- public readonly string InputColumnName;
- public readonly StopWordsRemovingEstimator.Language Language;
- public readonly string LanguageColumn;
-
- ///
- /// Describes how the transformer handles one column pair.
- ///
- /// Name of the column resulting from the transformation of .
- /// Name of the column to transform. If set to , the value of the will be used as source.
- /// Language-specific stop words list.
- /// Optional column to use for languages. This overrides language value.
- public ColumnInfo(string name,
- string inputColumnName = null,
- StopWordsRemovingEstimator.Language language = StopWordsRemovingEstimator.Defaults.DefaultLanguage,
- string languageColumn = null)
- {
- Contracts.CheckNonWhiteSpace(name, nameof(name));
-
- Name = name;
- InputColumnName = inputColumnName ?? name;
- Language = language;
- LanguageColumn = languageColumn;
- }
- }
-
- private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns)
+ private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(StopWordsRemovingEstimator.ColumnInfo[] columns)
{
Contracts.CheckValue(columns, nameof(columns));
return columns.Select(x => (x.Name, x.InputColumnName)).ToArray();
@@ -228,7 +197,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol
///
/// The environment.
/// Pairs of columns to remove stop words from.
- public StopWordsRemovingTransformer(IHostEnvironment env, params ColumnInfo[] columns) :
+ internal StopWordsRemovingTransformer(IHostEnvironment env, params StopWordsRemovingEstimator.ColumnInfo[] columns) :
base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
{
_columns = columns;
@@ -262,13 +231,13 @@ private StopWordsRemovingTransformer(IHost host, ModelLoadContext ctx) :
// foreach column:
// int: the stopwords list language
// string: the id of languages column name
- _columns = new ColumnInfo[columnsLength];
+ _columns = new StopWordsRemovingEstimator.ColumnInfo[columnsLength];
for (int i = 0; i < columnsLength; i++)
{
var lang = (StopWordsRemovingEstimator.Language)ctx.Reader.ReadInt32();
Contracts.CheckDecode(Enum.IsDefined(typeof(StopWordsRemovingEstimator.Language), lang));
var langColName = ctx.LoadStringOrNull();
- _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, lang, langColName);
+ _columns[i] = new StopWordsRemovingEstimator.ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, lang, langColName);
}
}
@@ -283,22 +252,22 @@ private static StopWordsRemovingTransformer Create(IHostEnvironment env, ModelLo
}
// Factory method for SignatureDataTransform.
- internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- env.CheckValue(args.Columns, nameof(args.Columns));
- var cols = new ColumnInfo[args.Columns.Length];
+ env.CheckValue(options.Columns, nameof(options.Columns));
+ var cols = new StopWordsRemovingEstimator.ColumnInfo[options.Columns.Length];
for (int i = 0; i < cols.Length; i++)
{
- var item = args.Columns[i];
- cols[i] = new ColumnInfo(
+ var item = options.Columns[i];
+ cols[i] = new StopWordsRemovingEstimator.ColumnInfo(
item.Name,
item.Source ?? item.Name,
- item.Language ?? args.Language,
- item.LanguagesColumn ?? args.LanguagesColumn);
+ item.Language ?? options.Language,
+ item.LanguagesColumn ?? options.LanguagesColumn);
}
return new StopWordsRemovingTransformer(env, cols).MakeDataTransform(input);
}
@@ -519,6 +488,36 @@ private protected override Func GetDependenciesCore(Func a
///
public sealed class StopWordsRemovingEstimator : TrivialEstimator
{
+ ///
+ /// Describes how the transformer handles one column pair.
+ ///
+ public sealed class ColumnInfo
+ {
+ public readonly string Name;
+ public readonly string InputColumnName;
+ public readonly StopWordsRemovingEstimator.Language Language;
+ public readonly string LanguageColumn;
+
+ ///
+ /// Describes how the transformer handles one column pair.
+ ///
+ /// Name of the column resulting from the transformation of .
+ /// Name of the column to transform. If set to , the value of the will be used as source.
+ /// Language-specific stop words list.
+ /// Optional column to use for languages. This overrides language value.
+ public ColumnInfo(string name,
+ string inputColumnName = null,
+ StopWordsRemovingEstimator.Language language = StopWordsRemovingEstimator.Defaults.DefaultLanguage,
+ string languageColumn = null)
+ {
+ Contracts.CheckNonWhiteSpace(name, nameof(name));
+
+ Name = name;
+ InputColumnName = inputColumnName ?? name;
+ Language = language;
+ LanguageColumn = languageColumn;
+ }
+ }
///
/// Stopwords language. This enumeration is serialized.
///
@@ -549,7 +548,7 @@ internal static class Defaults
public const Language DefaultLanguage = Language.English;
}
- public static bool IsColumnTypeValid(ColumnType type) =>
+ internal static bool IsColumnTypeValid(ColumnType type) =>
type is VectorType vectorType && vectorType.ItemType is TextType;
internal const string ExpectedColumnType = "vector of Text type";
@@ -562,7 +561,7 @@ public static bool IsColumnTypeValid(ColumnType type) =>
/// Name of the column resulting from the transformation of .
/// Name of the column to transform. If set to , the value of the will be used as source.
/// Langauge of the input text column .
- public StopWordsRemovingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, Language language = Language.English)
+ internal StopWordsRemovingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, Language language = Language.English)
: this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, language)
{
}
@@ -574,16 +573,20 @@ public StopWordsRemovingEstimator(IHostEnvironment env, string outputColumnName,
/// The environment.
/// Pairs of columns to remove stop words on.
/// Langauge of the input text columns .
- public StopWordsRemovingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, Language language = Language.English)
- : this(env, columns.Select(x => new StopWordsRemovingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, language)).ToArray())
+ internal StopWordsRemovingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, Language language = Language.English)
+ : this(env, columns.Select(x => new ColumnInfo(x.outputColumnName, x.inputColumnName, language)).ToArray())
{
}
- public StopWordsRemovingEstimator(IHostEnvironment env, params StopWordsRemovingTransformer.ColumnInfo[] columns)
+ internal StopWordsRemovingEstimator(IHostEnvironment env, params ColumnInfo[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(StopWordsRemovingEstimator)), new StopWordsRemovingTransformer(env, columns))
{
}
+ ///
+ /// Returns the of the schema which will be produced by the transformer.
+ /// Used for schema propagation and verification in a pipeline.
+ ///
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
@@ -642,7 +645,7 @@ internal abstract class ArgumentsBase
public string StopwordsColumn;
}
- internal sealed class Arguments : ArgumentsBase
+ internal sealed class Options : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
@@ -713,7 +716,7 @@ private IDataLoader GetLoaderForStopwords(IChannel ch, string dataFile,
if (isBinary || isTranspose)
{
ch.Assert(isBinary != isTranspose);
- ch.CheckUserArg(!string.IsNullOrWhiteSpace(stopwordsCol), nameof(Arguments.StopwordsColumn),
+ ch.CheckUserArg(!string.IsNullOrWhiteSpace(stopwordsCol), nameof(Options.StopwordsColumn),
"stopwordsColumn should be specified");
if (isBinary)
dataLoader = new BinaryLoader(Host, new BinaryLoader.Arguments(), fileSource);
@@ -772,7 +775,7 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory stopwords, string d
warnEmpty = false;
}
}
- ch.CheckUserArg(stopWordsMap.Count > 0, nameof(Arguments.Stopword), "stopwords is empty");
+ ch.CheckUserArg(stopWordsMap.Count > 0, nameof(Options.Stopword), "stopwords is empty");
}
else
{
@@ -780,9 +783,9 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory stopwords, string d
var loader = GetLoaderForStopwords(ch, dataFile, loaderFactory, ref srcCol);
if (!loader.Schema.TryGetColumnIndex(srcCol, out int colSrcIndex))
- throw ch.ExceptUserArg(nameof(Arguments.StopwordsColumn), "Unknown column '{0}'", srcCol);
+ throw ch.ExceptUserArg(nameof(Options.StopwordsColumn), "Unknown column '{0}'", srcCol);
var typeSrc = loader.Schema[colSrcIndex].Type;
- ch.CheckUserArg(typeSrc is TextType, nameof(Arguments.StopwordsColumn), "Must be a scalar text column");
+ ch.CheckUserArg(typeSrc is TextType, nameof(Options.StopwordsColumn), "Must be a scalar text column");
// Accumulate the stopwords.
using (var cursor = loader.GetRowCursor(loader.Schema[srcCol]))
@@ -805,7 +808,7 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory stopwords, string d
}
}
}
- ch.CheckUserArg(stopWordsMap.Count > 0, nameof(Arguments.DataFile), "dataFile is empty");
+ ch.CheckUserArg(stopWordsMap.Count > 0, nameof(Options.DataFile), "dataFile is empty");
}
}
@@ -817,7 +820,7 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory stopwords, string d
/// The environment.
/// Array of words to remove.
/// Pairs of columns to remove stop words from.
- public CustomStopWordsRemovingTransformer(IHostEnvironment env, string[] stopwords, params (string outputColumnName, string inputColumnName)[] columns) :
+ internal CustomStopWordsRemovingTransformer(IHostEnvironment env, string[] stopwords, params (string outputColumnName, string inputColumnName)[] columns) :
base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns)
{
_stopWordsMap = new NormStr.Pool();
@@ -938,24 +941,24 @@ private static CustomStopWordsRemovingTransformer Create(IHostEnvironment env, M
}
// Factory method for SignatureDataTransform.
- internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- env.CheckValue(args.Columns, nameof(args.Columns));
- var cols = new (string outputColumnName, string inputColumnName)[args.Columns.Length];
+ env.CheckValue(options.Columns, nameof(options.Columns));
+ var cols = new (string outputColumnName, string inputColumnName)[options.Columns.Length];
for (int i = 0; i < cols.Length; i++)
{
- var item = args.Columns[i];
+ var item = options.Columns[i];
cols[i] = (item.Name, item.Source ?? item.Name);
}
CustomStopWordsRemovingTransformer transfrom = null;
- if (Utils.Size(args.Stopwords) > 0)
- transfrom = new CustomStopWordsRemovingTransformer(env, args.Stopwords, cols);
+ if (Utils.Size(options.Stopwords) > 0)
+ transfrom = new CustomStopWordsRemovingTransformer(env, options.Stopwords, cols);
else
- transfrom = new CustomStopWordsRemovingTransformer(env, args.Stopword, args.DataFile, args.StopwordsColumn, args.Loader, cols);
+ transfrom = new CustomStopWordsRemovingTransformer(env, options.Stopword, options.DataFile, options.StopwordsColumn, options.Loader, cols);
return transfrom.MakeDataTransform(input);
}
@@ -1057,7 +1060,7 @@ public sealed class CustomStopWordsRemovingEstimator : TrivialEstimatorName of the column resulting from the transformation of .
/// Name of the column to transform. If set to , the value of the will be used as source.
/// Array of words to remove.
- public CustomStopWordsRemovingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, params string[] stopwords)
+ internal CustomStopWordsRemovingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, params string[] stopwords)
: this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, stopwords)
{
}
@@ -1069,11 +1072,15 @@ public CustomStopWordsRemovingEstimator(IHostEnvironment env, string outputColum
/// The environment.
/// Pairs of columns to remove stop words on.
/// Array of words to remove.
- public CustomStopWordsRemovingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, string[] stopwords) :
+ internal CustomStopWordsRemovingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, string[] stopwords) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CustomStopWordsRemovingEstimator)), new CustomStopWordsRemovingTransformer(env, stopwords, columns))
{
}
+ ///
+ /// Returns the of the schema which will be produced by the transformer.
+ /// Used for schema propagation and verification in a pipeline.
+ ///
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
@@ -1089,4 +1096,4 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
return new SchemaShape(result.Values);
}
}
-}
\ No newline at end of file
+}
diff --git a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs
index 176ece41ba..57f56bfa64 100644
--- a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs
+++ b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs
@@ -100,7 +100,7 @@ public static TextNormalizingEstimator NormalizeText(this TransformsCatalog.Text
/// The text-related transform's catalog.
/// Name of the column resulting from the transformation of .
/// Name of the column to transform. If set to , the value of the will be used as source.
- /// The embeddings to use.
+ /// The embeddings to use.
///
///
/// new WordEmbeddingsExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), outputColumnName, inputColumnName, modelKind);
///
@@ -135,7 +135,7 @@ public static WordEmbeddingsExtractingEstimator ExtractWordEmbeddings(this Trans
///
/// The text-related transform's catalog.
- /// The embeddings to use.
+ /// The embeddings to use.
/// The array columns, and per-column configurations to extract embeedings from.
///
///
@@ -145,8 +145,8 @@ public static WordEmbeddingsExtractingEstimator ExtractWordEmbeddings(this Trans
///
///
public static WordEmbeddingsExtractingEstimator ExtractWordEmbeddings(this TransformsCatalog.TextTransforms catalog,
- WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe,
- params WordEmbeddingsExtractingTransformer.ColumnInfo[] columns)
+ WordEmbeddingsExtractingEstimator.PretrainedModelKind modelKind = WordEmbeddingsExtractingEstimator.PretrainedModelKind.Sswe,
+ params WordEmbeddingsExtractingEstimator.ColumnInfo[] columns)
=> new WordEmbeddingsExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), modelKind, columns);
///
@@ -180,7 +180,7 @@ public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextT
/// The text-related transform's catalog.
/// Pairs of columns to run the tokenization on.
public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextTransforms catalog,
- params WordTokenizingTransformer.ColumnInfo[] columns)
+ params WordTokenizingEstimator.ColumnInfo[] columns)
=> new WordTokenizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns);
///
@@ -241,7 +241,7 @@ public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.Text
/// The text-related transform's catalog.
/// Pairs of columns to run the ngram process on.
public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.TextTransforms catalog,
- params NgramExtractingTransformer.ColumnInfo[] columns)
+ params NgramExtractingEstimator.ColumnInfo[] columns)
=> new NgramExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns);
///
@@ -622,7 +622,7 @@ public static LatentDirichletAllocationEstimator LatentDirichletAllocation(this
/// Describes the parameters of LDA for each column pair.
public static LatentDirichletAllocationEstimator LatentDirichletAllocation(
this TransformsCatalog.TextTransforms catalog,
- params LatentDirichletAllocationTransformer.ColumnInfo[] columns)
+ params LatentDirichletAllocationEstimator.ColumnInfo[] columns)
=> new LatentDirichletAllocationEstimator(CatalogUtils.GetEnvironment(catalog), columns);
}
}
diff --git a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs
index 25931c5787..257e77c67a 100644
--- a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs
+++ b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs
@@ -328,11 +328,11 @@ public ITransformer Fit(IDataView input)
if (tparams.NeedsWordTokenizationTransform)
{
- var xfCols = new WordTokenizingTransformer.ColumnInfo[textCols.Length];
+ var xfCols = new WordTokenizingEstimator.ColumnInfo[textCols.Length];
wordTokCols = new string[textCols.Length];
for (int i = 0; i < textCols.Length; i++)
{
- var col = new WordTokenizingTransformer.ColumnInfo(GenerateColumnName(view.Schema, textCols[i], "WordTokenizer"), textCols[i]);
+ var col = new WordTokenizingEstimator.ColumnInfo(GenerateColumnName(view.Schema, textCols[i], "WordTokenizer"), textCols[i]);
xfCols[i] = col;
wordTokCols[i] = col.Name;
tempCols.Add(col.Name);
@@ -344,12 +344,12 @@ public ITransformer Fit(IDataView input)
if (tparams.UsePredefinedStopWordRemover)
{
Contracts.Assert(wordTokCols != null, "StopWords transform requires that word tokenization has been applied to the input text.");
- var xfCols = new StopWordsRemovingTransformer.ColumnInfo[wordTokCols.Length];
+ var xfCols = new StopWordsRemovingEstimator.ColumnInfo[wordTokCols.Length];
var dstCols = new string[wordTokCols.Length];
for (int i = 0; i < wordTokCols.Length; i++)
{
var tempName = GenerateColumnName(view.Schema, wordTokCols[i], "StopWordsRemoverTransform");
- var col = new StopWordsRemovingTransformer.ColumnInfo(tempName, wordTokCols[i], tparams.StopwordsLanguage);
+ var col = new StopWordsRemovingEstimator.ColumnInfo(tempName, wordTokCols[i], tparams.StopwordsLanguage);
dstCols[i] = tempName;
tempCols.Add(tempName);
diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs
index 4e1df85671..75bab0f8b4 100644
--- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs
+++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs
@@ -16,7 +16,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Transforms.Text;
-[assembly: LoadableClass(TextNormalizingTransformer.Summary, typeof(IDataTransform), typeof(TextNormalizingTransformer), typeof(TextNormalizingTransformer.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(TextNormalizingTransformer.Summary, typeof(IDataTransform), typeof(TextNormalizingTransformer), typeof(TextNormalizingTransformer.Options), typeof(SignatureDataTransform),
"Text Normalizer Transform", "TextNormalizerTransform", "TextNormalizer", "TextNorm")]
[assembly: LoadableClass(TextNormalizingTransformer.Summary, typeof(IDataTransform), typeof(TextNormalizingTransformer), null, typeof(SignatureLoadDataTransform),
@@ -36,7 +36,7 @@ namespace Microsoft.ML.Transforms.Text
///
public sealed class TextNormalizingTransformer : OneToOneTransformerBase
{
- public sealed class Column : OneToOneColumn
+ internal sealed class Column : OneToOneColumn
{
internal static Column Parse(string str)
{
@@ -53,7 +53,7 @@ internal bool TryUnparse(StringBuilder sb)
}
}
- public sealed class Arguments
+ internal sealed class Options
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
@@ -96,7 +96,7 @@ private static VersionInfo GetVersionInfo()
private readonly bool _keepPunctuations;
private readonly bool _keepNumbers;
- public TextNormalizingTransformer(IHostEnvironment env,
+ internal TextNormalizingTransformer(IHostEnvironment env,
TextNormalizingEstimator.CaseNormalizationMode textCase = TextNormalizingEstimator.Defaults.TextCase,
bool keepDiacritics = TextNormalizingEstimator.Defaults.KeepDiacritics,
bool keepPunctuations = TextNormalizingEstimator.Defaults.KeepPunctuations,
@@ -167,20 +167,20 @@ private TextNormalizingTransformer(IHost host, ModelLoadContext ctx)
}
// Factory method for SignatureDataTransform.
- private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- env.CheckValue(args.Columns, nameof(args.Columns));
- var cols = new (string outputColumnName, string inputColumnName)[args.Columns.Length];
+ env.CheckValue(options.Columns, nameof(options.Columns));
+ var cols = new (string outputColumnName, string inputColumnName)[options.Columns.Length];
for (int i = 0; i < cols.Length; i++)
{
- var item = args.Columns[i];
+ var item = options.Columns[i];
cols[i] = (item.Name, item.Source ?? item.Name);
}
- return new TextNormalizingTransformer(env, args.TextCase, args.KeepDiacritics, args.KeepPunctuations, args.KeepNumbers, cols).MakeDataTransform(input);
+ return new TextNormalizingTransformer(env, options.TextCase, options.KeepDiacritics, options.KeepPunctuations, options.KeepNumbers, cols).MakeDataTransform(input);
}
// Factory method for SignatureLoadDataTransform.
@@ -450,7 +450,7 @@ internal static class Defaults
}
- public static bool IsColumnTypeValid(ColumnType type) => (type.GetItemType() is TextType);
+ internal static bool IsColumnTypeValid(ColumnType type) => (type.GetItemType() is TextType);
internal const string ExpectedColumnType = "Text or vector of text.";
@@ -465,7 +465,7 @@ internal static class Defaults
/// Whether to keep diacritical marks or remove them.
/// Whether to keep punctuation marks or remove them.
/// Whether to keep numbers or remove them.
- public TextNormalizingEstimator(IHostEnvironment env,
+ internal TextNormalizingEstimator(IHostEnvironment env,
string outputColumnName,
string inputColumnName = null,
CaseNormalizationMode textCase = Defaults.TextCase,
@@ -486,7 +486,7 @@ public TextNormalizingEstimator(IHostEnvironment env,
/// Whether to keep punctuation marks or remove them.
/// Whether to keep numbers or remove them.
/// Pairs of columns to run the text normalization on.
- public TextNormalizingEstimator(IHostEnvironment env,
+ internal TextNormalizingEstimator(IHostEnvironment env,
CaseNormalizationMode textCase = Defaults.TextCase,
bool keepDiacritics = Defaults.KeepDiacritics,
bool keepPunctuations = Defaults.KeepPunctuations,
@@ -497,6 +497,10 @@ public TextNormalizingEstimator(IHostEnvironment env,
{
}
+ ///
+ /// Returns the of the schema which will be produced by the transformer.
+ /// Used for schema propagation and verification in a pipeline.
+ ///
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs
index 9903c6e07c..8ef8109c52 100644
--- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs
+++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs
@@ -17,7 +17,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Transforms.Text;
-[assembly: LoadableClass(TokenizingByCharactersTransformer.Summary, typeof(IDataTransform), typeof(TokenizingByCharactersTransformer), typeof(TokenizingByCharactersTransformer.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(TokenizingByCharactersTransformer.Summary, typeof(IDataTransform), typeof(TokenizingByCharactersTransformer), typeof(TokenizingByCharactersTransformer.Options), typeof(SignatureDataTransform),
TokenizingByCharactersTransformer.UserName, "CharTokenize", TokenizingByCharactersTransformer.LoaderSignature)]
[assembly: LoadableClass(typeof(IDataTransform), typeof(TokenizingByCharactersTransformer), null, typeof(SignatureLoadDataTransform),
@@ -36,7 +36,7 @@ namespace Microsoft.ML.Transforms.Text
///
public sealed class TokenizingByCharactersTransformer : OneToOneTransformerBase
{
- public sealed class Column : OneToOneColumn
+ internal sealed class Column : OneToOneColumn
{
internal static Column Parse(string str)
{
@@ -53,7 +53,7 @@ internal bool TryUnparse(StringBuilder sb)
}
}
- public sealed class Arguments : TransformInputBase
+ internal sealed class Options : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
@@ -106,7 +106,7 @@ private static VersionInfo GetVersionInfo()
/// The environment.
/// Whether to use marker characters to separate words.
/// Pairs of columns to run the tokenization on.
- public TokenizingByCharactersTransformer(IHostEnvironment env, bool useMarkerCharacters = TokenizingByCharactersEstimator.Defaults.UseMarkerCharacters,
+ internal TokenizingByCharactersTransformer(IHostEnvironment env, bool useMarkerCharacters = TokenizingByCharactersEstimator.Defaults.UseMarkerCharacters,
params (string outputColumnName, string inputColumnName)[] columns) :
base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns)
{
@@ -161,20 +161,20 @@ private static TokenizingByCharactersTransformer Create(IHostEnvironment env, Mo
}
// Factory method for SignatureDataTransform.
- internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- env.CheckValue(args.Columns, nameof(args.Columns));
- var cols = new (string outputColumnName, string inputColumnName)[args.Columns.Length];
+ env.CheckValue(options.Columns, nameof(options.Columns));
+ var cols = new (string outputColumnName, string inputColumnName)[options.Columns.Length];
for (int i = 0; i < cols.Length; i++)
{
- var item = args.Columns[i];
+ var item = options.Columns[i];
cols[i] = (item.Name, item.Source ?? item.Name);
}
- return new TokenizingByCharactersTransformer(env, args.UseMarkerChars, cols).MakeDataTransform(input);
+ return new TokenizingByCharactersTransformer(env, options.UseMarkerChars, cols).MakeDataTransform(input);
}
// Factory method for SignatureLoadRowMapper.
@@ -554,7 +554,7 @@ internal static class Defaults
{
public const bool UseMarkerCharacters = true;
}
- public static bool IsColumnTypeValid(ColumnType type) => type.GetItemType() is TextType;
+ internal static bool IsColumnTypeValid(ColumnType type) => type.GetItemType() is TextType;
internal const string ExpectedColumnType = "Text";
@@ -565,7 +565,7 @@ internal static class Defaults
/// Name of the column resulting from the transformation of .
/// Name of the column to transform. If set to , the value of the will be used as source.
/// Whether to use marker characters to separate words.
- public TokenizingByCharactersEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null,
+ internal TokenizingByCharactersEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null,
bool useMarkerCharacters = Defaults.UseMarkerCharacters)
: this(env, useMarkerCharacters, new[] { (outputColumnName, inputColumnName ?? outputColumnName) })
{
@@ -578,12 +578,16 @@ public TokenizingByCharactersEstimator(IHostEnvironment env, string outputColumn
/// Whether to use marker characters to separate words.
/// Pairs of columns to run the tokenization on.
- public TokenizingByCharactersEstimator(IHostEnvironment env, bool useMarkerCharacters = Defaults.UseMarkerCharacters,
+ internal TokenizingByCharactersEstimator(IHostEnvironment env, bool useMarkerCharacters = Defaults.UseMarkerCharacters,
params (string outputColumnName, string inputColumnName)[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TokenizingByCharactersEstimator)), new TokenizingByCharactersTransformer(env, useMarkerCharacters, columns))
{
}
+ ///
+ /// Returns the of the schema which will be produced by the transformer.
+ /// Used for schema propagation and verification in a pipeline.
+ ///
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
index 0fe45acc26..2af913b1c1 100644
--- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
@@ -14,7 +14,7 @@
using Microsoft.ML.Transforms.Conversions;
using Microsoft.ML.Transforms.Text;
-[assembly: LoadableClass(WordBagBuildingTransformer.Summary, typeof(IDataTransform), typeof(WordBagBuildingTransformer), typeof(WordBagBuildingTransformer.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(WordBagBuildingTransformer.Summary, typeof(IDataTransform), typeof(WordBagBuildingTransformer), typeof(WordBagBuildingTransformer.Options), typeof(SignatureDataTransform),
"Word Bag Transform", "WordBagTransform", "WordBag")]
[assembly: LoadableClass(NgramExtractorTransform.Summary, typeof(INgramExtractorFactory), typeof(NgramExtractorTransform), typeof(NgramExtractorTransform.NgramExtractorArguments),
@@ -42,7 +42,7 @@ internal sealed class ExtractorColumn : ManyToOneColumn
internal static class WordBagBuildingTransformer
{
- public sealed class Column : ManyToOneColumn
+ internal sealed class Column : ManyToOneColumn
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length", ShortName = "ngram")]
public int? NgramLength;
@@ -85,16 +85,7 @@ internal bool TryUnparse(StringBuilder sb)
}
}
- ///
- /// A vanilla implementation of OneToOneColumn that is used to represent the input of any tokenize
- /// transform (a transform that implements ITokenizeTransform interface).
- /// Note: Since WordBagTransform is a many-to-one column transform, for each WordBagTransform.Column
- /// with multiple sources, ConcatTransform is applied first. The output of ConcatTransform is a
- /// one-to-one column which is in turn the input to a tokenize transform.
- ///
- public sealed class TokenizeColumn : OneToOneColumn { }
-
- public sealed class Arguments : NgramExtractorTransform.ArgumentsBase
+ internal sealed class Options : NgramExtractorTransform.ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:srcs)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
@@ -105,13 +96,13 @@ public sealed class Arguments : NgramExtractorTransform.ArgumentsBase
internal const string Summary = "Produces a bag of counts of ngrams (sequences of consecutive words of length 1-n) in a given text. It does so by building "
+ "a dictionary of ngrams and using the id in the dictionary as the index in the bag.";
- internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
- h.CheckValue(args, nameof(args));
+ h.CheckValue(options, nameof(options));
h.CheckValue(input, nameof(input));
- h.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns), "Columns must be specified");
+ h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
// Compose the WordBagTransform from a tokenize transform,
// followed by a NgramExtractionTransform.
@@ -124,27 +115,27 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
// REVIEW: In order to make it possible to output separate bags for different columns
// using the same dictionary, we need to find a way to make ConcatTransform remember the boundaries.
- var tokenizeColumns = new WordTokenizingTransformer.ColumnInfo[args.Columns.Length];
+ var tokenizeColumns = new WordTokenizingEstimator.ColumnInfo[options.Columns.Length];
var extractorArgs =
- new NgramExtractorTransform.Arguments()
+ new NgramExtractorTransform.Options()
{
- MaxNumTerms = args.MaxNumTerms,
- NgramLength = args.NgramLength,
- SkipLength = args.SkipLength,
- AllLengths = args.AllLengths,
- Weighting = args.Weighting,
- Columns = new NgramExtractorTransform.Column[args.Columns.Length]
+ MaxNumTerms = options.MaxNumTerms,
+ NgramLength = options.NgramLength,
+ SkipLength = options.SkipLength,
+ AllLengths = options.AllLengths,
+ Weighting = options.Weighting,
+ Columns = new NgramExtractorTransform.Column[options.Columns.Length]
};
- for (int iinfo = 0; iinfo < args.Columns.Length; iinfo++)
+ for (int iinfo = 0; iinfo < options.Columns.Length; iinfo++)
{
- var column = args.Columns[iinfo];
+ var column = options.Columns[iinfo];
h.CheckUserArg(!string.IsNullOrWhiteSpace(column.Name), nameof(column.Name));
h.CheckUserArg(Utils.Size(column.Source) > 0, nameof(column.Source));
h.CheckUserArg(column.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(column.Source));
- tokenizeColumns[iinfo] = new WordTokenizingTransformer.ColumnInfo(column.Name, column.Source.Length > 1 ? column.Name : column.Source[0]);
+ tokenizeColumns[iinfo] = new WordTokenizingEstimator.ColumnInfo(column.Name, column.Source.Length > 1 ? column.Name : column.Source[0]);
extractorArgs.Columns[iinfo] =
new NgramExtractorTransform.Column()
@@ -160,7 +151,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
}
IDataView view = input;
- view = NgramExtractionUtils.ApplyConcatOnSources(h, args.Columns, view);
+ view = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns, view);
view = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view).Transform(view);
return NgramExtractorTransform.Create(h, extractorArgs, view);
}
@@ -173,7 +164,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
///
internal static class NgramExtractorTransform
{
- public sealed class Column : OneToOneColumn
+ internal sealed class Column : OneToOneColumn
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length (stores all lengths up to the specified Ngram length)", ShortName = "ngram")]
public int? NgramLength;
@@ -220,9 +211,9 @@ internal bool TryUnparse(StringBuilder sb)
///
/// This class is a merger of and
- /// , with the allLength option removed.
+ /// , with the allLength option removed.
///
- public abstract class ArgumentsBase
+ internal abstract class ArgumentsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length", ShortName = "ngram")]
public int NgramLength = 1;
@@ -254,7 +245,7 @@ public INgramExtractorFactory CreateComponent(IHostEnvironment env, TermLoaderAr
}
}
- public sealed class Arguments : ArgumentsBase
+ internal sealed class Options : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
@@ -265,22 +256,22 @@ public sealed class Arguments : ArgumentsBase
internal const string LoaderSignature = "NgramExtractor";
- public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input,
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input,
TermLoaderArguments termLoaderArgs = null)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(LoaderSignature);
- h.CheckValue(args, nameof(args));
+ h.CheckValue(options, nameof(options));
h.CheckValue(input, nameof(input));
- h.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns), "Columns must be specified");
+ h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
IDataView view = input;
var termCols = new List();
- var isTermCol = new bool[args.Columns.Length];
+ var isTermCol = new bool[options.Columns.Length];
- for (int i = 0; i < args.Columns.Length; i++)
+ for (int i = 0; i < options.Columns.Length; i++)
{
- var col = args.Columns[i];
+ var col = options.Columns[i];
h.CheckNonWhiteSpace(col.Name, nameof(col.Name));
h.CheckNonWhiteSpace(col.Source, nameof(col.Source));
@@ -324,7 +315,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
termArgs =
new ValueToKeyMappingTransformer.Options()
{
- MaxNumTerms = Utils.Size(args.MaxNumTerms) > 0 ? args.MaxNumTerms[0] : NgramExtractingEstimator.Defaults.MaxNumTerms,
+ MaxNumTerms = Utils.Size(options.MaxNumTerms) > 0 ? options.MaxNumTerms[0] : NgramExtractingEstimator.Defaults.MaxNumTerms,
Columns = new ValueToKeyMappingTransformer.Column[termCols.Count]
};
}
@@ -349,16 +340,16 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
view = new MissingValueDroppingTransformer(h, missingDropColumns.Select(x => (x, x)).ToArray()).Transform(view);
}
- var ngramColumns = new NgramExtractingTransformer.ColumnInfo[args.Columns.Length];
- for (int iinfo = 0; iinfo < args.Columns.Length; iinfo++)
+ var ngramColumns = new NgramExtractingEstimator.ColumnInfo[options.Columns.Length];
+ for (int iinfo = 0; iinfo < options.Columns.Length; iinfo++)
{
- var column = args.Columns[iinfo];
- ngramColumns[iinfo] = new NgramExtractingTransformer.ColumnInfo(column.Name,
- column.NgramLength ?? args.NgramLength,
- column.SkipLength ?? args.SkipLength,
- column.AllLengths ?? args.AllLengths,
- column.Weighting ?? args.Weighting,
- column.MaxNumTerms ?? args.MaxNumTerms,
+ var column = options.Columns[iinfo];
+ ngramColumns[iinfo] = new NgramExtractingEstimator.ColumnInfo(column.Name,
+ column.NgramLength ?? options.NgramLength,
+ column.SkipLength ?? options.SkipLength,
+ column.AllLengths ?? options.AllLengths,
+ column.Weighting ?? options.Weighting,
+ column.MaxNumTerms ?? options.MaxNumTerms,
isTermCol[iinfo] ? column.Name : column.Source
);
}
@@ -366,7 +357,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
return new NgramExtractingEstimator(env, ngramColumns).Fit(view).Transform(view) as IDataTransform;
}
- public static IDataTransform Create(IHostEnvironment env, NgramExtractorArguments extractorArgs, IDataView input,
+ internal static IDataTransform Create(IHostEnvironment env, NgramExtractorArguments extractorArgs, IDataView input,
ExtractorColumn[] cols, TermLoaderArguments termLoaderArgs = null)
{
Contracts.CheckValue(env, nameof(env));
@@ -374,7 +365,7 @@ public static IDataTransform Create(IHostEnvironment env, NgramExtractorArgument
h.CheckValue(extractorArgs, nameof(extractorArgs));
h.CheckValue(input, nameof(input));
h.CheckUserArg(extractorArgs.SkipLength < extractorArgs.NgramLength, nameof(extractorArgs.SkipLength), "Should be less than " + nameof(extractorArgs.NgramLength));
- h.CheckUserArg(Utils.Size(cols) > 0, nameof(Arguments.Columns), "Must be specified");
+ h.CheckUserArg(Utils.Size(cols) > 0, nameof(Options.Columns), "Must be specified");
h.CheckValueOrNull(termLoaderArgs);
var extractorCols = new Column[cols.Length];
@@ -384,7 +375,7 @@ public static IDataTransform Create(IHostEnvironment env, NgramExtractorArgument
extractorCols[i] = new Column { Name = cols[i].Name, Source = cols[i].Source[0] };
}
- var args = new Arguments
+ var options = new Options
{
Columns = extractorCols,
NgramLength = extractorArgs.NgramLength,
@@ -394,10 +385,10 @@ public static IDataTransform Create(IHostEnvironment env, NgramExtractorArgument
Weighting = extractorArgs.Weighting
};
- return Create(h, args, input, termLoaderArgs);
+ return Create(h, options, input, termLoaderArgs);
}
- public static INgramExtractorFactory Create(IHostEnvironment env, NgramExtractorArguments extractorArgs,
+ internal static INgramExtractorFactory Create(IHostEnvironment env, NgramExtractorArguments extractorArgs,
TermLoaderArguments termLoaderArgs)
{
Contracts.CheckValue(env, nameof(env));
@@ -517,7 +508,7 @@ public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColu
var concatColumns = new List();
foreach (var col in columns)
{
- env.CheckUserArg(col != null, nameof(WordBagBuildingTransformer.Arguments.Columns));
+ env.CheckUserArg(col != null, nameof(WordBagBuildingTransformer.Options.Columns));
env.CheckUserArg(!string.IsNullOrWhiteSpace(col.Name), nameof(col.Name));
env.CheckUserArg(Utils.Size(col.Source) > 0, nameof(col.Source));
env.CheckUserArg(col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source));
@@ -545,7 +536,7 @@ public static string[][] GenerateUniqueSourceNames(IHostEnvironment env, ManyToO
for (int iinfo = 0; iinfo < columns.Length; iinfo++)
{
var col = columns[iinfo];
- env.CheckUserArg(col != null, nameof(WordHashBagProducingTransformer.Arguments.Columns));
+ env.CheckUserArg(col != null, nameof(WordHashBagProducingTransformer.Options.Columns));
env.CheckUserArg(!string.IsNullOrWhiteSpace(col.Name), nameof(col.Name));
env.CheckUserArg(Utils.Size(col.Source) > 0 &&
col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source));
@@ -570,4 +561,4 @@ public static string[][] GenerateUniqueSourceNames(IHostEnvironment env, ManyToO
return uniqueNames;
}
}
-}
\ No newline at end of file
+}
diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs
index 07b2fb23f4..8a6c8108b5 100644
--- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs
@@ -21,7 +21,7 @@
using Microsoft.ML.Model.Onnx;
using Microsoft.ML.Transforms.Text;
-[assembly: LoadableClass(WordEmbeddingsExtractingTransformer.Summary, typeof(IDataTransform), typeof(WordEmbeddingsExtractingTransformer), typeof(WordEmbeddingsExtractingTransformer.Arguments),
+[assembly: LoadableClass(WordEmbeddingsExtractingTransformer.Summary, typeof(IDataTransform), typeof(WordEmbeddingsExtractingTransformer), typeof(WordEmbeddingsExtractingTransformer.Options),
typeof(SignatureDataTransform), WordEmbeddingsExtractingTransformer.UserName, "WordEmbeddingsTransform", WordEmbeddingsExtractingTransformer.ShortName, DocName = "transform/WordEmbeddingsTransform.md")]
[assembly: LoadableClass(WordEmbeddingsExtractingTransformer.Summary, typeof(IDataTransform), typeof(WordEmbeddingsExtractingTransformer), null, typeof(SignatureLoadDataTransform),
@@ -38,7 +38,7 @@ namespace Microsoft.ML.Transforms.Text
///
public sealed class WordEmbeddingsExtractingTransformer : OneToOneTransformerBase
{
- public sealed class Column : OneToOneColumn
+ internal sealed class Column : OneToOneColumn
{
internal static Column Parse(string str)
{
@@ -57,13 +57,13 @@ internal bool TryUnparse(StringBuilder sb)
}
}
- public sealed class Arguments : TransformInputBase
+ internal sealed class Options : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 0)]
public Column[] Columns;
[Argument(ArgumentType.AtMostOnce, HelpText = "Pre-trained model used to create the vocabulary", ShortName = "model", SortOrder = 1)]
- public PretrainedModelKind? ModelKind = PretrainedModelKind.Sswe;
+ public WordEmbeddingsExtractingEstimator.PretrainedModelKind? ModelKind = WordEmbeddingsExtractingEstimator.PretrainedModelKind.Sswe;
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Filename for custom word embedding model",
ShortName = "dataFile", SortOrder = 2)]
@@ -76,7 +76,7 @@ public sealed class Arguments : TransformInputBase
internal const string ShortName = "WordEmbeddings";
internal const string LoaderSignature = "WordEmbeddingsTransform";
- public static VersionInfo GetVersionInfo()
+ internal static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "W2VTRANS",
@@ -87,7 +87,7 @@ public static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(WordEmbeddingsExtractingTransformer).Assembly.FullName);
}
- private readonly PretrainedModelKind? _modelKind;
+ private readonly WordEmbeddingsExtractingEstimator.PretrainedModelKind? _modelKind;
private readonly string _modelFileNameWithPath;
private static object _embeddingsLock = new object();
private readonly bool _customLookup;
@@ -148,23 +148,6 @@ public List GetWordLabels()
}
- ///
- /// Information for each column pair.
- ///
- public sealed class ColumnInfo
- {
- public readonly string Name;
- public readonly string InputColumnName;
-
- public ColumnInfo(string name, string inputColumnName = null)
- {
- Contracts.CheckNonEmpty(name, nameof(name));
-
- Name = name;
- InputColumnName = inputColumnName ?? name;
- }
- }
-
private const string RegistrationName = "WordEmbeddings";
private const int Timeout = 10 * 60 * 1000;
@@ -176,9 +159,9 @@ public ColumnInfo(string name, string inputColumnName = null)
/// Name of the column resulting from the transformation of .
/// Name of the column to transform. If set to , the value of the will be used as source.
/// The pretrained word embedding model.
- public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null,
- PretrainedModelKind modelKind = PretrainedModelKind.Sswe)
- : this(env, modelKind, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName))
+ internal WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null,
+ WordEmbeddingsExtractingEstimator.PretrainedModelKind modelKind = WordEmbeddingsExtractingEstimator.PretrainedModelKind.Sswe)
+ : this(env, modelKind, new WordEmbeddingsExtractingEstimator.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName))
{
}
@@ -189,8 +172,8 @@ public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputCo
/// Name of the column resulting from the transformation of .
/// Filename for custom word embedding model.
/// Name of the column to transform. If set to , the value of the will be used as source.
- public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputColumnName, string customModelFile, string inputColumnName = null)
- : this(env, customModelFile, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName))
+ internal WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputColumnName, string customModelFile, string inputColumnName = null)
+ : this(env, customModelFile, new WordEmbeddingsExtractingEstimator.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName))
{
}
@@ -200,13 +183,13 @@ public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputCo
/// Host Environment.
/// The pretrained word embedding model.
/// Input/Output columns.
- public WordEmbeddingsExtractingTransformer(IHostEnvironment env, PretrainedModelKind modelKind, params ColumnInfo[] columns)
+ internal WordEmbeddingsExtractingTransformer(IHostEnvironment env, WordEmbeddingsExtractingEstimator.PretrainedModelKind modelKind, params WordEmbeddingsExtractingEstimator.ColumnInfo[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
{
- env.CheckUserArg(Enum.IsDefined(typeof(PretrainedModelKind), modelKind), nameof(modelKind));
+ env.CheckUserArg(Enum.IsDefined(typeof(WordEmbeddingsExtractingEstimator.PretrainedModelKind), modelKind), nameof(modelKind));
_modelKind = modelKind;
- _modelFileNameWithPath = EnsureModelFile(env, out _linesToSkip, (PretrainedModelKind)_modelKind);
+ _modelFileNameWithPath = EnsureModelFile(env, out _linesToSkip, (WordEmbeddingsExtractingEstimator.PretrainedModelKind)_modelKind);
_currentVocab = GetVocabularyDictionary(env);
}
@@ -216,7 +199,7 @@ public WordEmbeddingsExtractingTransformer(IHostEnvironment env, PretrainedModel
/// Host Environment.
/// Filename for custom word embedding model.
/// Input/Output columns.
- public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string customModelFile, params ColumnInfo[] columns)
+ internal WordEmbeddingsExtractingTransformer(IHostEnvironment env, string customModelFile, params WordEmbeddingsExtractingEstimator.ColumnInfo[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
{
env.CheckValue(customModelFile, nameof(customModelFile));
@@ -228,39 +211,39 @@ public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string customMo
_currentVocab = GetVocabularyDictionary(env);
}
- private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns)
+ private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(WordEmbeddingsExtractingEstimator.ColumnInfo[] columns)
{
Contracts.CheckValue(columns, nameof(columns));
return columns.Select(x => (x.Name, x.InputColumnName)).ToArray();
}
// Factory method for SignatureDataTransform.
- internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- if (args.ModelKind == null)
- args.ModelKind = PretrainedModelKind.Sswe;
- env.CheckUserArg(!args.ModelKind.HasValue || Enum.IsDefined(typeof(PretrainedModelKind), args.ModelKind), nameof(args.ModelKind));
+ if (options.ModelKind == null)
+ options.ModelKind = WordEmbeddingsExtractingEstimator.PretrainedModelKind.Sswe;
+ env.CheckUserArg(!options.ModelKind.HasValue || Enum.IsDefined(typeof(WordEmbeddingsExtractingEstimator.PretrainedModelKind), options.ModelKind), nameof(options.ModelKind));
- env.CheckValue(args.Columns, nameof(args.Columns));
+ env.CheckValue(options.Columns, nameof(options.Columns));
- var cols = new ColumnInfo[args.Columns.Length];
+ var cols = new WordEmbeddingsExtractingEstimator.ColumnInfo[options.Columns.Length];
for (int i = 0; i < cols.Length; i++)
{
- var item = args.Columns[i];
- cols[i] = new ColumnInfo(
+ var item = options.Columns[i];
+ cols[i] = new WordEmbeddingsExtractingEstimator.ColumnInfo(
item.Name,
item.Source ?? item.Name);
}
- bool customLookup = !string.IsNullOrWhiteSpace(args.CustomLookupTable);
+ bool customLookup = !string.IsNullOrWhiteSpace(options.CustomLookupTable);
if (customLookup)
- return new WordEmbeddingsExtractingTransformer(env, args.CustomLookupTable, cols).MakeDataTransform(input);
+ return new WordEmbeddingsExtractingTransformer(env, options.CustomLookupTable, cols).MakeDataTransform(input);
else
- return new WordEmbeddingsExtractingTransformer(env, args.ModelKind.Value, cols).MakeDataTransform(input);
+ return new WordEmbeddingsExtractingTransformer(env, options.ModelKind.Value, cols).MakeDataTransform(input);
}
private WordEmbeddingsExtractingTransformer(IHost host, ModelLoadContext ctx)
@@ -276,15 +259,15 @@ private WordEmbeddingsExtractingTransformer(IHost host, ModelLoadContext ctx)
}
else
{
- _modelKind = (PretrainedModelKind)ctx.Reader.ReadUInt32();
- _modelFileNameWithPath = EnsureModelFile(Host, out _linesToSkip, (PretrainedModelKind)_modelKind);
+ _modelKind = (WordEmbeddingsExtractingEstimator.PretrainedModelKind)ctx.Reader.ReadUInt32();
+ _modelFileNameWithPath = EnsureModelFile(Host, out _linesToSkip, (WordEmbeddingsExtractingEstimator.PretrainedModelKind)_modelKind);
}
Host.CheckNonWhiteSpace(_modelFileNameWithPath, nameof(_modelFileNameWithPath));
_currentVocab = GetVocabularyDictionary(host);
}
- public static WordEmbeddingsExtractingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
+ internal static WordEmbeddingsExtractingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
IHost h = env.Register(RegistrationName);
@@ -618,57 +601,24 @@ private ValueGetter> GetGetterVec(Row input, int iinfo)
}
}
- public enum PretrainedModelKind
- {
- [TGUI(Label = "GloVe 50D")]
- GloVe50D = 0,
-
- [TGUI(Label = "GloVe 100D")]
- GloVe100D = 1,
-
- [TGUI(Label = "GloVe 200D")]
- GloVe200D = 2,
-
- [TGUI(Label = "GloVe 300D")]
- GloVe300D = 3,
-
- [TGUI(Label = "GloVe Twitter 25D")]
- GloVeTwitter25D = 4,
-
- [TGUI(Label = "GloVe Twitter 50D")]
- GloVeTwitter50D = 5,
-
- [TGUI(Label = "GloVe Twitter 100D")]
- GloVeTwitter100D = 6,
-
- [TGUI(Label = "GloVe Twitter 200D")]
- GloVeTwitter200D = 7,
-
- [TGUI(Label = "fastText Wikipedia 300D")]
- FastTextWikipedia300D = 8,
-
- [TGUI(Label = "Sentiment-Specific Word Embedding")]
- Sswe = 9
- }
-
- private static Dictionary _modelsMetaData = new Dictionary()
+ private static Dictionary _modelsMetaData = new Dictionary()
{
- { PretrainedModelKind.GloVe50D, "glove.6B.50d.txt" },
- { PretrainedModelKind.GloVe100D, "glove.6B.100d.txt" },
- { PretrainedModelKind.GloVe200D, "glove.6B.200d.txt" },
- { PretrainedModelKind.GloVe300D, "glove.6B.300d.txt" },
- { PretrainedModelKind.GloVeTwitter25D, "glove.twitter.27B.25d.txt" },
- { PretrainedModelKind.GloVeTwitter50D, "glove.twitter.27B.50d.txt" },
- { PretrainedModelKind.GloVeTwitter100D, "glove.twitter.27B.100d.txt" },
- { PretrainedModelKind.GloVeTwitter200D, "glove.twitter.27B.200d.txt" },
- { PretrainedModelKind.FastTextWikipedia300D, "wiki.en.vec" },
- { PretrainedModelKind.Sswe, "sentiment.emd" }
+ { WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVe50D, "glove.6B.50d.txt" },
+ { WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVe100D, "glove.6B.100d.txt" },
+ { WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVe200D, "glove.6B.200d.txt" },
+ { WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVe300D, "glove.6B.300d.txt" },
+ { WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVeTwitter25D, "glove.twitter.27B.25d.txt" },
+ { WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVeTwitter50D, "glove.twitter.27B.50d.txt" },
+ { WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVeTwitter100D, "glove.twitter.27B.100d.txt" },
+ { WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVeTwitter200D, "glove.twitter.27B.200d.txt" },
+ { WordEmbeddingsExtractingEstimator.PretrainedModelKind.FastTextWikipedia300D, "wiki.en.vec" },
+ { WordEmbeddingsExtractingEstimator.PretrainedModelKind.Sswe, "sentiment.emd" }
};
- private static Dictionary _linesToSkipInModels = new Dictionary()
- { { PretrainedModelKind.FastTextWikipedia300D, 1 } };
+ private static Dictionary _linesToSkipInModels = new Dictionary()
+ { { WordEmbeddingsExtractingEstimator.PretrainedModelKind.FastTextWikipedia300D, 1 } };
- private string EnsureModelFile(IHostEnvironment env, out int linesToSkip, PretrainedModelKind kind)
+ private string EnsureModelFile(IHostEnvironment env, out int linesToSkip, WordEmbeddingsExtractingEstimator.PretrainedModelKind kind)
{
linesToSkip = 0;
if (_modelsMetaData.ContainsKey(kind))
@@ -678,7 +628,7 @@ private string EnsureModelFile(IHostEnvironment env, out int linesToSkip, Pretra
linesToSkip = _linesToSkipInModels[kind];
using (var ch = Host.Start("Ensuring resources"))
{
- string dir = kind == PretrainedModelKind.Sswe ? Path.Combine("Text", "Sswe") : "WordVectors";
+ string dir = kind == WordEmbeddingsExtractingEstimator.PretrainedModelKind.Sswe ? Path.Combine("Text", "Sswe") : "WordVectors";
var url = $"{dir}/{modelFileName}";
var ensureModel = ResourceManagerUtils.Instance.EnsureResource(Host, ch, url, modelFileName, dir, Timeout);
ensureModel.Wait();
@@ -786,8 +736,8 @@ private static ParallelOptions GetParallelOptions(IHostEnvironment hostEnvironme
public sealed class WordEmbeddingsExtractingEstimator : IEstimator
{
private readonly IHost _host;
- private readonly WordEmbeddingsExtractingTransformer.ColumnInfo[] _columns;
- private readonly WordEmbeddingsExtractingTransformer.PretrainedModelKind? _modelKind;
+ private readonly ColumnInfo[] _columns;
+ private readonly PretrainedModelKind? _modelKind;
private readonly string _customLookupTable;
///
@@ -799,10 +749,10 @@ public sealed class WordEmbeddingsExtractingEstimator : IEstimatorThe local instance of
/// Name of the column resulting from the transformation of .
/// Name of the column to transform. If set to , the value of the will be used as source.
- /// The embeddings to use.
+ /// The embeddings to use.
internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null,
- WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe)
- : this(env, modelKind, new WordEmbeddingsExtractingTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName))
+ PretrainedModelKind modelKind = PretrainedModelKind.Sswe)
+ : this(env, modelKind, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName))
{
}
@@ -817,7 +767,7 @@ internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string outputCo
/// The path of the pre-trained embeedings model to use.
/// Name of the column to transform.
internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string outputColumnName, string customModelFile, string inputColumnName = null)
- : this(env, customModelFile, new WordEmbeddingsExtractingTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName))
+ : this(env, customModelFile, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName))
{
}
@@ -828,11 +778,11 @@ internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string outputCo
/// and third one represent maximum encountered values (for each dimension).
///
/// The local instance of
- /// The embeddings to use.
+ /// The embeddings to use.
/// The array columns, and per-column configurations to extract embeedings from.
internal WordEmbeddingsExtractingEstimator(IHostEnvironment env,
- WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe,
- params WordEmbeddingsExtractingTransformer.ColumnInfo[] columns)
+ PretrainedModelKind modelKind = PretrainedModelKind.Sswe,
+ params ColumnInfo[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(WordEmbeddingsExtractingEstimator));
@@ -841,7 +791,7 @@ internal WordEmbeddingsExtractingEstimator(IHostEnvironment env,
_columns = columns;
}
- internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string customModelFile, params WordEmbeddingsExtractingTransformer.ColumnInfo[] columns)
+ internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string customModelFile, params ColumnInfo[] columns)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(WordEmbeddingsExtractingEstimator));
@@ -850,6 +800,71 @@ internal WordEmbeddingsExtractingEstimator(IHostEnvironment env, string customMo
_columns = columns;
}
+ public enum PretrainedModelKind
+ {
+ [TGUI(Label = "GloVe 50D")]
+ GloVe50D = 0,
+
+ [TGUI(Label = "GloVe 100D")]
+ GloVe100D = 1,
+
+ [TGUI(Label = "GloVe 200D")]
+ GloVe200D = 2,
+
+ [TGUI(Label = "GloVe 300D")]
+ GloVe300D = 3,
+
+ [TGUI(Label = "GloVe Twitter 25D")]
+ GloVeTwitter25D = 4,
+
+ [TGUI(Label = "GloVe Twitter 50D")]
+ GloVeTwitter50D = 5,
+
+ [TGUI(Label = "GloVe Twitter 100D")]
+ GloVeTwitter100D = 6,
+
+ [TGUI(Label = "GloVe Twitter 200D")]
+ GloVeTwitter200D = 7,
+
+ [TGUI(Label = "fastText Wikipedia 300D")]
+ FastTextWikipedia300D = 8,
+
+ [TGUI(Label = "Sentiment-Specific Word Embedding")]
+ Sswe = 9
+ }
+ ///
+ /// Information for each column pair.
+ ///
+ public sealed class ColumnInfo
+ {
+ ///
+ /// Name of the column resulting from the transformation of .
+ ///
+ public readonly string Name;
+ ///
+ /// Name of column to transform.
+ ///
+ public readonly string InputColumnName;
+
+ ///
+ /// Describes how the transformer handles one column pair.
+ ///
+ /// Name of the column resulting from the transformation of .
+ /// Name of column to transform. If set to will be used as source.
+
+ public ColumnInfo(string name, string inputColumnName = null)
+ {
+ Contracts.CheckNonEmpty(name, nameof(name));
+
+ Name = name;
+ InputColumnName = inputColumnName ?? name;
+ }
+ }
+
+ ///
+ /// Returns the of the schema which will be produced by the transformer.
+ /// Used for schema propagation and verification in a pipeline.
+ ///
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
@@ -867,6 +882,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
return new SchemaShape(result.Values);
}
+ ///
+ /// Trains and returns a .
+ ///
public WordEmbeddingsExtractingTransformer Fit(IDataView input)
{
bool customLookup = !string.IsNullOrWhiteSpace(_customLookupTable);
diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs
index 77b7a8444c..049379e18d 100644
--- a/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs
@@ -14,7 +14,7 @@
using Microsoft.ML.Transforms.Conversions;
using Microsoft.ML.Transforms.Text;
-[assembly: LoadableClass(WordHashBagProducingTransformer.Summary, typeof(IDataTransform), typeof(WordHashBagProducingTransformer), typeof(WordHashBagProducingTransformer.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(WordHashBagProducingTransformer.Summary, typeof(IDataTransform), typeof(WordHashBagProducingTransformer), typeof(WordHashBagProducingTransformer.Options), typeof(SignatureDataTransform),
"Word Hash Bag Transform", "WordHashBagTransform", "WordHashBag")]
[assembly: LoadableClass(NgramHashExtractingTransformer.Summary, typeof(INgramExtractorFactory), typeof(NgramHashExtractingTransformer), typeof(NgramHashExtractingTransformer.NgramHashExtractorArguments),
@@ -26,7 +26,7 @@ namespace Microsoft.ML.Transforms.Text
{
internal static class WordHashBagProducingTransformer
{
- public sealed class Column : NgramHashExtractingTransformer.ColumnBase
+ internal sealed class Column : NgramHashExtractingTransformer.ColumnBase
{
internal static Column Parse(string str)
{
@@ -73,7 +73,7 @@ internal bool TryUnparse(StringBuilder sb)
}
}
- public sealed class Arguments : NgramHashExtractingTransformer.ArgumentsBase
+ internal sealed class Options : NgramHashExtractingTransformer.ArgumentsBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:hashBits:srcs)",
Name = "Column", ShortName = "col", SortOrder = 1)]
@@ -84,13 +84,13 @@ public sealed class Arguments : NgramHashExtractingTransformer.ArgumentsBase
internal const string Summary = "Produces a bag of counts of ngrams (sequences of consecutive words of length 1-n) in a given text. "
+ "It does so by hashing each ngram and using the hash value as the index in the bag.";
- public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
- h.CheckValue(args, nameof(args));
+ h.CheckValue(options, nameof(options));
h.CheckValue(input, nameof(input));
- h.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns), "Columns must be specified");
+ h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
// To each input column to the WordHashBagTransform, a tokenize transform is applied,
// followed by applying WordHashVectorizeTransform.
@@ -100,21 +100,21 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
// The intermediate columns are dropped at the end of using a DropColumnsTransform.
IDataView view = input;
- var uniqueSourceNames = NgramExtractionUtils.GenerateUniqueSourceNames(h, args.Columns, view.Schema);
- Contracts.Assert(uniqueSourceNames.Length == args.Columns.Length);
+ var uniqueSourceNames = NgramExtractionUtils.GenerateUniqueSourceNames(h, options.Columns, view.Schema);
+ Contracts.Assert(uniqueSourceNames.Length == options.Columns.Length);
- var tokenizeColumns = new List();
- var extractorCols = new NgramHashExtractingTransformer.Column[args.Columns.Length];
- var colCount = args.Columns.Length;
+ var tokenizeColumns = new List();
+ var extractorCols = new NgramHashExtractingTransformer.Column[options.Columns.Length];
+ var colCount = options.Columns.Length;
List tmpColNames = new List();
for (int iinfo = 0; iinfo < colCount; iinfo++)
{
- var column = args.Columns[iinfo];
+ var column = options.Columns[iinfo];
int srcCount = column.Source.Length;
var curTmpNames = new string[srcCount];
- Contracts.Assert(uniqueSourceNames[iinfo].Length == args.Columns[iinfo].Source.Length);
+ Contracts.Assert(uniqueSourceNames[iinfo].Length == options.Columns[iinfo].Source.Length);
for (int isrc = 0; isrc < srcCount; isrc++)
- tokenizeColumns.Add(new WordTokenizingTransformer.ColumnInfo(curTmpNames[isrc] = uniqueSourceNames[iinfo][isrc], args.Columns[iinfo].Source[isrc]));
+ tokenizeColumns.Add(new WordTokenizingEstimator.ColumnInfo(curTmpNames[isrc] = uniqueSourceNames[iinfo][isrc], options.Columns[iinfo].Source[isrc]));
tmpColNames.AddRange(curTmpNames);
extractorCols[iinfo] =
@@ -128,7 +128,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
SkipLength = column.SkipLength,
Ordered = column.Ordered,
InvertHash = column.InvertHash,
- FriendlyNames = args.Columns[iinfo].Source,
+ FriendlyNames = options.Columns[iinfo].Source,
AllLengths = column.AllLengths
};
}
@@ -136,16 +136,16 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
view = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view).Transform(view);
var featurizeArgs =
- new NgramHashExtractingTransformer.Arguments
+ new NgramHashExtractingTransformer.Options
{
- AllLengths = args.AllLengths,
- HashBits = args.HashBits,
- NgramLength = args.NgramLength,
- SkipLength = args.SkipLength,
- Ordered = args.Ordered,
- Seed = args.Seed,
+ AllLengths = options.AllLengths,
+ HashBits = options.HashBits,
+ NgramLength = options.NgramLength,
+ SkipLength = options.SkipLength,
+ Ordered = options.Ordered,
+ Seed = options.Seed,
Columns = extractorCols.ToArray(),
- InvertHash = args.InvertHash
+ InvertHash = options.InvertHash
};
view = NgramHashExtractingTransformer.Create(h, featurizeArgs, view);
@@ -193,7 +193,7 @@ public abstract class ColumnBase : ManyToOneColumn
public bool? AllLengths;
}
- public sealed class Column : ColumnBase
+ internal sealed class Column : ColumnBase
{
// For all source columns, use these friendly names for the source
// column names instead of the real column names.
@@ -246,10 +246,10 @@ internal bool TryUnparse(StringBuilder sb)
///
/// This class is a merger of and
- /// , with the ordered option,
+ /// , with the ordered option,
/// the rehashUnigrams option and the allLength option removed.
///
- public abstract class ArgumentsBase
+ internal abstract class ArgumentsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length", ShortName = "ngram", SortOrder = 3)]
public int NgramLength = 1;
@@ -296,7 +296,7 @@ internal static class DefaultArguments
[TlcModule.Component(Name = "NGramHash", FriendlyName = "NGram Hash Extractor Transform", Alias = "NGramHashExtractorTransform,NGramHashExtractor",
Desc = "Extracts NGrams from text and convert them to vector using hashing trick.")]
- public sealed class NgramHashExtractorArguments : ArgumentsBase, INgramExtractorFactoryFactory
+ internal sealed class NgramHashExtractorArguments : ArgumentsBase, INgramExtractorFactoryFactory
{
public INgramExtractorFactory CreateComponent(IHostEnvironment env, TermLoaderArguments loaderArgs)
{
@@ -304,7 +304,7 @@ public INgramExtractorFactory CreateComponent(IHostEnvironment env, TermLoaderAr
}
}
- public sealed class Arguments : ArgumentsBase
+ internal sealed class Options : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:srcs)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
@@ -314,14 +314,14 @@ public sealed class Arguments : ArgumentsBase
internal const string LoaderSignature = "NgramHashExtractor";
- public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input,
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input,
TermLoaderArguments termLoaderArgs = null)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(LoaderSignature);
- h.CheckValue(args, nameof(args));
+ h.CheckValue(options, nameof(options));
h.CheckValue(input, nameof(input));
- h.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns), "Columns must be specified");
+ h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
// To each input column to the NgramHashExtractorArguments, a HashTransform using 31
// bits (to minimize collisions) is applied first, followed by an NgramHashTransform.
@@ -330,16 +330,17 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
List termCols = null;
if (termLoaderArgs != null)
termCols = new List();
- var hashColumns = new List();
- var ngramHashColumns = new NgramHashingTransformer.ColumnInfo[args.Columns.Length];
- var colCount = args.Columns.Length;
+ var hashColumns = new List();
+ var ngramHashColumns = new NgramHashingEstimator.ColumnInfo[options.Columns.Length];
+
+ var colCount = options.Columns.Length;
// The NGramHashExtractor has a ManyToOne column type. To avoid stepping over the source
// column name when a 'name' destination column name was specified, we use temporary column names.
string[][] tmpColNames = new string[colCount][];
for (int iinfo = 0; iinfo < colCount; iinfo++)
{
- var column = args.Columns[iinfo];
+ var column = options.Columns[iinfo];
h.CheckUserArg(!string.IsNullOrWhiteSpace(column.Name), nameof(column.Name));
h.CheckUserArg(Utils.Size(column.Source) > 0 &&
column.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(column.Source));
@@ -361,18 +362,18 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
}
hashColumns.Add(new HashingEstimator.ColumnInfo(tmpName, termLoaderArgs == null ? column.Source[isrc] : tmpName,
- 30, column.Seed ?? args.Seed, false, column.InvertHash ?? args.InvertHash));
+ 30, column.Seed ?? options.Seed, false, column.InvertHash ?? options.InvertHash));
}
ngramHashColumns[iinfo] =
- new NgramHashingTransformer.ColumnInfo(column.Name, tmpColNames[iinfo],
- column.NgramLength ?? args.NgramLength,
- column.SkipLength ?? args.SkipLength,
- column.AllLengths ?? args.AllLengths,
- column.HashBits ?? args.HashBits,
- column.Seed ?? args.Seed,
- column.Ordered ?? args.Ordered,
- column.InvertHash ?? args.InvertHash);
+ new NgramHashingEstimator.ColumnInfo(column.Name, tmpColNames[iinfo],
+ column.NgramLength ?? options.NgramLength,
+ column.SkipLength ?? options.SkipLength,
+ column.AllLengths ?? options.AllLengths,
+ column.HashBits ?? options.HashBits,
+ column.Seed ?? options.Seed,
+ column.Ordered ?? options.Ordered,
+ column.InvertHash ?? options.InvertHash);
ngramHashColumns[iinfo].FriendlyNames = column.FriendlyNames;
}
@@ -406,7 +407,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
return ColumnSelectingTransformer.CreateDrop(h, view, tmpColNames.SelectMany(cols => cols).ToArray());
}
- public static IDataTransform Create(NgramHashExtractorArguments extractorArgs, IHostEnvironment env, IDataView input,
+ internal static IDataTransform Create(NgramHashExtractorArguments extractorArgs, IHostEnvironment env, IDataView input,
ExtractorColumn[] cols, TermLoaderArguments termLoaderArgs = null)
{
Contracts.CheckValue(env, nameof(env));
@@ -414,7 +415,7 @@ public static IDataTransform Create(NgramHashExtractorArguments extractorArgs, I
h.CheckValue(extractorArgs, nameof(extractorArgs));
h.CheckValue(input, nameof(input));
h.CheckUserArg(extractorArgs.SkipLength < extractorArgs.NgramLength, nameof(extractorArgs.SkipLength), "Should be less than " + nameof(extractorArgs.NgramLength));
- h.CheckUserArg(Utils.Size(cols) > 0, nameof(Arguments.Columns), "Must be specified");
+ h.CheckUserArg(Utils.Size(cols) > 0, nameof(Options.Columns), "Must be specified");
h.AssertValueOrNull(termLoaderArgs);
var extractorCols = new Column[cols.Length];
@@ -429,7 +430,7 @@ public static IDataTransform Create(NgramHashExtractorArguments extractorArgs, I
};
}
- var args = new Arguments
+ var options = new Options
{
Columns = extractorCols,
NgramLength = extractorArgs.NgramLength,
@@ -441,10 +442,10 @@ public static IDataTransform Create(NgramHashExtractorArguments extractorArgs, I
AllLengths = extractorArgs.AllLengths
};
- return Create(h, args, input, termLoaderArgs);
+ return Create(h, options, input, termLoaderArgs);
}
- public static INgramExtractorFactory Create(IHostEnvironment env, NgramHashExtractorArguments extractorArgs,
+ internal static INgramExtractorFactory Create(IHostEnvironment env, NgramHashExtractorArguments extractorArgs,
TermLoaderArguments termLoaderArgs)
{
Contracts.CheckValue(env, nameof(env));
diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs
index c1c7820afa..1d46ed420e 100644
--- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs
@@ -19,7 +19,7 @@
using Microsoft.ML.Transforms.Text;
using Newtonsoft.Json.Linq;
-[assembly: LoadableClass(WordTokenizingTransformer.Summary, typeof(IDataTransform), typeof(WordTokenizingTransformer), typeof(WordTokenizingTransformer.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(WordTokenizingTransformer.Summary, typeof(IDataTransform), typeof(WordTokenizingTransformer), typeof(WordTokenizingTransformer.Options), typeof(SignatureDataTransform),
"Word Tokenizer Transform", "WordTokenizeTransform", "DelimitedTokenizeTransform", "WordToken", "DelimitedTokenize", "Token")]
[assembly: LoadableClass(WordTokenizingTransformer.Summary, typeof(IDataTransform), typeof(WordTokenizingTransformer), null, typeof(SignatureLoadDataTransform),
@@ -40,7 +40,7 @@ namespace Microsoft.ML.Transforms.Text
///
public sealed class WordTokenizingTransformer : OneToOneTransformerBase
{
- public class Column : OneToOneColumn
+ internal class Column : OneToOneColumn
{
[Argument(ArgumentType.AtMostOnce,
HelpText = "Comma separated set of term separator(s). Commonly: 'space', 'comma', 'semicolon' or other single character.",
@@ -64,7 +64,7 @@ internal bool TryUnparse(StringBuilder sb)
}
}
- public abstract class ArgumentsBase : TransformInputBase
+ internal abstract class ArgumentsBase : TransformInputBase
{
// REVIEW: Think about adding a user specified separator string, that is added as an extra token between
// the tokens of each column
@@ -81,16 +81,12 @@ public abstract class ArgumentsBase : TransformInputBase
public char[] CharArrayTermSeparators;
}
- public sealed class Arguments : ArgumentsBase
+ internal sealed class Options : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
}
- public sealed class TokenizeArguments : ArgumentsBase
- {
- }
-
internal const string Summary = "The input to this transform is text, and the output is a vector of text containing the words (tokens) in the original text. "
+ "The separator is space, but can be specified as any other character (or multiple characters) if needed.";
@@ -111,35 +107,16 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "DelimitedTokenize";
- public sealed class ColumnInfo
- {
- public readonly string Name;
- public readonly string InputColumnName;
- public readonly char[] Separators;
-
- ///
- /// Describes how the transformer handles one column pair.
- ///
- /// Name of the column resulting from the transformation of .
- /// Name of column to transform. If set to , the value of the will be used as source.
- /// Casing text using the rules of the invariant culture. If not specified, space will be used as separator.
- public ColumnInfo(string name, string inputColumnName = null, char[] separators = null)
- {
- Name = name;
- InputColumnName = inputColumnName ?? name;
- Separators = separators ?? new[] { ' ' };
- }
- }
- public IReadOnlyCollection Columns => _columns.AsReadOnly();
- private readonly ColumnInfo[] _columns;
+ public IReadOnlyCollection Columns => _columns.AsReadOnly();
+ private readonly WordTokenizingEstimator.ColumnInfo[] _columns;
- private static (string name, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns)
+ private static (string name, string inputColumnName)[] GetColumnPairs(WordTokenizingEstimator.ColumnInfo[] columns)
{
Contracts.CheckNonEmpty(columns, nameof(columns));
return columns.Select(x => (x.Name, x.InputColumnName)).ToArray();
}
- public WordTokenizingTransformer(IHostEnvironment env, params ColumnInfo[] columns) :
+ internal WordTokenizingTransformer(IHostEnvironment env, params WordTokenizingEstimator.ColumnInfo[] columns) :
base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
{
_columns = columns.ToArray();
@@ -156,7 +133,7 @@ private WordTokenizingTransformer(IHost host, ModelLoadContext ctx) :
base(host, ctx)
{
var columnsLength = ColumnPairs.Length;
- _columns = new ColumnInfo[columnsLength];
+ _columns = new WordTokenizingEstimator.ColumnInfo[columnsLength];
// *** Binary format ***
//
// for each added column
@@ -165,7 +142,7 @@ private WordTokenizingTransformer(IHost host, ModelLoadContext ctx) :
{
var separators = ctx.Reader.ReadCharArray();
Contracts.CheckDecode(Utils.Size(separators) > 0);
- _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, separators);
+ _columns[i] = new WordTokenizingEstimator.ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, separators);
}
}
@@ -199,19 +176,19 @@ private static WordTokenizingTransformer Create(IHostEnvironment env, ModelLoadC
}
// Factory method for SignatureDataTransform.
- internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- env.CheckValue(args, nameof(args));
+ env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- env.CheckValue(args.Columns, nameof(args.Columns));
- var cols = new ColumnInfo[args.Columns.Length];
+ env.CheckValue(options.Columns, nameof(options.Columns));
+ var cols = new WordTokenizingEstimator.ColumnInfo[options.Columns.Length];
for (int i = 0; i < cols.Length; i++)
{
- var item = args.Columns[i];
- var separators = args.CharArrayTermSeparators ?? PredictionUtil.SeparatorFromString(item.TermSeparators ?? args.TermSeparators);
- cols[i] = new ColumnInfo(item.Name, item.Source ?? item.Name, separators);
+ var item = options.Columns[i];
+ var separators = options.CharArrayTermSeparators ?? PredictionUtil.SeparatorFromString(item.TermSeparators ?? options.TermSeparators);
+ cols[i] = new WordTokenizingEstimator.ColumnInfo(item.Name, item.Source ?? item.Name, separators);
}
return new WordTokenizingTransformer(env, cols).MakeDataTransform(input);
@@ -428,7 +405,7 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, JToken srcToken)
///
public sealed class WordTokenizingEstimator : TrivialEstimator
{
- public static bool IsColumnTypeValid(ColumnType type) => type.GetItemType() is TextType;
+ internal static bool IsColumnTypeValid(ColumnType type) => type.GetItemType() is TextType;
internal const string ExpectedColumnType = "Text";
@@ -439,7 +416,7 @@ public sealed class WordTokenizingEstimator : TrivialEstimatorName of the column resulting from the transformation of .
/// Name of the column to transform. If set to , the value of the will be used as source.
/// The separators to use (uses space character by default).
- public WordTokenizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, char[] separators = null)
+ internal WordTokenizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, char[] separators = null)
: this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, separators)
{
}
@@ -450,8 +427,8 @@ public WordTokenizingEstimator(IHostEnvironment env, string outputColumnName, st
/// The environment.
/// Pairs of columns to run the tokenization on.
/// The separators to use (uses space character by default).
- public WordTokenizingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, char[] separators = null)
- : this(env, columns.Select(x => new WordTokenizingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, separators)).ToArray())
+ internal WordTokenizingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, char[] separators = null)
+ : this(env, columns.Select(x => new ColumnInfo(x.outputColumnName, x.inputColumnName, separators)).ToArray())
{
}
@@ -460,11 +437,34 @@ public WordTokenizingEstimator(IHostEnvironment env, (string outputColumnName, s
///
/// The environment.
/// Pairs of columns to run the tokenization on.
- public WordTokenizingEstimator(IHostEnvironment env, params WordTokenizingTransformer.ColumnInfo[] columns)
+ internal WordTokenizingEstimator(IHostEnvironment env, params ColumnInfo[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(WordTokenizingEstimator)), new WordTokenizingTransformer(env, columns))
{
}
+ public sealed class ColumnInfo
+ {
+ public readonly string Name;
+ public readonly string InputColumnName;
+ public readonly char[] Separators;
+ ///
+ /// Describes how the transformer handles one column pair.
+ ///
+ /// Name of the column resulting from the transformation of .
+ /// Name of column to transform. If set to , the value of the will be used as source.
+ /// Casing text using the rules of the invariant culture. If not specified, space will be used as separator.
+ public ColumnInfo(string name, string inputColumnName = null, char[] separators = null)
+ {
+ Name = name;
+ InputColumnName = inputColumnName ?? name;
+ Separators = separators ?? new[] { ' ' };
+ }
+ }
+
+ ///
+ /// Returns the of the schema which will be produced by the transformer.
+ /// Used for schema propagation and verification in a pipeline.
+ ///
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
diff --git a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
index 1bda6b77f7..e57a6c7b8b 100644
--- a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
+++ b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
@@ -35,7 +35,7 @@ public sealed class WordBagEstimator : TrainedWrapperEstimatorBase
/// Whether to include all ngram lengths up to or only .
/// Maximum number of ngrams to store in the dictionary.
/// Statistical measure used to evaluate how important a word is to a document in a corpus.
- public WordBagEstimator(IHostEnvironment env,
+ internal WordBagEstimator(IHostEnvironment env,
string outputColumnName,
string inputColumnName = null,
int ngramLength = 1,
@@ -59,7 +59,7 @@ public WordBagEstimator(IHostEnvironment env,
/// Whether to include all ngram lengths up to or only .
/// Maximum number of ngrams to store in the dictionary.
/// Statistical measure used to evaluate how important a word is to a document in a corpus.
- public WordBagEstimator(IHostEnvironment env,
+ internal WordBagEstimator(IHostEnvironment env,
string outputColumnName,
string[] inputColumnNames,
int ngramLength = 1,
@@ -82,7 +82,7 @@ public WordBagEstimator(IHostEnvironment env,
/// Whether to include all ngram lengths up to or only .
/// Maximum number of ngrams to store in the dictionary.
/// Statistical measure used to evaluate how important a word is to a document in a corpus.
- public WordBagEstimator(IHostEnvironment env,
+ internal WordBagEstimator(IHostEnvironment env,
(string outputColumnName, string[] inputColumnNames)[] columns,
int ngramLength = 1,
int skipLength = 0,
@@ -108,7 +108,7 @@ public WordBagEstimator(IHostEnvironment env,
public override TransformWrapper Fit(IDataView input)
{
// Create arguments.
- var args = new WordBagBuildingTransformer.Arguments
+ var options = new WordBagBuildingTransformer.Options
{
Columns = _columns.Select(x => new WordBagBuildingTransformer.Column { Name = x.outputColumnName, Source = x.sourceColumnsNames }).ToArray(),
NgramLength = _ngramLength,
@@ -118,7 +118,7 @@ public override TransformWrapper Fit(IDataView input)
Weighting = _weighting
};
- return new TransformWrapper(Host, WordBagBuildingTransformer.Create(Host, args, input), true);
+ return new TransformWrapper(Host, WordBagBuildingTransformer.Create(Host, options, input), true);
}
}
@@ -154,7 +154,7 @@ public sealed class WordHashBagEstimator : TrainedWrapperEstimatorBase
/// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one.
/// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained.
/// 0 does not retain any input values. -1 retains all input values mapping to each hash.
- public WordHashBagEstimator(IHostEnvironment env,
+ internal WordHashBagEstimator(IHostEnvironment env,
string outputColumnName,
string inputColumnName = null,
int hashBits = 16,
@@ -185,7 +185,7 @@ public WordHashBagEstimator(IHostEnvironment env,
/// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one.
/// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained.
/// 0 does not retain any input values. -1 retains all input values mapping to each hash.
- public WordHashBagEstimator(IHostEnvironment env,
+ internal WordHashBagEstimator(IHostEnvironment env,
string outputColumnName,
string[] inputColumnNames,
int hashBits = 16,
@@ -215,7 +215,7 @@ public WordHashBagEstimator(IHostEnvironment env,
/// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one.
/// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained.
/// 0 does not retain any input values. -1 retains all input values mapping to each hash.
- public WordHashBagEstimator(IHostEnvironment env,
+ internal WordHashBagEstimator(IHostEnvironment env,
(string outputColumnName, string[] inputColumnNames)[] columns,
int hashBits = 16,
int ngramLength = 1,
@@ -245,7 +245,7 @@ public WordHashBagEstimator(IHostEnvironment env,
public override TransformWrapper Fit(IDataView input)
{
// Create arguments.
- var args = new WordHashBagProducingTransformer.Arguments
+ var options = new WordHashBagProducingTransformer.Options
{
Columns = _columns.Select(x => new WordHashBagProducingTransformer.Column { Name = x.outputColumnName ,Source = x.inputColumnNames}).ToArray(),
HashBits = _hashBits,
@@ -257,7 +257,7 @@ public override TransformWrapper Fit(IDataView input)
InvertHash = _invertHash
};
- return new TransformWrapper(Host, WordHashBagProducingTransformer.Create(Host, args, input), true);
+ return new TransformWrapper(Host, WordHashBagProducingTransformer.Create(Host, options, input), true);
}
}
}
\ No newline at end of file
diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
index 1c8470a58d..7472f4b35a 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
+++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
@@ -75,7 +75,7 @@ Transforms.BinaryPredictionScoreColumnsRenamer For binary prediction, it renames
Transforms.BinNormalizer The values are assigned into equidensity bins and a value is mapped to its bin_number/number_of_bins. Microsoft.ML.Data.Normalize Bin Microsoft.ML.Transforms.Normalizers.NormalizeTransform+BinArguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.CategoricalHashOneHotVectorizer Converts the categorical value into an indicator array by hashing the value and using the hash as an index in the bag. If the input column is a vector, a single indicator bag is returned for it. Microsoft.ML.Transforms.Categorical.Categorical CatTransformHash Microsoft.ML.Transforms.Categorical.OneHotHashEncodingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.CategoricalOneHotVectorizer Converts the categorical value into an indicator array by building a dictionary of categories based on the data and using the id in the dictionary as the index in the array. Microsoft.ML.Transforms.Categorical.Categorical CatTransformDict Microsoft.ML.Transforms.Categorical.OneHotEncodingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
-Transforms.CharacterTokenizer Character-oriented tokenizer where text is considered a sequence of characters. Microsoft.ML.Transforms.Text.TextAnalytics CharTokenize Microsoft.ML.Transforms.Text.TokenizingByCharactersTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
+Transforms.CharacterTokenizer Character-oriented tokenizer where text is considered a sequence of characters. Microsoft.ML.Transforms.Text.TextAnalytics CharTokenize Microsoft.ML.Transforms.Text.TokenizingByCharactersTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.ColumnConcatenator Concatenates one or more columns of the same item type. Microsoft.ML.EntryPoints.SchemaManipulation ConcatColumns Microsoft.ML.Data.ColumnConcatenatingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.ColumnCopier Duplicates columns from the dataset Microsoft.ML.EntryPoints.SchemaManipulation CopyColumns Microsoft.ML.Transforms.ColumnCopyingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.ColumnSelector Selects a set of columns, dropping all others Microsoft.ML.EntryPoints.SchemaManipulation SelectColumns Microsoft.ML.Transforms.ColumnSelectingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
@@ -100,7 +100,7 @@ Transforms.KeyToTextConverter KeyToValueTransform utilizes KeyValues metadata to
Transforms.LabelColumnKeyBooleanConverter Transforms the label to either key or bool (if needed) to make it suitable for classification. Microsoft.ML.EntryPoints.FeatureCombiner PrepareClassificationLabel Microsoft.ML.EntryPoints.FeatureCombiner+ClassificationLabelInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.LabelIndicator Label remapper used by OVA Microsoft.ML.Transforms.LabelIndicatorTransform LabelIndicator Microsoft.ML.Transforms.LabelIndicatorTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.LabelToFloatConverter Transforms the label to float to make it suitable for regression. Microsoft.ML.EntryPoints.FeatureCombiner PrepareRegressionLabel Microsoft.ML.EntryPoints.FeatureCombiner+RegressionLabelInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
-Transforms.LightLda The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation. Microsoft.ML.Transforms.Text.TextAnalytics LightLda Microsoft.ML.Transforms.Text.LatentDirichletAllocationTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
+Transforms.LightLda The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation. Microsoft.ML.Transforms.Text.TextAnalytics LightLda Microsoft.ML.Transforms.Text.LatentDirichletAllocationTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.LogMeanVarianceNormalizer Normalizes the data based on the computed mean and variance of the logarithm of the data. Microsoft.ML.Data.Normalize LogMeanVar Microsoft.ML.Transforms.Normalizers.NormalizeTransform+LogMeanVarArguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.LpNormalizer Normalize vectors (rows) individually by rescaling them to unit norm (L2, L1 or LInf). Performs the following operation on a vector X: Y = (X - M) / D, where M is mean and D is either L2 norm, L1 norm or LInf norm. Microsoft.ML.Transforms.Projections.LpNormalization Normalize Microsoft.ML.Transforms.Projections.LpNormalizingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.ManyHeterogeneousModelCombiner Combines a sequence of TransformModels and a PredictorModel into a single PredictorModel. Microsoft.ML.EntryPoints.ModelOperations CombineModels Microsoft.ML.EntryPoints.ModelOperations+PredictorModelInput Microsoft.ML.EntryPoints.ModelOperations+PredictorModelOutput
@@ -112,7 +112,7 @@ Transforms.MissingValuesDropper Removes NAs from vector columns. Microsoft.ML.Tr
Transforms.MissingValuesRowDropper Filters out rows that contain missing values. Microsoft.ML.Transforms.NAHandling Filter Microsoft.ML.Transforms.NAFilter+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.MissingValueSubstitutor Create an output column of the same type and size of the input column, where missing values are replaced with either the default value or the mean/min/max value (for non-text columns only). Microsoft.ML.Transforms.NAHandling Replace Microsoft.ML.Transforms.MissingValueReplacingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.ModelCombiner Combines a sequence of TransformModels into a single model Microsoft.ML.EntryPoints.ModelOperations CombineTransformModels Microsoft.ML.EntryPoints.ModelOperations+CombineTransformModelsInput Microsoft.ML.EntryPoints.ModelOperations+CombineTransformModelsOutput
-Transforms.NGramTranslator Produces a bag of counts of ngrams (sequences of consecutive values of length 1-n) in a given vector of keys. It does so by building a dictionary of ngrams and using the id in the dictionary as the index in the bag. Microsoft.ML.Transforms.Text.TextAnalytics NGramTransform Microsoft.ML.Transforms.Text.NgramExtractingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
+Transforms.NGramTranslator Produces a bag of counts of ngrams (sequences of consecutive values of length 1-n) in a given vector of keys. It does so by building a dictionary of ngrams and using the id in the dictionary as the index in the bag. Microsoft.ML.Transforms.Text.TextAnalytics NGramTransform Microsoft.ML.Transforms.Text.NgramExtractingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.NoOperation Does nothing. Microsoft.ML.Data.NopTransform Nop Microsoft.ML.Data.NopTransform+NopInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.OptionalColumnCreator If the source column does not exist after deserialization, create a column with the right type and default values. Microsoft.ML.Transforms.OptionalColumnTransform MakeOptional Microsoft.ML.Transforms.OptionalColumnTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.PcaCalculator PCA is a dimensionality-reduction transform which computes the projection of a numeric vector onto a low-rank subspace. Microsoft.ML.Transforms.Projections.PcaTransformer Calculate Microsoft.ML.Transforms.Projections.PcaTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
@@ -133,5 +133,5 @@ Transforms.TrainTestDatasetSplitter Split the dataset into train and test sets M
Transforms.TreeLeafFeaturizer Trains a tree ensemble, or loads it from a file, then maps a numeric feature vector to three outputs: 1. A vector containing the individual tree outputs of the tree ensemble. 2. A vector indicating the leaves that the feature vector falls on in the tree ensemble. 3. A vector indicating the paths that the feature vector falls on in the tree ensemble. If a both a model file and a trainer are specified - will use the model file. If neither are specified, will train a default FastTree model. This can handle key labels by training a regression model towards their optionally permuted indices. Microsoft.ML.Data.TreeFeaturize Featurizer Microsoft.ML.Data.TreeEnsembleFeaturizerTransform+ArgumentsForEntryPoint Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.TwoHeterogeneousModelCombiner Combines a TransformModel and a PredictorModel into a single PredictorModel. Microsoft.ML.EntryPoints.ModelOperations CombineTwoModels Microsoft.ML.EntryPoints.ModelOperations+SimplePredictorModelInput Microsoft.ML.EntryPoints.ModelOperations+PredictorModelOutput
Transforms.VectorToImage Converts vector array into image type. Microsoft.ML.ImageAnalytics.EntryPoints.ImageAnalytics VectorToImage Microsoft.ML.ImageAnalytics.VectorToImageTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
-Transforms.WordEmbeddings Word Embeddings transform is a text featurizer which converts vectors of text tokens into sentence vectors using a pre-trained model Microsoft.ML.Transforms.Text.TextAnalytics WordEmbeddings Microsoft.ML.Transforms.Text.WordEmbeddingsExtractingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
-Transforms.WordTokenizer The input to this transform is text, and the output is a vector of text containing the words (tokens) in the original text. The separator is space, but can be specified as any other character (or multiple characters) if needed. Microsoft.ML.Transforms.Text.TextAnalytics DelimitedTokenizeTransform Microsoft.ML.Transforms.Text.WordTokenizingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
+Transforms.WordEmbeddings Word Embeddings transform is a text featurizer which converts vectors of text tokens into sentence vectors using a pre-trained model Microsoft.ML.Transforms.Text.TextAnalytics WordEmbeddings Microsoft.ML.Transforms.Text.WordEmbeddingsExtractingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
+Transforms.WordTokenizer The input to this transform is text, and the output is a vector of text containing the words (tokens) in the original text. The separator is space, but can be specified as any other character (or multiple characters) if needed. Microsoft.ML.Transforms.Text.TextAnalytics DelimitedTokenizeTransform Microsoft.ML.Transforms.Text.WordTokenizingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
index f8a27276a5..613dda6609 100644
--- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
+++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
@@ -110,8 +110,8 @@ public void TrainSentiment()
args.UseCharExtractor = false;
args.UseWordExtractor = false;
}).Fit(loader).Transform(loader);
- var trans = _env.Transforms.Text.ExtractWordEmbeddings("Features", "WordEmbeddings_TransformedText",
- WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe).Fit(text).Transform(text);
+ var trans = _env.Transforms.Text.ExtractWordEmbeddings("Features", "WordEmbeddings_TransformedText",
+ WordEmbeddingsExtractingEstimator.PretrainedModelKind.Sswe).Fit(text).Transform(text);
// Train
var trainer = _env.MulticlassClassification.Trainers.StochasticDualCoordinateAscent();
var predicted = trainer.Fit(trans);
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
index b9c43e2653..bd6c65b3d3 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
@@ -1007,7 +1007,7 @@ public void EntryPointPipelineEnsembleText()
else
{
data = WordHashBagProducingTransformer.Create(Env,
- new WordHashBagProducingTransformer.Arguments()
+ new WordHashBagProducingTransformer.Options()
{
Columns =
new[] { new WordHashBagProducingTransformer.Column() { Name = "Features", Source = new[] { "Text" } }, }
@@ -2491,7 +2491,7 @@ public void TestInputBuilderComponentFactories()
Assert.True(success);
var inputBuilder = new InputBuilder(Env, info.InputType, catalog);
- var args = new SdcaBinaryTrainer.Options()
+ var options = new SdcaBinaryTrainer.Options()
{
NormalizeFeatures = NormalizeOption.Yes,
CheckFrequency = 42
@@ -2504,7 +2504,7 @@ public void TestInputBuilderComponentFactories()
inputBindingMap.Add("TrainingData", new List() { parameterBinding });
inputMap.Add(parameterBinding, new SimpleVariableBinding("data"));
- var result = inputBuilder.GetJsonObject(args, inputBindingMap, inputMap);
+ var result = inputBuilder.GetJsonObject(options, inputBindingMap, inputMap);
var json = FixWhitespace(result.ToString(Formatting.Indented));
var expected =
@@ -2516,8 +2516,8 @@ public void TestInputBuilderComponentFactories()
expected = FixWhitespace(expected);
Assert.Equal(expected, json);
- args.LossFunction = new HingeLoss.Arguments();
- result = inputBuilder.GetJsonObject(args, inputBindingMap, inputMap);
+ options.LossFunction = new HingeLoss.Arguments();
+ result = inputBuilder.GetJsonObject(options, inputBindingMap, inputMap);
json = FixWhitespace(result.ToString(Formatting.Indented));
expected =
@@ -2532,8 +2532,8 @@ public void TestInputBuilderComponentFactories()
expected = FixWhitespace(expected);
Assert.Equal(expected, json);
- args.LossFunction = new HingeLoss.Arguments() { Margin = 2 };
- result = inputBuilder.GetJsonObject(args, inputBindingMap, inputMap);
+ options.LossFunction = new HingeLoss.Arguments() { Margin = 2 };
+ result = inputBuilder.GetJsonObject(options, inputBindingMap, inputMap);
json = FixWhitespace(result.ToString(Formatting.Indented));
expected =
@@ -3629,11 +3629,11 @@ public void EntryPointWordEmbeddings()
},
InputFile = inputFile,
}).Data;
- var embedding = Transforms.Text.TextAnalytics.WordEmbeddings(Env, new WordEmbeddingsExtractingTransformer.Arguments()
+ var embedding = Transforms.Text.TextAnalytics.WordEmbeddings(Env, new WordEmbeddingsExtractingTransformer.Options()
{
Data = dataView,
Columns = new[] { new WordEmbeddingsExtractingTransformer.Column { Name = "Features", Source = "Text" } },
- ModelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe
+ ModelKind = WordEmbeddingsExtractingEstimator.PretrainedModelKind.Sswe
});
var result = embedding.OutputData;
using (var cursor = result.GetRowCursorForAllColumns())
diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
index d089153764..7541959768 100644
--- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
+++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
@@ -1324,7 +1324,7 @@ public void TestLDATransform()
builder.AddColumn("F1V", NumberType.Float, data);
var srcView = builder.GetDataView();
- var est = new LatentDirichletAllocationEstimator(Env, "F1V", numTopic: 3, numSummaryTermPerTopic: 3, alphaSum: 3, numThreads: 1, resetRandomGenerator: true);
+ var est = ML.Transforms.Text.LatentDirichletAllocation("F1V", numTopic: 3, numSummaryTermPerTopic: 3, alphaSum: 3, numThreads: 1, resetRandomGenerator: true);
var ldaTransformer = est.Fit(srcView);
var transformedData = ldaTransformer.Transform(srcView);
@@ -1362,6 +1362,7 @@ public void TestLDATransform()
public void TestLdaTransformerEmptyDocumentException()
{
var builder = new ArrayDataViewBuilder(Env);
+ string colName = "Zeros";
var data = new[]
{
new[] { (Float)0.0, (Float)0.0, (Float)0.0 },
@@ -1369,25 +1370,16 @@ public void TestLdaTransformerEmptyDocumentException()
new[] { (Float)0.0, (Float)0.0, (Float)0.0 },
};
- builder.AddColumn("Zeros", NumberType.Float, data);
+ builder.AddColumn(colName, NumberType.Float, data);
var srcView = builder.GetDataView();
- var col = new LatentDirichletAllocationTransformer.Column()
- {
- Source = "Zeros",
- };
- var args = new LatentDirichletAllocationTransformer.Arguments()
- {
- Columns = new[] { col }
- };
-
try
{
- var lda = new LatentDirichletAllocationEstimator(Env, "Zeros").Fit(srcView).Transform(srcView);
+ var lda = ML.Transforms.Text.LatentDirichletAllocation("Zeros").Fit(srcView).Transform(srcView);
}
catch (InvalidOperationException ex)
{
- Assert.Equal(ex.Message, string.Format("The specified documents are all empty in column '{0}'.", col.Source));
+ Assert.Equal(ex.Message, string.Format("The specified documents are all empty in column '{0}'.", colName));
return;
}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs
index 0960e32da4..01501677be 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs
@@ -473,7 +473,7 @@ private void TextFeaturizationOn(string dataPath)
BagOfTrichar: r.Message.TokenizeIntoCharacters().ToNgrams(ngramLength: 3, weighting: NgramExtractingEstimator.WeightingCriteria.TfIdf),
// NLP pipeline 4: word embeddings.
- Embeddings: r.Message.NormalizeText().TokenizeText().WordEmbeddings(WordEmbeddingsExtractingTransformer.PretrainedModelKind.GloVeTwitter25D)
+ Embeddings: r.Message.NormalizeText().TokenizeText().WordEmbeddings(WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVeTwitter25D)
));
// Let's train our pipeline, and then apply it to the same data.
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs
index 2cc1803347..39957cc031 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs
@@ -316,7 +316,7 @@ private void TextFeaturizationOn(string dataPath)
// NLP pipeline 4: word embeddings.
.Append(mlContext.Transforms.Text.TokenizeWords("TokenizedMessage", "NormalizedMessage"))
.Append(mlContext.Transforms.Text.ExtractWordEmbeddings("Embeddings", "TokenizedMessage",
- WordEmbeddingsExtractingTransformer.PretrainedModelKind.GloVeTwitter25D));
+ WordEmbeddingsExtractingEstimator.PretrainedModelKind.GloVeTwitter25D));
// Let's train our pipeline, and then apply it to the same data.
// Note that even on a small dataset of 70KB the pipeline above can take up to a minute to completely train.
diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs
index e4b4410844..7c8ef48d22 100644
--- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs
@@ -183,7 +183,7 @@ public void StopWordsRemoverFromFactory()
var tokenized = new WordTokenizingTransformer(ML, new[]
{
- new WordTokenizingTransformer.ColumnInfo("Text", "Text")
+ new WordTokenizingEstimator.ColumnInfo("Text", "Text")
}).Transform(data);
var xf = factory.CreateComponent(ML, tokenized,
diff --git a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs
index 2744fcfd48..38a56d2fa6 100644
--- a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs
@@ -88,7 +88,7 @@ public void ValueMapInputIsVectorTest()
var values = new List() { 1, 2, 3, 4 };
var estimator = new WordTokenizingEstimator(Env, new[]{
- new WordTokenizingTransformer.ColumnInfo("TokenizeA", "A")
+ new WordTokenizingEstimator.ColumnInfo("TokenizeA", "A")
}).Append(new ValueMappingEstimator, int>(Env, keys, values, new[] { ("VecD", "TokenizeA"), ("E", "B"), ("F", "C") }));
var schema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));
Assert.True(schema.TryFindColumn("VecD", out var originalColumn));
@@ -124,7 +124,7 @@ public void ValueMapInputIsVectorAndValueAsStringKeyTypeTest()
var values = new List>() { "a".AsMemory(), "b".AsMemory(), "c".AsMemory(), "d".AsMemory() };
var estimator = new WordTokenizingEstimator(Env, new[]{
- new WordTokenizingTransformer.ColumnInfo("TokenizeA", "A")
+ new WordTokenizingEstimator.ColumnInfo("TokenizeA", "A")
}).Append(new ValueMappingEstimator, ReadOnlyMemory>(Env, keys, values, true, new[] { ("VecD", "TokenizeA"), ("E", "B"), ("F", "C") }));
var t = estimator.Fit(dataView);
diff --git a/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs b/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs
index 37dfe4cd9b..539265ca41 100644
--- a/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs
@@ -42,7 +42,7 @@ public void TestWordEmbeddings()
.Append(ML.Transforms.Text.RemoveDefaultStopWords("CleanWords", "Words"));
var words = est.Fit(data).Transform(data);
- var pipe = ML.Transforms.Text.ExtractWordEmbeddings("WordEmbeddings", "CleanWords", modelKind: WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe);
+ var pipe = ML.Transforms.Text.ExtractWordEmbeddings("WordEmbeddings", "CleanWords", modelKind: WordEmbeddingsExtractingEstimator.PretrainedModelKind.Sswe);
TestEstimatorCore(pipe, words, invalidInput: data);
diff --git a/test/Microsoft.ML.Tests/Transformers/WordTokenizeTests.cs b/test/Microsoft.ML.Tests/Transformers/WordTokenizeTests.cs
index 5f6723526d..397dbf50ae 100644
--- a/test/Microsoft.ML.Tests/Transformers/WordTokenizeTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/WordTokenizeTests.cs
@@ -57,8 +57,8 @@ public void WordTokenizeWorkout()
var invalidData = new[] { new TestWrong() { A =1, B = new float[2] { 2,3 } } };
var invalidDataView = ML.Data.ReadFromEnumerable(invalidData);
var pipe = new WordTokenizingEstimator(Env, new[]{
- new WordTokenizingTransformer.ColumnInfo("TokenizeA", "A"),
- new WordTokenizingTransformer.ColumnInfo("TokenizeB", "B"),
+ new WordTokenizingEstimator.ColumnInfo("TokenizeA", "A"),
+ new WordTokenizingEstimator.ColumnInfo("TokenizeB", "B"),
});
TestEstimatorCore(pipe, dataView, invalidInput: invalidDataView);
@@ -99,8 +99,8 @@ public void TestOldSavingAndLoading()
var dataView = ML.Data.ReadFromEnumerable(data);
var pipe = new WordTokenizingEstimator(Env, new[]{
- new WordTokenizingTransformer.ColumnInfo("TokenizeA", "A"),
- new WordTokenizingTransformer.ColumnInfo("TokenizeB", "B"),
+ new WordTokenizingEstimator.ColumnInfo("TokenizeA", "A"),
+ new WordTokenizingEstimator.ColumnInfo("TokenizeB", "B"),
});
var result = pipe.Fit(dataView).Transform(dataView);
var resultRoles = new RoleMappedData(result);