Skip to content

Commit feddc72

Browse files
authored
Remove parsing perf bottleneck in WordEmbeddingsTransform (#1599) fixes #1608
* update benchmarking docs: mention required git submodule * update BDN to latest version with ETW profiler * use CopyLocalLockFileAssemblies to force MSBuild to copy all dependencies to output folder, even if they are not used - to allow for dynamic assembly loading for EtwProfiler when used from console app * use new .AsDefault() bdn method to tell it that it's a default job which can be overwritten by TrainConfig * add benchmark which isolates the performance bottleneck * write unit tests! * move the parsing logic to a separate method before applying any code changes * apply the optimizations * read the file in parallel for even x3 speedup! * remove the temporary benchmark * revert breaking benchmark config changes * apply a workaround for BenchmarkDotNet bug * code review fixes * cleanup the comment * update BDN to 0.11.3, remove all workarounds * code review fixes * missing license header
1 parent b4eebc5 commit feddc72

File tree

8 files changed

+165
-68
lines changed

8 files changed

+165
-68
lines changed

build/Dependencies.props

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
<!-- Test-only Dependencies -->
4141
<PropertyGroup>
42-
<BenchmarkDotNetVersion>0.11.1</BenchmarkDotNetVersion>
42+
<BenchmarkDotNetVersion>0.11.3</BenchmarkDotNetVersion>
4343
<MicrosoftMLTestModelsPackageVersion>0.0.3-test</MicrosoftMLTestModelsPackageVersion>
4444
</PropertyGroup>
4545

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Runtime.CompilerServices;
7+
8+
namespace Microsoft.ML.Runtime.Internal.Utilities
9+
{
10+
[BestFriend]
11+
internal static class LineParser
12+
{
13+
public static (bool isSuccess, string key, float[] values) ParseKeyThenNumbers(string line)
14+
{
15+
if (string.IsNullOrWhiteSpace(line))
16+
return (false, null, null);
17+
18+
ReadOnlySpan<char> trimmedLine = line.AsSpan().TrimEnd(); // TrimEnd creates a Span, no allocations
19+
20+
int firstSeparatorIndex = trimmedLine.IndexOfAny(' ', '\t'); // the first word is the key, we just skip it
21+
ReadOnlySpan<char> valuesToParse = trimmedLine.Slice(start: firstSeparatorIndex + 1);
22+
23+
float[] values = AllocateFixedSizeArrayToStoreParsedValues(valuesToParse);
24+
25+
int toParseStartIndex = 0;
26+
int valueIndex = 0;
27+
for (int i = 0; i <= valuesToParse.Length; i++)
28+
{
29+
if (i == valuesToParse.Length || valuesToParse[i] == ' ' || valuesToParse[i] == '\t')
30+
{
31+
if (DoubleParser.TryParse(valuesToParse.Slice(toParseStartIndex, i - toParseStartIndex), out float parsed))
32+
values[valueIndex++] = parsed;
33+
else
34+
return (false, null, null);
35+
36+
toParseStartIndex = i + 1;
37+
}
38+
}
39+
40+
return (true, trimmedLine.Slice(0, firstSeparatorIndex).ToString(), values);
41+
}
42+
43+
/// <summary>
44+
/// we count the number of values first to allocate a single array with of proper size
45+
/// </summary>
46+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
47+
private static float[] AllocateFixedSizeArrayToStoreParsedValues(ReadOnlySpan<char> valuesToParse)
48+
{
49+
int valuesCount = 0;
50+
51+
for (int i = 0; i < valuesToParse.Length; i++)
52+
if (valuesToParse[i] == ' ' || valuesToParse[i] == '\t')
53+
valuesCount++;
54+
55+
return new float[valuesCount + 1]; // + 1 because the line is trimmed and there is no whitespace at the end
56+
}
57+
}
58+
}

src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs

+55-51
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
using Microsoft.ML.StaticPipe.Runtime;
1616
using Microsoft.ML.Transforms.Text;
1717
using System;
18+
using System.Collections.Concurrent;
1819
using System.Collections.Generic;
1920
using System.IO;
2021
using System.Linq;
2122
using System.Text;
23+
using System.Threading.Tasks;
2224

2325
[assembly: LoadableClass(WordEmbeddingsExtractingTransformer.Summary, typeof(IDataTransform), typeof(WordEmbeddingsExtractingTransformer), typeof(WordEmbeddingsExtractingTransformer.Arguments),
2426
typeof(SignatureDataTransform), WordEmbeddingsExtractingTransformer.UserName, "WordEmbeddingsTransform", WordEmbeddingsExtractingTransformer.ShortName, DocName = "transform/WordEmbeddingsTransform.md")]
@@ -207,7 +209,7 @@ public WordEmbeddingsExtractingTransformer(IHostEnvironment env, PretrainedModel
207209

208210
_modelKind = modelKind;
209211
_modelFileNameWithPath = EnsureModelFile(env, out _linesToSkip, (PretrainedModelKind)_modelKind);
210-
_currentVocab = GetVocabularyDictionary();
212+
_currentVocab = GetVocabularyDictionary(env);
211213
}
212214

213215
/// <summary>
@@ -225,7 +227,7 @@ public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string customMo
225227
_modelKind = null;
226228
_customLookup = true;
227229
_modelFileNameWithPath = customModelFile;
228-
_currentVocab = GetVocabularyDictionary();
230+
_currentVocab = GetVocabularyDictionary(env);
229231
}
230232

