@@ -209,7 +209,7 @@ public WordEmbeddingsTransform(IHostEnvironment env, PretrainedModelKind modelKi
209
209
210
210
_modelKind = modelKind ;
211
211
_modelFileNameWithPath = EnsureModelFile ( env , out _linesToSkip , ( PretrainedModelKind ) _modelKind ) ;
212
- _currentVocab = GetVocabularyDictionary ( ) ;
212
+ _currentVocab = GetVocabularyDictionary ( env ) ;
213
213
}
214
214
215
215
/// <summary>
@@ -227,7 +227,7 @@ public WordEmbeddingsTransform(IHostEnvironment env, string customModelFile, par
227
227
_modelKind = null ;
228
228
_customLookup = true ;
229
229
_modelFileNameWithPath = customModelFile ;
230
- _currentVocab = GetVocabularyDictionary ( ) ;
230
+ _currentVocab = GetVocabularyDictionary ( env ) ;
231
231
}
232
232
233
233
private static ( string input , string output ) [ ] GetColumnPairs ( ColumnInfo [ ] columns )
@@ -283,7 +283,7 @@ private WordEmbeddingsTransform(IHost host, ModelLoadContext ctx)
283
283
}
284
284
285
285
Host . CheckNonWhiteSpace ( _modelFileNameWithPath , nameof ( _modelFileNameWithPath ) ) ;
286
- _currentVocab = GetVocabularyDictionary ( ) ;
286
+ _currentVocab = GetVocabularyDictionary ( host ) ;
287
287
}
288
288
289
289
public static WordEmbeddingsTransform Create ( IHostEnvironment env , ModelLoadContext ctx )
@@ -699,7 +699,7 @@ private string EnsureModelFile(IHostEnvironment env, out int linesToSkip, Pretra
699
699
throw Host . Except ( $ "Can't map model kind = { kind } to specific file, please refer to https://aka.ms/MLNetIssue for assistance") ;
700
700
}
701
701
702
- private Model GetVocabularyDictionary ( )
702
+ private Model GetVocabularyDictionary ( IHostEnvironment hostEnvironment )
703
703
{
704
704
int dimension = 0 ;
705
705
if ( ! File . Exists ( _modelFileNameWithPath ) )
@@ -731,7 +731,7 @@ private Model GetVocabularyDictionary()
731
731
var parsedData = new ConcurrentBag < ( string key , float [ ] values , long lineNumber ) > ( ) ;
732
732
int skippedLinesCount = Math . Max ( 1 , _linesToSkip ) ;
733
733
734
- Parallel . ForEach ( File . ReadLines ( _modelFileNameWithPath ) . Skip ( skippedLinesCount ) ,
734
+ Parallel . ForEach ( File . ReadLines ( _modelFileNameWithPath ) . Skip ( skippedLinesCount ) , GetParallelOptions ( hostEnvironment ) ,
735
735
( line , parallelState , lineNumber ) =>
736
736
{
737
737
( bool isSuccess , string key , float [ ] values ) = LineParser . ParseKeyThenNumbers ( line ) ;
@@ -774,6 +774,15 @@ private Model GetVocabularyDictionary()
774
774
}
775
775
}
776
776
}
777
+
778
+ private static ParallelOptions GetParallelOptions ( IHostEnvironment hostEnvironment )
779
+ {
780
+ // "Less than 1 means whatever the component views as ideal." (about ConcurrencyFactor)
781
+ if ( hostEnvironment . ConcurrencyFactor < 1 )
782
+ return new ParallelOptions ( ) ; // we provide default options and let the Parallel decide
783
+ else
784
+ return new ParallelOptions ( ) { MaxDegreeOfParallelism = hostEnvironment . ConcurrencyFactor } ;
785
+ }
777
786
}
778
787
779
788
/// <include file='doc.xml' path='doc/members/member[@name="WordEmbeddings"]/*' />
0 commit comments