Skip to content

Commit 9425d1e

Browse files
committed
code review fixes
1 parent 62e650b commit 9425d1e

File tree

8 files changed

+75
-83
lines changed

8 files changed

+75
-83
lines changed

build/Dependencies.props

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
<!-- Test-only Dependencies -->
4040
<PropertyGroup>
4141
<BenchmarkDotNetVersion>0.11.2</BenchmarkDotNetVersion>
42-
<TraceEventVersion>2.0.26</TraceEventVersion>
42+
<MicrosoftDiagnosticsTracingTraceEventVersion>2.0.26</MicrosoftDiagnosticsTracingTraceEventVersion>
4343
<MicrosoftMLTestModelsPackageVersion>0.0.3-test</MicrosoftMLTestModelsPackageVersion>
4444
</PropertyGroup>
4545

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
using System;
2+
using System.Runtime.CompilerServices;
3+
4+
namespace Microsoft.ML.Runtime.Internal.Utilities
5+
{
6+
[BestFriend]
7+
internal static class LineParser
8+
{
9+
[BestFriend] // required because the return type is a value tuple..
10+
internal static (bool isSuccess, string key, float[] values) ParseKeyThenNumbers(string line)
11+
{
12+
if (string.IsNullOrWhiteSpace(line))
13+
return (false, null, null);
14+
15+
ReadOnlySpan<char> trimmedLine = line.AsSpan().TrimEnd(); // TrimEnd creates a Span, no allocations
16+
17+
int firstSeparatorIndex = trimmedLine.IndexOfAny(' ', '\t'); // the first word is the key, we just skip it
18+
ReadOnlySpan<char> valuesToParse = trimmedLine.Slice(start: firstSeparatorIndex + 1);
19+
20+
float[] values = AllocateFixedSizeArrayToStoreParsedValues(valuesToParse);
21+
22+
int toParseStartIndex = 0;
23+
int valueIndex = 0;
24+
for (int i = 0; i <= valuesToParse.Length; i++)
25+
{
26+
if (i == valuesToParse.Length || valuesToParse[i] == ' ' || valuesToParse[i] == '\t')
27+
{
28+
if (DoubleParser.TryParse(valuesToParse.Slice(toParseStartIndex, i - toParseStartIndex), out float parsed))
29+
values[valueIndex++] = parsed;
30+
else
31+
return (false, null, null);
32+
33+
toParseStartIndex = i + 1;
34+
}
35+
}
36+
37+
return (true, trimmedLine.Slice(0, firstSeparatorIndex).ToString(), values);
38+
}
39+
40+
/// <summary>
41+
/// we count the number of values first to allocate a single array with of proper size
42+
/// </summary>
43+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
44+
private static float[] AllocateFixedSizeArrayToStoreParsedValues(ReadOnlySpan<char> valuesToParse)
45+
{
46+
int valuesCount = 0;
47+
48+
for (int i = 0; i < valuesToParse.Length; i++)
49+
if (valuesToParse[i] == ' ' || valuesToParse[i] == '\t')
50+
valuesCount++;
51+
52+
return new float[valuesCount + 1]; // + 1 because the line is trimmed and there is no whitespace at the end
53+
}
54+
}
55+
}

src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
4-
<TargetFrameworks>netstandard2.0;netcoreapp2.1</TargetFrameworks>
4+
<TargetFramework>netstandard2.0</TargetFramework>
55
<IncludeInPackage>Microsoft.ML</IncludeInPackage>
66
<DefineConstants>CORECLR</DefineConstants>
77
</PropertyGroup>

src/Microsoft.ML.Transforms/Text/LineParser.cs

-70
This file was deleted.

src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs

+14-5
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ public WordEmbeddingsTransform(IHostEnvironment env, PretrainedModelKind modelKi
209209

210210
_modelKind = modelKind;
211211
_modelFileNameWithPath = EnsureModelFile(env, out _linesToSkip, (PretrainedModelKind)_modelKind);
212-
_currentVocab = GetVocabularyDictionary();
212+
_currentVocab = GetVocabularyDictionary(env);
213213
}
214214

215215
/// <summary>
@@ -227,7 +227,7 @@ public WordEmbeddingsTransform(IHostEnvironment env, string customModelFile, par
227227
_modelKind = null;
228228
_customLookup = true;
229229
_modelFileNameWithPath = customModelFile;
230-
_currentVocab = GetVocabularyDictionary();
230+
_currentVocab = GetVocabularyDictionary(env);
231231
}
232232

233233
private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns)
@@ -283,7 +283,7 @@ private WordEmbeddingsTransform(IHost host, ModelLoadContext ctx)
283283
}
284284