231233
private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns)
@@ -281,7 +283,7 @@ private WordEmbeddingsExtractingTransformer(IHost host, ModelLoadContext ctx)
281283
}
282284

283285
Host.CheckNonWhiteSpace(_modelFileNameWithPath, nameof(_modelFileNameWithPath));
284-
_currentVocab = GetVocabularyDictionary();
286+
_currentVocab = GetVocabularyDictionary(host);
285287
}
286288

287289
public static WordEmbeddingsExtractingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
@@ -696,7 +698,7 @@ private string EnsureModelFile(IHostEnvironment env, out int linesToSkip, Pretra
696698
throw Host.Except($"Can't map model kind = {kind} to specific file, please refer to https://aka.ms/MLNetIssue for assistance");
697699
}
698700

699-
private Model GetVocabularyDictionary()
701+
private Model GetVocabularyDictionary(IHostEnvironment hostEnvironment)
700702
{
701703
int dimension = 0;
702704
if (!File.Exists(_modelFileNameWithPath))
@@ -722,62 +724,64 @@ private Model GetVocabularyDictionary()
722724
}
723725
}
724726

725-
Model model = null;
726-
using (StreamReader sr = File.OpenText(_modelFileNameWithPath))
727+
using (var ch = Host.Start(LoaderSignature))
728+
using (var pch = Host.StartProgressChannel("Building Vocabulary from Model File for Word Embeddings Transform"))
727729
{
728-
string line;
729-
int lineNumber = 1;
730-
char[] delimiters = { ' ', '\t' };
731-
using (var ch = Host.Start(LoaderSignature))
732-
using (var pch = Host.StartProgressChannel("Building Vocabulary from Model File for Word Embeddings Transform"))
733-
{
734-
var header = new ProgressHeader(new[] { "lines" });
735-
pch.SetHeader(header, e => e.SetProgress(0, lineNumber));
736-
string firstLine = sr.ReadLine();
737-
while ((line = sr.ReadLine()) != null)
730+
var parsedData = new ConcurrentBag<(string key, float[] values, long lineNumber)>();
731+
int skippedLinesCount = Math.Max(1, _linesToSkip);
732+
733+
Parallel.ForEach(File.ReadLines(_modelFileNameWithPath).Skip(skippedLinesCount), GetParallelOptions(hostEnvironment),
734+
(line, parallelState, lineNumber) =>
738735
{
739-
if (lineNumber >= _linesToSkip)
740-
{
741-
string[] words = line.TrimEnd().Split(delimiters);
742-
dimension = words.Length - 1;
743-
if (model == null)
744-
model = new Model(dimension);
745-
if (model.Dimension != dimension)
746-
ch.Warning($"Dimension mismatch while reading model file: '{_modelFileNameWithPath}', line number {lineNumber + 1}, expected dimension = {model.Dimension}, received dimension = {dimension}");
747-
else
748-
{
749-
float tmp;
750-
string key = words[0];
751-
float[] value = words.Skip(1).Select(x => float.TryParse(x, out tmp) ? tmp : Single.NaN).ToArray();
752-
if (!value.Contains(Single.NaN))
753-
model.AddWordVector(ch, key, value);
754-
else
755-
ch.Warning($"Parsing error while reading model file: '{_modelFileNameWithPath}', line number {lineNumber + 1}");
756-
}
757-
}
758-
lineNumber++;
759-
}
736+
(bool isSuccess, string key, float[] values) = LineParser.ParseKeyThenNumbers(line);
737+
738+
if (isSuccess)
739+
parsedData.Add((key, values, lineNumber + skippedLinesCount));
740+
else // we use shared state here (ch) but it's not our hot path and we don't care about unhappy-path performance
741+
ch.Warning($"Parsing error while reading model file: '{_modelFileNameWithPath}', line number {lineNumber + skippedLinesCount}");
742+
});
760743

761-
// Handle first line of the embedding file separately since some embedding files including fastText have a single-line header
762-
string[] wordsInFirstLine = firstLine.TrimEnd().Split(delimiters);
763-
dimension = wordsInFirstLine.Length - 1;
744+
Model model = null;
745+
foreach (var parsedLine in parsedData.OrderBy(parsedLine => parsedLine.lineNumber))
746+
{
747+
dimension = parsedLine.values.Length;
764748
if (model == null)
765749
model = new Model(dimension);
766-
if (model.Dimension == dimension)
767-
{
768-
float temp;
769-
string firstKey = wordsInFirstLine[0];
770-
float[] firstValue = wordsInFirstLine.Skip(1).Select(x => float.TryParse(x, out temp) ? temp : Single.NaN).ToArray();
771-
if (!firstValue.Contains(Single.NaN))
772-
model.AddWordVector(ch, firstKey, firstValue);
773-
}
774-
pch.Checkpoint(lineNumber);
750+
if (model.Dimension != dimension)
751+
ch.Warning($"Dimension mismatch while reading model file: '{_modelFileNameWithPath}', line number {parsedLine.lineNumber}, expected dimension = {model.Dimension}, received dimension = {dimension}");
752+
else
753+
model.AddWordVector(ch, parsedLine.key, parsedLine.values);
775754
}
755+
756+
// Handle first line of the embedding file separately since some embedding files including fastText have a single-line header
757+
var firstLine = File.ReadLines(_modelFileNameWithPath).First();
758+
string[] wordsInFirstLine = firstLine.TrimEnd().Split(' ', '\t');
759+
dimension = wordsInFirstLine.Length - 1;
760+
if (model == null)
761+
model = new Model(dimension);
762+
if (model.Dimension == dimension)
763+
{
764+
float temp;
765+
string firstKey = wordsInFirstLine[0];
766+
float[] firstValue = wordsInFirstLine.Skip(1).Select(x => float.TryParse(x, out temp) ? temp : Single.NaN).ToArray();
767+
if (!firstValue.Contains(Single.NaN))
768+
model.AddWordVector(ch, firstKey, firstValue);
769+
}
770+
771+
_vocab[_modelFileNameWithPath] = new WeakReference<Model>(model, false);
772+
return model;
776773
}
777-
_vocab[_modelFileNameWithPath] = new WeakReference<Model>(model, false);
778-
return model;
779774
}
780775
}
776+
777+
private static ParallelOptions GetParallelOptions(IHostEnvironment hostEnvironment)
778+
{
779+
// "Less than 1 means whatever the component views as ideal." (about ConcurrencyFactor)
780+
if (hostEnvironment.ConcurrencyFactor < 1)
781+
return new ParallelOptions(); // we provide default options and let the Parallel decide
782+
else
783+
return new ParallelOptions() { MaxDegreeOfParallelism = hostEnvironment.ConcurrencyFactor };
784+
}
781785
}
782786

