diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index 543b997e6d..3d29bce613 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -137,5 +137,25 @@ public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTr OutputData = view }; } + + [TlcModule.EntryPoint(Name = "Transforms.WordEmbeddings", + Desc = WordEmbeddingsTransform.Summary, + UserName = WordEmbeddingsTransform.UserName, + ShortName = WordEmbeddingsTransform.ShortName, + XmlInclude = new[] { @"", + @"" })] + public static CommonOutputs.TransformOutput WordEmbeddings(IHostEnvironment env, WordEmbeddingsTransform.Arguments input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(input, nameof(input)); + + var h = EntryPointUtils.CheckArgsAndCreateHost(env, "WordEmbeddings", input); + var view = new WordEmbeddingsTransform(h, input, input.Data); + return new CommonOutputs.TransformOutput() + { + Model = new TransformModel(h, view, input.Data), + OutputData = view + }; + } } } diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs new file mode 100644 index 0000000000..bf85ddf42f --- /dev/null +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs @@ -0,0 +1,444 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Model; + +[assembly: LoadableClass(WordEmbeddingsTransform.Summary, typeof(IDataTransform), typeof(WordEmbeddingsTransform), typeof(WordEmbeddingsTransform.Arguments), + typeof(SignatureDataTransform), WordEmbeddingsTransform.UserName, "WordEmbeddingsTransform", WordEmbeddingsTransform.ShortName, DocName = "transform/WordEmbeddingsTransform.md")] + +[assembly: LoadableClass(typeof(WordEmbeddingsTransform), null, typeof(SignatureLoadDataTransform), + WordEmbeddingsTransform.UserName, WordEmbeddingsTransform.LoaderSignature)] + +namespace Microsoft.ML.Runtime.Data +{ + /// + public sealed class WordEmbeddingsTransform : OneToOneTransformBase + { + public sealed class Column : OneToOneColumn + { + public static Column Parse(string str) + { + Contracts.AssertNonEmpty(str); + + var res = new Column(); + if (res.TryParse(str)) + return res; + return null; + } + + public bool TryUnparse(StringBuilder sb) + { + Contracts.AssertValue(sb); + return TryUnparseCore(sb); + } + } + + public sealed class Arguments : TransformInputBase + { + [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 0)] + public Column[] Column; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Pre-trained model used to create the vocabulary", ShortName = "model", SortOrder = 1)] + public PretrainedModelKind? ModelKind = PretrainedModelKind.Sswe; + + [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Filename for custom word embedding model", + ShortName = "dataFile", SortOrder = 2)] + public string CustomLookupTable; + } + + internal const string Summary = "Word Embeddings transform is a text featurizer which converts vectors of text tokens into sentence " + + "vectors using a pre-trained model"; + internal const string UserName = "Word Embeddings Transform"; + internal const string ShortName = "WordEmbeddings"; + public const string LoaderSignature = "WordEmbeddingsTransform"; + + public static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "W2VTRANS", + verWrittenCur: 0x00010001, //Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature); + } + + private readonly PretrainedModelKind? _modelKind; + private readonly string _modelFileNameWithPath; + private readonly Model _currentVocab; + private static object _embeddingsLock = new object(); + private readonly VectorType _outputType; + private readonly bool _customLookup; + private readonly int _linesToSkip; + private static Dictionary> _vocab = new Dictionary>(); + + private sealed class Model + { + private readonly BigArray _wordVectors; + private readonly NormStr.Pool _pool; + public readonly int Dimension; + + public Model(int dimension) + { + Dimension = dimension; + _wordVectors = new BigArray(); + _pool = new NormStr.Pool(); + } + + public void AddWordVector(IChannel ch, string word, float[] wordVector) + { + ch.Assert(wordVector.Length == Dimension); + if (_pool.Get(word) == null) + { + _pool.Add(word); + _wordVectors.AddRange(wordVector, Dimension); + } + } + + public bool GetWordVector(ref DvText word, float[] wordVector) + { + if (word.IsNA) + return false; + string rawWord = word.GetRawUnderlyingBufferInfo(out int ichMin, out int ichLim); + NormStr str = _pool.Get(rawWord, ichMin, ichLim); + if (str != null) + { + _wordVectors.CopyTo(str.Id * Dimension, wordVector, Dimension); + return true; + } + return false; + } + } + + private const string RegistrationName = "WordEmbeddings"; + + private const int Timeout = 10 * 60 * 1000; + + /// + /// Public constructor corresponding to . + /// + public WordEmbeddingsTransform(IHostEnvironment env, Arguments args, IDataView input) + : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, + input, TestIsTextVector) + { + if (args.ModelKind == null) + args.ModelKind = PretrainedModelKind.Sswe; + Host.CheckUserArg(!args.ModelKind.HasValue || Enum.IsDefined(typeof(PretrainedModelKind), args.ModelKind), nameof(args.ModelKind)); + Host.AssertNonEmpty(Infos); + Host.Assert(Infos.Length == Utils.Size(args.Column)); + + _customLookup = !string.IsNullOrWhiteSpace(args.CustomLookupTable); + if (_customLookup) + { + _modelKind = null; + _modelFileNameWithPath = args.CustomLookupTable; + } + else + { + _modelKind = args.ModelKind; + _modelFileNameWithPath = EnsureModelFile(env, out _linesToSkip, (PretrainedModelKind)_modelKind); + } + + Host.CheckNonWhiteSpace(_modelFileNameWithPath, nameof(_modelFileNameWithPath)); + _currentVocab = GetVocabularyDictionary(); + _outputType = new VectorType(NumberType.R4, 3 * _currentVocab.Dimension); + Metadata.Seal(); + } + + private WordEmbeddingsTransform(IHost host, ModelLoadContext ctx, IDataView input) + : base(host, ctx, input, TestIsTextVector) + { + Host.AssertValue(ctx); + Host.AssertNonEmpty(Infos); + _customLookup = ctx.Reader.ReadBoolByte(); + + if (_customLookup) + { + _modelFileNameWithPath = ctx.LoadNonEmptyString(); + _modelKind = null; + } + else + { + _modelKind = (PretrainedModelKind)ctx.Reader.ReadUInt32(); + _modelFileNameWithPath = EnsureModelFile(Host, out _linesToSkip, (PretrainedModelKind)_modelKind); + } + + Host.CheckNonWhiteSpace(_modelFileNameWithPath, nameof(_modelFileNameWithPath)); + _currentVocab = GetVocabularyDictionary(); + _outputType = new VectorType(NumberType.R4, 3 * _currentVocab.Dimension); + Metadata.Seal(); + } + + public static WordEmbeddingsTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + IHost h = env.Register(RegistrationName); + h.CheckValue(ctx, nameof(ctx)); + h.CheckValue(input, nameof(input)); + return h.Apply("Loading Model", + ch => new WordEmbeddingsTransform(h, ctx, input)); + } + + public override void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + SaveBase(ctx); + ctx.Writer.WriteBoolByte(_customLookup); + if (_customLookup) + ctx.SaveString(_modelFileNameWithPath); + else + ctx.Writer.Write((uint)_modelKind); + } + + protected override ColumnType GetColumnTypeCore(int iinfo) + { + Host.Assert(0 <= iinfo && iinfo < Infos.Length); + return _outputType; + } + + protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + { + Host.AssertValue(ch); + ch.AssertValue(input); + ch.Assert(0 <= iinfo && iinfo < Infos.Length); + disposer = null; + + var info = Infos[iinfo]; + if (!info.TypeSrc.IsVector) + { + throw Host.ExceptParam(nameof(input), + "Text input given, expects a text vector"); + } + return GetGetterVec(ch, input, iinfo); + } + + private ValueGetter> GetGetterVec(IChannel ch, IRow input, int iinfo) + { + Host.AssertValue(ch); + ch.AssertValue(input); + ch.Assert(0 <= iinfo && iinfo < Infos.Length); + + var info = Infos[iinfo]; + ch.Assert(info.TypeSrc.IsVector); + ch.Assert(info.TypeSrc.ItemType.IsText); + + var srcGetter = input.GetGetter>(info.Source); + var src = default(VBuffer); + int dimension = _currentVocab.Dimension; + float[] wordVector = new float[_currentVocab.Dimension]; + + return + (ref VBuffer dst) => + { + int deno = 0; + srcGetter(ref src); + var values = dst.Values; + if (Utils.Size(values) != 3 * dimension) + values = new float[3 * dimension]; + int offset = 2 * dimension; + for (int i = 0; i < dimension; i++) + { + values[i] = float.MaxValue; + values[i + dimension] = 0; + values[i + offset] = float.MinValue; + } + for (int word = 0; word < src.Count; word++) + { + if (_currentVocab.GetWordVector(ref src.Values[word], wordVector)) + { + deno++; + for (int i = 0; i < dimension; i++) + { + float currentTerm = wordVector[i]; + if (values[i] > currentTerm) + values[i] = currentTerm; + values[dimension + i] += currentTerm; + if (values[offset + i] < currentTerm) + values[offset + i] = currentTerm; + } + } + } + + if (deno != 0) + for (int index = 0; index < dimension; index++) + values[index + dimension] /= deno; + + dst = new VBuffer(values.Length, values, dst.Indices); + }; + } + + public enum PretrainedModelKind + { + [TGUI(Label = "GloVe 50D")] + GloVe50D = 0, + + [TGUI(Label = "GloVe 100D")] + GloVe100D = 1, + + [TGUI(Label = "GloVe 200D")] + GloVe200D = 2, + + [TGUI(Label = "GloVe 300D")] + GloVe300D = 3, + + [TGUI(Label = "GloVe Twitter 25D")] + GloVeTwitter25D = 4, + + [TGUI(Label = "GloVe Twitter 50D")] + GloVeTwitter50D = 5, + + [TGUI(Label = "GloVe Twitter 100D")] + GloVeTwitter100D = 6, + + [TGUI(Label = "GloVe Twitter 200D")] + GloVeTwitter200D = 7, + + [TGUI(Label = "fastText Wikipedia 300D")] + FastTextWikipedia300D = 8, + + [TGUI(Label = "Sentiment-Specific Word Embedding")] + Sswe = 9 + } + + private static Dictionary _modelsMetaData = new Dictionary() + { + { PretrainedModelKind.GloVe50D, "glove.6B.50d.txt" }, + { PretrainedModelKind.GloVe100D, "glove.6B.100d.txt" }, + { PretrainedModelKind.GloVe200D, "glove.6B.200d.txt" }, + { PretrainedModelKind.GloVe300D, "glove.6B.300d.txt" }, + { PretrainedModelKind.GloVeTwitter25D, "glove.twitter.27B.25d.txt" }, + { PretrainedModelKind.GloVeTwitter50D, "glove.twitter.27B.50d.txt" }, + { PretrainedModelKind.GloVeTwitter100D, "glove.twitter.27B.100d.txt" }, + { PretrainedModelKind.GloVeTwitter200D, "glove.twitter.27B.200d.txt" }, + { PretrainedModelKind.FastTextWikipedia300D, "wiki.en.vec" }, + { PretrainedModelKind.Sswe, "sentiment.emd" } + }; + + private static Dictionary _linesToSkipInModels = new Dictionary() + { { PretrainedModelKind.FastTextWikipedia300D, 1 } }; + + private string EnsureModelFile(IHostEnvironment env, out int linesToSkip, PretrainedModelKind kind) + { + linesToSkip = 0; + if (_modelsMetaData.ContainsKey(kind)) + { + var modelFileName = _modelsMetaData[kind]; + if (_linesToSkipInModels.ContainsKey(kind)) + linesToSkip = _linesToSkipInModels[kind]; + using (var ch = Host.Start("Ensuring resources")) + { + string dir = kind == PretrainedModelKind.Sswe ? Path.Combine("Text", "Sswe") : "WordVectors"; + var url = $"{dir}/{modelFileName}"; + var ensureModel = ResourceManagerUtils.Instance.EnsureResource(Host, ch, url, modelFileName, dir, Timeout); + ensureModel.Wait(); + var errorResult = ResourceManagerUtils.GetErrorMessage(out var errorMessage, ensureModel.Result); + if (errorResult != null) + { + var directory = Path.GetDirectoryName(errorResult.FileName); + var name = Path.GetFileName(errorResult.FileName); + throw ch.Except($"{errorMessage}\nModel file for Word Embedding transform could not be found! " + + $@"Please copy the model file '{name}' from '{url}' to '{directory}'."); + } + return ensureModel.Result.FileName; + } + } + throw Host.Except($"Can't map model kind = {kind} to specific file, please refer to https://aka.ms/MLNetIssue for assistance"); + } + + private Model GetVocabularyDictionary() + { + int dimension = 0; + if (!File.Exists(_modelFileNameWithPath)) + throw Host.Except("Custom word embedding model file '{0}' could not be found for Word Embeddings transform.", _modelFileNameWithPath); + + if (_vocab.ContainsKey(_modelFileNameWithPath) && _vocab[_modelFileNameWithPath] != null) + { + if (_vocab[_modelFileNameWithPath].TryGetTarget(out Model model)) + { + dimension = model.Dimension; + return model; + } + } + + lock (_embeddingsLock) + { + if (_vocab.ContainsKey(_modelFileNameWithPath) && _vocab[_modelFileNameWithPath] != null) + { + if (_vocab[_modelFileNameWithPath].TryGetTarget(out Model modelObject)) + { + dimension = modelObject.Dimension; + return modelObject; + } + } + + Model model = null; + using (StreamReader sr = File.OpenText(_modelFileNameWithPath)) + { + string line; + int lineNumber = 1; + char[] delimiters = { ' ', '\t' }; + using (var ch = Host.Start(LoaderSignature)) + using (var pch = Host.StartProgressChannel("Building Vocabulary from Model File for Word Embeddings Transform")) + { + var header = new ProgressHeader(new[] { "lines" }); + pch.SetHeader(header, e => e.SetProgress(0, lineNumber)); + string firstLine = sr.ReadLine(); + while ((line = sr.ReadLine()) != null) + { + if (lineNumber >= _linesToSkip) + { + string[] words = line.TrimEnd().Split(delimiters); + dimension = words.Length - 1; + if (model == null) + model = new Model(dimension); + if (model.Dimension != dimension) + ch.Warning($"Dimension mismatch while reading model file: '{_modelFileNameWithPath}', line number {lineNumber + 1}, expected dimension = {model.Dimension}, received dimension = {dimension}"); + else + { + float tmp; + string key = words[0]; + float[] value = words.Skip(1).Select(x => float.TryParse(x, out tmp) ? tmp : Single.NaN).ToArray(); + if (!value.Contains(Single.NaN)) + model.AddWordVector(ch, key, value); + else + ch.Warning($"Parsing error while reading model file: '{_modelFileNameWithPath}', line number {lineNumber + 1}"); + } + } + lineNumber++; + } + + // Handle first line of the embedding file separately since some embedding files including fastText have a single-line header + string[] wordsInFirstLine = firstLine.TrimEnd().Split(delimiters); + dimension = wordsInFirstLine.Length - 1; + if (model == null) + model = new Model(dimension); + float temp; + string firstKey = wordsInFirstLine[0]; + float[] firstValue = wordsInFirstLine.Skip(1).Select(x => float.TryParse(x, out temp) ? temp : Single.NaN).ToArray(); + if (!firstValue.Contains(Single.NaN)) + model.AddWordVector(ch, firstKey, firstValue); + else + ch.Warning($"Parsing error while reading model file: '{_modelFileNameWithPath}', line number 1"); + pch.Checkpoint(lineNumber); + } + } + _vocab[_modelFileNameWithPath] = new WeakReference(model, false); + return model; + } + } + } +} diff --git a/src/Microsoft.ML.Transforms/Text/doc.xml b/src/Microsoft.ML.Transforms/Text/doc.xml index 5f734e1cfd..2d077dc3ed 100644 --- a/src/Microsoft.ML.Transforms/Text/doc.xml +++ b/src/Microsoft.ML.Transforms/Text/doc.xml @@ -179,7 +179,49 @@ - pipeline.Add(new LightLda(("InTextCol" , "OutTextCol"))); + pipeline.Add(new LightLda(("InTextCol" , "OutTextCol"))); + + + + + + + Word Embeddings transform is a text featurizer which converts vectors of text tokens into sentence vectors using a pre-trained model. + + + WordEmbeddings wrap different embedding models, such as GloVe. Users can specify which embedding to use. + The available options are various versions of GloVe Models, fastText, and SSWE. + + Note: As WordEmbedding requires a column with text vector, e.g. %3C%27this%27, %27is%27, %27good%27%3E, users need to create an input column by + using the output_tokens=True for TextTransform to convert a column with sentences like "This is good" into %3C%27this%27, %27is%27, %27good%27 %3E. + The suffix of %27_TransformedText%27 is added to the original column name to create the output token column. For instance if the input column is %27body%27, + the output tokens column is named %27body_TransformedText%27. + + + License attributes for pretrained models: + + + + "fastText Wikipedia 300D" by Facebook, Inc. is licensed under CC-BY-SA 3.0 based on: + P. Bojanowski*, E. Grave*, A. Joulin, T. Mikolov,Enriching Word Vectors with Subword Information + %40article%7Bbojanowski2016enriching%2C%0A%20%20title%3D%7BEnriching%20Word%20Vectors%20with%20Subword%20Information%7D%2C%0A%20%20author%3D%7BBojanowski%2C%20Piotr%20and%20Grave%2C%20Edouard%20and%20Joulin%2C%20Armand%20and%20Mikolov%2C%20Tomas%7D%2C%0A%20%20journal%3D%7BarXiv%20preprint%20arXiv%3A1607.04606%7D%2C%0A%20%20year%3D%7B2016%7D%0A%7D + More information can be found here. + + + + + GloVe models by Stanford University, or (Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. GloVe: Global Vectors for Word Representation) is licensed under PDDL. + More information can be found here. Repository can be found here. + + + + + + + + + + pipeline.Add(new WordEmbeddings(("InVectorTextCol" , "OutTextCol"))); diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index b2181fb256..0940a11775 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -1522,6 +1522,18 @@ public void Add(Microsoft.ML.Transforms.TwoHeterogeneousModelCombiner input, Mic _jsonNodes.Add(Serialize("Transforms.TwoHeterogeneousModelCombiner", input, output)); } + public Microsoft.ML.Transforms.WordEmbeddings.Output Add(Microsoft.ML.Transforms.WordEmbeddings input) + { + var output = new Microsoft.ML.Transforms.WordEmbeddings.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Transforms.WordEmbeddings input, Microsoft.ML.Transforms.WordEmbeddings.Output output) + { + _jsonNodes.Add(Serialize("Transforms.WordEmbeddings", input, output)); + } + public Microsoft.ML.Transforms.WordTokenizer.Output Add(Microsoft.ML.Transforms.WordTokenizer input) { var output = new Microsoft.ML.Transforms.WordTokenizer.Output(); @@ -15431,6 +15443,148 @@ public sealed class Output } } + namespace Transforms + { + public enum WordEmbeddingsTransformPretrainedModelKind + { + GloVe50D = 0, + GloVe100D = 1, + GloVe200D = 2, + GloVe300D = 3, + GloVeTwitter25D = 4, + GloVeTwitter50D = 5, + GloVeTwitter100D = 6, + GloVeTwitter200D = 7, + FastTextWikipedia300D = 8, + Sswe = 9 + } + + + public sealed partial class WordEmbeddingsTransformColumn : OneToOneColumn, IOneToOneColumn + { + /// + /// Name of the new column + /// + public string Name { get; set; } + + /// + /// Name of the source column + /// + public string Source { get; set; } + + } + + /// + /// + public sealed partial class WordEmbeddings : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem + { + + public WordEmbeddings() + { + } + + public WordEmbeddings(params string[] inputColumns) + { + if (inputColumns != null) + { + foreach (string input in inputColumns) + { + AddColumn(input); + } + } + } + + public WordEmbeddings(params (string inputColumn, string outputColumn)[] inputOutputColumns) + { + if (inputOutputColumns != null) + { + foreach (var inputOutput in inputOutputColumns) + { + AddColumn(inputOutput.outputColumn, inputOutput.inputColumn); + } + } + } + + public void AddColumn(string inputColumn) + { + var list = Column == null ? new List() : new List(Column); + list.Add(OneToOneColumn.Create(inputColumn)); + Column = list.ToArray(); + } + + public void AddColumn(string outputColumn, string inputColumn) + { + var list = Column == null ? new List() : new List(Column); + list.Add(OneToOneColumn.Create(outputColumn, inputColumn)); + Column = list.ToArray(); + } + + + /// + /// New column definition(s) (optional form: name:src) + /// + public WordEmbeddingsTransformColumn[] Column { get; set; } + + /// + /// Pre-trained model used to create the vocabulary + /// + public WordEmbeddingsTransformPretrainedModelKind? ModelKind { get; set; } = WordEmbeddingsTransformPretrainedModelKind.Sswe; + + /// + /// Filename for custom word embedding model + /// + public string CustomLookupTable { get; set; } + + /// + /// Input dataset + /// + public Var Data { get; set; } = new Var(); + + + public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + { + /// + /// Transformed dataset + /// + public Var OutputData { get; set; } = new Var(); + + /// + /// Transform model + /// + public Var Model { get; set; } = new Var(); + + } + public Var GetInputData() => Data; + + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) + { + if (previousStep != null) + { + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(WordEmbeddings)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } + + Data = dataStep.Data; + } + Output output = experiment.Add(this); + return new WordEmbeddingsPipelineStep(output); + } + + private class WordEmbeddingsPipelineStep : ILearningPipelineDataStep + { + public WordEmbeddingsPipelineStep(Output output) + { + Data = output.OutputData; + Model = output.Model; + } + + public Var Data { get; } + public Var Model { get; } + } + } + } + namespace Transforms { diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index fe46ca26e1..d424953f70 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -123,4 +123,5 @@ Transforms.TextToKeyConverter Converts input values (words, numbers, etc.) to in Transforms.TrainTestDatasetSplitter Split the dataset into train and test sets Microsoft.ML.Runtime.EntryPoints.TrainTestSplit Split Microsoft.ML.Runtime.EntryPoints.TrainTestSplit+Input Microsoft.ML.Runtime.EntryPoints.TrainTestSplit+Output Transforms.TreeLeafFeaturizer Trains a tree ensemble, or loads it from a file, then maps a numeric feature vector to three outputs: 1. A vector containing the individual tree outputs of the tree ensemble. 2. A vector indicating the leaves that the feature vector falls on in the tree ensemble. 3. A vector indicating the paths that the feature vector falls on in the tree ensemble. If a both a model file and a trainer are specified - will use the model file. If neither are specified, will train a default FastTree model. This can handle key labels by training a regression model towards their optionally permuted indices. Microsoft.ML.Runtime.Data.TreeFeaturize Featurizer Microsoft.ML.Runtime.Data.TreeEnsembleFeaturizerTransform+ArgumentsForEntryPoint Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.TwoHeterogeneousModelCombiner Combines a TransformModel and a PredictorModel into a single PredictorModel. Microsoft.ML.Runtime.EntryPoints.ModelOperations CombineTwoModels Microsoft.ML.Runtime.EntryPoints.ModelOperations+SimplePredictorModelInput Microsoft.ML.Runtime.EntryPoints.ModelOperations+PredictorModelOutput +Transforms.WordEmbeddings Word Embeddings transform is a text featurizer which converts vectors of text tokens into sentence vectors using a pre-trained model Microsoft.ML.Runtime.Transforms.TextAnalytics WordEmbeddings Microsoft.ML.Runtime.Data.WordEmbeddingsTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.WordTokenizer The input to this transform is text, and the output is a vector of text containing the words (tokens) in the original text. The separator is space, but can be specified as any other character (or multiple characters) if needed. Microsoft.ML.Runtime.Transforms.TextAnalytics DelimitedTokenizeTransform Microsoft.ML.Runtime.Data.DelimitedTokenizeTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 36bdd0f62a..afacd8a71f 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -21450,6 +21450,120 @@ } ] }, + { + "Name": "Transforms.WordEmbeddings", + "Desc": "Word Embeddings transform is a text featurizer which converts vectors of text tokens into sentence vectors using a pre-trained model", + "FriendlyName": "Word Embeddings Transform", + "ShortName": "WordEmbeddings", + "Inputs": [ + { + "Name": "Column", + "Type": { + "Kind": "Array", + "ItemType": { + "Kind": "Struct", + "Fields": [ + { + "Name": "Name", + "Type": "String", + "Desc": "Name of the new column", + "Aliases": [ + "name" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Source", + "Type": "String", + "Desc": "Name of the source column", + "Aliases": [ + "src" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + } + ] + } + }, + "Desc": "New column definition(s) (optional form: name:src)", + "Aliases": [ + "col" + ], + "Required": true, + "SortOrder": 0.0, + "IsNullable": false + }, + { + "Name": "ModelKind", + "Type": { + "Kind": "Enum", + "Values": [ + "GloVe50D", + "GloVe100D", + "GloVe200D", + "GloVe300D", + "GloVeTwitter25D", + "GloVeTwitter50D", + "GloVeTwitter100D", + "GloVeTwitter200D", + "FastTextWikipedia300D", + "Sswe" + ] + }, + "Desc": "Pre-trained model used to create the vocabulary", + "Aliases": [ + "model" + ], + "Required": false, + "SortOrder": 1.0, + "IsNullable": true, + "Default": "Sswe" + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "CustomLookupTable", + "Type": "String", + "Desc": "Filename for custom word embedding model", + "Aliases": [ + "dataFile" + ], + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, { "Name": "Transforms.WordTokenizer", "Desc": "The input to this transform is text, and the output is a vector of text containing the words (tokens) in the original text. The separator is space, but can be specified as any other character (or multiple characters) if needed.", diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index b0bc269164..d9db7c432b 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -3710,5 +3710,52 @@ public void EntryPointTreeLeafFeaturizer() } } } + + [Fact] + public void EntryPointWordEmbeddings() + { + string dataFile = DeleteOutputPath("SavePipe", "SavePipeTextWordEmbeddings-SampleText.txt"); + File.WriteAllLines(dataFile, new[] { + "The quick brown fox jumps over the lazy dog.", + "The five boxing wizards jump quickly." + }); + var inputFile = new SimpleFileHandle(Env, dataFile, false, false); + var dataView = ImportTextData.TextLoader(Env, new ImportTextData.LoaderInput() + { + Arguments = + { + SeparatorChars = new []{' '}, + Column = new[] + { + new TextLoader.Column() + { + Name = "Text", + Source = new [] { new TextLoader.Range() { Min = 0, VariableEnd=true, ForceVector=true} }, + Type = DataKind.Text + } + } + }, + InputFile = inputFile, + }).Data; + var embedding = Transforms.TextAnalytics.WordEmbeddings(Env, new WordEmbeddingsTransform.Arguments() + { + Data = dataView, + Column = new[] { new WordEmbeddingsTransform.Column { Name = "Features", Source = "Text" } }, + ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe + }); + var result = embedding.OutputData; + using (var cursor = result.GetRowCursor((x => true))) + { + Assert.True(result.Schema.TryGetColumnIndex("Features", out int featColumn)); + var featGetter = cursor.GetGetter>(featColumn); + VBuffer feat = default; + while (cursor.MoveNext()) + { + featGetter(ref feat); + Assert.True(feat.Count == 150); + Assert.True(feat.Values[0] != 0); + } + } + } } } \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs index bea81175d8..7dd7a0145c 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs @@ -100,6 +100,99 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest() } } + [Fact] + public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordEmbedding() + { + var dataPath = GetDataPath(SentimentDataPath); + var testDataPath = GetDataPath(SentimentTestPath); + + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + // Pipeline + var loader = new TextLoader(env, + new TextLoader.Arguments() + { + Separator = "tab", + HasHeader = true, + Column = new[] + { + new TextLoader.Column() + { + Name = "Label", + Source = new [] { new TextLoader.Range() { Min=0, Max=0} }, + Type = DataKind.Num + }, + + new TextLoader.Column() + { + Name = "SentimentText", + Source = new [] { new TextLoader.Range() { Min=1, Max=1} }, + Type = DataKind.Text + } + } + }, new MultiFileSource(dataPath)); + + var text = TextTransform.Create(env, new TextTransform.Arguments() + { + Column = new TextTransform.Column + { + Name = "WordEmbeddings", + Source = new[] { "SentimentText" } + }, + KeepDiacritics = false, + KeepPunctuations = false, + TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower, + OutputTokens = true, + StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(), + VectorNormalizer = TextTransform.TextNormKind.None, + CharFeatureExtractor = null, + WordFeatureExtractor = null, + }, + loader); + + var trans = new WordEmbeddingsTransform(env, new WordEmbeddingsTransform.Arguments() + { + Column = new WordEmbeddingsTransform.Column[1] + { + new WordEmbeddingsTransform.Column + { + Name = "Features", + Source = "WordEmbeddings_TransformedText" + } + }, + ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe, + }, text); + // Train + var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments() + { + NumLeaves = 5, + NumTrees = 5, + MinDocumentsInLeafs = 2 + }); + + var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); + var pred = trainer.Train(trainRoles); + // Get scorer and evaluate the predictions from test data + IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); + var metrics = EvaluateBinary(env, testDataScorer); + + // SSWE is a simple word embedding model + we train on a really small dataset, so metrics are not great. + Assert.Equal(.6667, metrics.Accuracy, 4); + Assert.Equal(.71, metrics.Auc, 1); + Assert.Equal(.58, metrics.Auprc, 2); + // Create prediction engine and test predictions + var model = env.CreateBatchPredictionEngine(testDataScorer); + var sentiments = GetTestData(); + var predictions = model.Predict(sentiments, false); + Assert.Equal(2, predictions.Count()); + Assert.True(predictions.ElementAt(0).Sentiment.IsTrue); + Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); + + // Get feature importance based on feature gain during training + var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); + Assert.Equal(1.0, (double)summary[0].Value, 1); + } + } private BinaryClassificationMetrics EvaluateBinary(IHostEnvironment env, IDataView scoredData) { var dataEval = new RoleMappedData(scoredData, label: "Label", feature: "Features", opt: true);