diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs index bf85ddf42f..83e420afff 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs @@ -426,13 +426,14 @@ private Model GetVocabularyDictionary() dimension = wordsInFirstLine.Length - 1; if (model == null) model = new Model(dimension); - float temp; - string firstKey = wordsInFirstLine[0]; - float[] firstValue = wordsInFirstLine.Skip(1).Select(x => float.TryParse(x, out temp) ? temp : Single.NaN).ToArray(); - if (!firstValue.Contains(Single.NaN)) - model.AddWordVector(ch, firstKey, firstValue); - else - ch.Warning($"Parsing error while reading model file: '{_modelFileNameWithPath}', line number 1"); + if (model.Dimension == dimension) + { + float temp; + string firstKey = wordsInFirstLine[0]; + float[] firstValue = wordsInFirstLine.Skip(1).Select(x => float.TryParse(x, out temp) ? temp : Single.NaN).ToArray(); + if (!firstValue.Contains(Single.NaN)) + model.AddWordVector(ch, firstKey, firstValue); + } pch.Checkpoint(lineNumber); } } diff --git a/test/Microsoft.ML.Benchmarks/BigramAndTrigramBenchMark.cs b/test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs similarity index 69% rename from test/Microsoft.ML.Benchmarks/BigramAndTrigramBenchMark.cs rename to test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs index f4c947abda..85c1cb0091 100644 --- a/test/Microsoft.ML.Benchmarks/BigramAndTrigramBenchMark.cs +++ b/test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs @@ -20,14 +20,16 @@ internal class EmptyWriter : TextWriter public override Encoding Encoding => null; } - public class BigramAndTrigramBenchmark + public class MultiClassClassification { private string _dataPath_Wiki; private string _modelPath_Wiki; [GlobalSetup(Targets = new string[] { nameof(CV_Multiclass_WikiDetox_BigramsAndTrichar_OVAAveragedPerceptron), - nameof(CV_Multiclass_WikiDetox_BigramsAndTrichar_LightGBMMulticlass) })] + nameof(CV_Multiclass_WikiDetox_BigramsAndTrichar_LightGBMMulticlass), + nameof(CV_Multiclass_WikiDetox_WordEmbeddings_OVAAveragedPerceptron), + nameof(CV_Multiclass_WikiDetox_WordEmbeddings_SDCAMC)})] public void SetupTrainingSpeedTests() { _dataPath_Wiki = Path.GetFullPath(TestDatasets.WikiDetox.trainFilename); @@ -81,5 +83,25 @@ public void Test_Multiclass_WikiDetox_BigramsAndTrichar_OVAAveragedPerceptron() Maml.MainCore(tlc, cmd, alwaysPrintStacktrace: false); } } + + [Benchmark] + public void CV_Multiclass_WikiDetox_WordEmbeddings_OVAAveragedPerceptron() + { + string cmd = @"CV tr=OVA{p=AveragedPerceptron{iter=10}} k=5 loader=TextLoader{quote=- sparse=- col=Label:R4:0 col=rev_id:TX:1 col=comment:TX:2 col=logged_in:BL:4 col=ns:TX:5 col=sample:TX:6 col=split:TX:7 col=year:R4:3 header=+} data=" + _dataPath_Wiki + " xf=Convert{col=logged_in type=R4} xf=CategoricalTransform{col=ns} xf=TextTransform{col=FeaturesText:comment tokens=+ wordExtractor=NGramExtractorTransform{ngram=2}} xf=WordEmbeddingsTransform{col=FeaturesWordEmbedding:FeaturesText_TransformedText model=FastTextWikipedia300D} xf=Concat{col=Features:FeaturesText,FeaturesWordEmbedding,logged_in,ns}"; + using (var tlc = new TlcEnvironment(verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance)) + { + Maml.MainCore(tlc, cmd, alwaysPrintStacktrace: false); + } + } + + [Benchmark] + public void CV_Multiclass_WikiDetox_WordEmbeddings_SDCAMC() + { + string cmd = @"CV tr=SDCAMC k=5 loader=TextLoader{quote=- sparse=- col=Label:R4:0 col=rev_id:TX:1 col=comment:TX:2 col=logged_in:BL:4 col=ns:TX:5 col=sample:TX:6 col=split:TX:7 col=year:R4:3 header=+} data=" + _dataPath_Wiki + " xf=Convert{col=logged_in type=R4} xf=CategoricalTransform{col=ns} xf=TextTransform{col=FeaturesText:comment tokens=+ wordExtractor={} charExtractor={}} xf=WordEmbeddingsTransform{col=FeaturesWordEmbedding:FeaturesText_TransformedText model=FastTextWikipedia300D} xf=Concat{col=Features:FeaturesWordEmbedding,logged_in,ns}"; + using (var tlc = new TlcEnvironment(verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance)) + { + Maml.MainCore(tlc, cmd, alwaysPrintStacktrace: false); + } + } } }