783787
/// <include file='doc.xml' path='doc/members/member[@name="WordEmbeddings"]/*' />

test/Microsoft.ML.Benchmarks/Harness/Configs.cs

+4-14
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,21 @@ public RecommendedConfig()
1818
.With(CreateToolchain())); // toolchain is responsible for generating, building and running dedicated executable per benchmark
1919

2020
Add(new ExtraMetricColumn()); // an extra colum that can display additional metric reported by the benchmarks
21-
22-
UnionRule = ConfigUnionRule.AlwaysUseLocal; // global config can be overwritten with local (the one set via [ConfigAttribute])
2321
}
2422

2523
protected virtual Job GetJobDefinition()
2624
=> Job.Default
2725
.WithWarmupCount(1) // ML.NET benchmarks are typically CPU-heavy benchmarks, 1 warmup is usually enough
28-
.WithMaxIterationCount(20);
26+
.WithMaxIterationCount(20)
27+
.AsDefault(); // this way we tell BDN that it's a default config which can be overwritten
2928

3029
/// <summary>
3130
/// we need our own toolchain because MSBuild by default does not copy recursive native dependencies to the output
3231
/// </summary>
3332
private IToolchain CreateToolchain()
3433
{
35-
var tfm = GetTargetFrameworkMoniker();
36-
var csProj = CsProjCoreToolchain.From(new NetCoreAppSettings(targetFrameworkMoniker: tfm, runtimeFrameworkVersion: null, name: tfm));
34+
var tfm = NetCoreAppSettings.Current.Value.TargetFrameworkMoniker;
35+
var csProj = CsProjCoreToolchain.Current.Value;
3736

3837
return new Toolchain(
3938
tfm,
@@ -42,15 +41,6 @@ private IToolchain CreateToolchain()
4241
csProj.Executor);
4342
}
4443