285285
Host.CheckNonWhiteSpace(_modelFileNameWithPath, nameof(_modelFileNameWithPath));
286-
_currentVocab = GetVocabularyDictionary();
286+
_currentVocab = GetVocabularyDictionary(host);
287287
}
288288

289289
public static WordEmbeddingsTransform Create(IHostEnvironment env, ModelLoadContext ctx)
@@ -699,7 +699,7 @@ private string EnsureModelFile(IHostEnvironment env, out int linesToSkip, Pretra
699699
throw Host.Except($"Can't map model kind = {kind} to specific file, please refer to https://aka.ms/MLNetIssue for assistance");
700700
}
701701

702-
private Model GetVocabularyDictionary()
702+
private Model GetVocabularyDictionary(IHostEnvironment hostEnvironment)
703703
{
704704
int dimension = 0;
705705
if (!File.Exists(_modelFileNameWithPath))
@@ -731,7 +731,7 @@ private Model GetVocabularyDictionary()
731731
var parsedData = new ConcurrentBag<(string key, float[] values, long lineNumber)>();
732732
int skippedLinesCount = Math.Max(1, _linesToSkip);
733733

734-
Parallel.ForEach(File.ReadLines(_modelFileNameWithPath).Skip(skippedLinesCount),
734+
Parallel.ForEach(File.ReadLines(_modelFileNameWithPath).Skip(skippedLinesCount), GetParallelOptions(hostEnvironment),
735735
(line, parallelState, lineNumber) =>
736736
{
737737
(bool isSuccess, string key, float[] values) = LineParser.ParseKeyThenNumbers(line);
@@ -774,6 +774,15 @@ private Model GetVocabularyDictionary()
774774
}
775775
}
776776
}
777+
778+
private static ParallelOptions GetParallelOptions(IHostEnvironment hostEnvironment)
779+
{
780+
// "Less than 1 means whatever the component views as ideal." (about ConcurrencyFactor)
781+
if (hostEnvironment.ConcurrencyFactor < 1)
782+
return new ParallelOptions(); // we provide default options and let the Parallel decide
783+
else
784+
return new ParallelOptions() { MaxDegreeOfParallelism = hostEnvironment.ConcurrencyFactor };
785+
}
777786
}
778787

779788
/// <include file='doc.xml' path='doc/members/member[@name="WordEmbeddings"]/*' />

test/Microsoft.ML.Benchmarks/Microsoft.ML.Benchmarks.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
<ItemGroup>
1515
<PackageReference Include="BenchmarkDotNet" Version="$(BenchmarkDotNetVersion)" />
1616
<PackageReference Include="BenchmarkDotNet.Diagnostics.Windows" Version="$(BenchmarkDotNetVersion)" />
17-
<PackageReference Include="Microsoft.Diagnostics.Tracing.TraceEvent" Version="$(TraceEventVersion)" />
17+
<PackageReference Include="Microsoft.Diagnostics.Tracing.TraceEvent" Version="$(MicrosoftDiagnosticsTracingTraceEventVersion)" />
1818
</ItemGroup>
1919
<ItemGroup>
2020
<ProjectReference Include="..\..\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />

test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs

-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
using Microsoft.ML.Trainers;
1212
using Microsoft.ML.Transforms.Categorical;
1313
using System.IO;
14-
using Microsoft.ML.Transforms.Text;
1514

1615
namespace Microsoft.ML.Benchmarks
1716
{

test/Microsoft.ML.Tests/Transformers/LineParserTests.cs

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
using System;
1+
using Microsoft.ML.Runtime.Internal.Utilities;
22
using System.Collections.Generic;
3-
using System.Linq;
43
using Xunit;
54

65
namespace Microsoft.ML.Tests.Transformers
@@ -23,7 +22,7 @@ public static IEnumerable<object[]> ValidInputs()
2322
[MemberData(nameof(ValidInputs))]
2423
public void WhenProvidedAValidInputParserParsesKeyAndValues(string input, string expectedKey, float[] expectedValues)
2524
{
26-
var result = Transforms.Text.LineParser.ParseKeyThenNumbers(input);
25+
var result = LineParser.ParseKeyThenNumbers(input);
2726

2827
Assert.True(result.isSuccess);
2928
Assert.Equal(expectedKey, result.key);
@@ -35,7 +34,7 @@ public void WhenProvidedAValidInputParserParsesKeyAndValues(string input, string
3534
[InlineData("key 0.1 NOT_A_NUMBER")] // invalid number
3635
public void WhenProvidedAnInvalidInputParserReturnsFailure(string input)
3736
{
38-
Assert.False(Transforms.Text.LineParser.ParseKeyThenNumbers(input).isSuccess);
37+
Assert.False(LineParser.ParseKeyThenNumbers(input).isSuccess);
3938
}
4039
}
4140
}

0 commit comments

Comments
 (0)