45-
private static string GetTargetFrameworkMoniker()
46-
{
47-
#if NETCOREAPP3_0 // todo: remove the #IF DEFINES when BDN 0.11.2 gets released (BDN gains the 3.0 support)
48-
return "netcoreapp3.0";
49-
#else
50-
return NetCoreAppSettings.Current.Value.TargetFrameworkMoniker;
51-
#endif
52-
}
53-
5444
private static string GetBuildConfigurationName()
5545
{
5646
#if NETCOREAPP3_0

test/Microsoft.ML.Benchmarks/Harness/ProjectGenerator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace Microsoft.ML.Benchmarks.Harness
2727
/// </summary>
2828
public class ProjectGenerator : CsProjGenerator
2929
{
30-
public ProjectGenerator(string targetFrameworkMoniker) : base(targetFrameworkMoniker, platform => platform.ToConfig(), null)
30+
public ProjectGenerator(string targetFrameworkMoniker) : base(targetFrameworkMoniker, null, null, null)
3131
{
3232
}
3333

test/Microsoft.ML.Benchmarks/Microsoft.ML.Benchmarks.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
</ItemGroup>
1313
<ItemGroup>
1414
<PackageReference Include="BenchmarkDotNet" Version="$(BenchmarkDotNetVersion)" />
15+
<PackageReference Include="BenchmarkDotNet.Diagnostics.Windows" Version="$(BenchmarkDotNetVersion)" />
1516
</ItemGroup>
1617
<ItemGroup>
1718
<ProjectReference Include="..\..\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />

test/Microsoft.ML.Benchmarks/README.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@ This project contains performance benchmarks.
44

55
## Run the Performance Tests
66

7-
**Pre-requisite:** On a clean repo, `build.cmd` at the root installs the right version of dotnet.exe and builds the solution. You need to build the solution in `Release` with native dependencies.
7+
**Pre-requisite:** In order to fetch dependencies which come through Git submodules the following command needs to be run before building:
8+
9+
git submodule update --init
10+
11+
**Pre-requisite:** On a clean repo with initalized submodules, `build.cmd` at the root installs the right version of dotnet.exe and builds the solution. You need to build the solution in `Release` with native dependencies.
812

913
build.cmd -release -buildNative
1014

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Internal.Utilities;
6+
using System.Collections.Generic;
7+
using Xunit;
8+
9+
namespace Microsoft.ML.Tests.Transformers
10+
{
11+
public class LineParserTests
12+
{
13+
public static IEnumerable<object[]> ValidInputs()
14+
{
15+
yield return new object[] { "key 0.1 0.2 0.3", "key", new float[] { 0.1f, 0.2f, 0.3f } };
16+
yield return new object[] { "key 0.1 0.2 0.3 ", "key", new float[] { 0.1f, 0.2f, 0.3f } };
17+
yield return new object[] { "key\t0.1\t0.2\t0.3", "key", new float[] { 0.1f, 0.2f, 0.3f } }; // tab can also be a separator
18+
yield return new object[] { "key\t0.1\t0.2\t0.3\t", "key", new float[] { 0.1f, 0.2f, 0.3f } };
19+
}
20+
21+
[Theory]
22+
[MemberData(nameof(ValidInputs))]
23+
public void WhenProvidedAValidInputParserParsesKeyAndValues(string input, string expectedKey, float[] expectedValues)
24+
{
25+
var result = LineParser.ParseKeyThenNumbers(input);
26+
27+
Assert.True(result.isSuccess);
28+
Assert.Equal(expectedKey, result.key);
29+
Assert.Equal(expectedValues, result.values);
30+
}
31+
32+
[Theory]
33+
[InlineData("")]
34+
[InlineData("key 0.1 NOT_A_NUMBER")] // invalid number
35+
public void WhenProvidedAnInvalidInputParserReturnsFailure(string input)
36+
{
37+
Assert.False(LineParser.ParseKeyThenNumbers(input).isSuccess);
38+
}
39+
}
40+
}

0 commit comments

Comments
 (0)