@@ -19,6 +19,8 @@ internal class CodeGenerator : IProjectGenerator
19
19
private readonly Pipeline pipeline ;
20
20
private readonly CodeGeneratorSettings settings ;
21
21
private readonly ColumnInferenceResults columnInferenceResult ;
22
+ private readonly HashSet < string > LightGBMTrainers = new HashSet < string > ( ) { TrainerName . LightGbmBinary . ToString ( ) , TrainerName . LightGbmMulti . ToString ( ) , TrainerName . LightGbmRegression . ToString ( ) } ;
23
+ private readonly HashSet < string > mklComponentsTrainers = new HashSet < string > ( ) { TrainerName . OlsRegression . ToString ( ) , TrainerName . SymbolicSgdLogisticRegressionBinary . ToString ( ) } ;
22
24
23
25
internal CodeGenerator ( Pipeline pipeline , ColumnInferenceResults columnInferenceResult , CodeGeneratorSettings settings )
24
26
{
@@ -29,25 +31,32 @@ internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInference
29
31
30
32
public void GenerateOutput ( )
31
33
{
34
+ // Get the extra nuget packages to be included in the generated project.
35
+ var trainerNodes = pipeline . Nodes . Where ( t => t . NodeType == PipelineNodeType . Trainer ) ;
36
+
37
+ bool includeLightGbmPackage = false ;
38
+ bool includeMklComponentsPackage = false ;
39
+ SetRequiredNugetPackages ( trainerNodes , ref includeLightGbmPackage , ref includeMklComponentsPackage ) ;
40
+
32
41
// Get Namespace
33
42
var namespaceValue = Utils . Normalize ( settings . OutputName ) ;
34
43
var labelType = columnInferenceResult . TextLoaderOptions . Columns . Where ( t => t . Name == columnInferenceResult . ColumnInformation . LabelColumnName ) . First ( ) . DataKind ;
35
44
Type labelTypeCsharp = Utils . GetCSharpType ( labelType ) ;
36
45
37
46
// Generate Model Project
38
- var modelProjectContents = GenerateModelProjectContents ( namespaceValue , labelTypeCsharp ) ;
47
+ var modelProjectContents = GenerateModelProjectContents ( namespaceValue , labelTypeCsharp , includeLightGbmPackage , includeMklComponentsPackage ) ;
39
48
40
49
// Write files to disk.
41
50
var modelprojectDir = Path . Combine ( settings . OutputBaseDir , $ "{ settings . OutputName } .Model") ;
42
51
var dataModelsDir = Path . Combine ( modelprojectDir , "DataModels" ) ;
43
52
var modelProjectName = $ "{ settings . OutputName } .Model.csproj";
44
53
45
- Utils . WriteOutputToFiles ( modelProjectContents . ObservationCSFileContent , "Observation .cs" , dataModelsDir ) ;
46
- Utils . WriteOutputToFiles ( modelProjectContents . PredictionCSFileContent , "Prediction .cs" , dataModelsDir ) ;
54
+ Utils . WriteOutputToFiles ( modelProjectContents . ObservationCSFileContent , "SampleObservation .cs" , dataModelsDir ) ;
55
+ Utils . WriteOutputToFiles ( modelProjectContents . PredictionCSFileContent , "SamplePrediction .cs" , dataModelsDir ) ;
47
56
Utils . WriteOutputToFiles ( modelProjectContents . ModelProjectFileContent , modelProjectName , modelprojectDir ) ;
48
57
49
58
// Generate ConsoleApp Project
50
- var consoleAppProjectContents = GenerateConsoleAppProjectContents ( namespaceValue , labelTypeCsharp ) ;
59
+ var consoleAppProjectContents = GenerateConsoleAppProjectContents ( namespaceValue , labelTypeCsharp , includeLightGbmPackage , includeMklComponentsPackage ) ;
51
60
52
61
// Write files to disk.
53
62
var consoleAppProjectDir = Path . Combine ( settings . OutputBaseDir , $ "{ settings . OutputName } .ConsoleApp") ;
@@ -65,12 +74,33 @@ public void GenerateOutput()
65
74
Utils . AddProjectsToSolution ( modelprojectDir , modelProjectName , consoleAppProjectDir , consoleAppProjectName , solutionPath ) ;
66
75
}
67
76
68
- internal ( string ConsoleAppProgramCSFileContent , string ConsoleAppProjectFileContent , string modelBuilderCSFileContent ) GenerateConsoleAppProjectContents ( string namespaceValue , Type labelTypeCsharp )
77
+ private void SetRequiredNugetPackages ( IEnumerable < PipelineNode > trainerNodes , ref bool includeLightGbmPackage , ref bool includeMklComponentsPackage )
78
+ {
79
+ foreach ( var node in trainerNodes )
80
+ {
81
+ PipelineNode currentNode = node ;
82
+ if ( currentNode . Name == TrainerName . Ova . ToString ( ) )
83
+ {
84
+ currentNode = ( PipelineNode ) currentNode . Properties [ "BinaryTrainer" ] ;
85
+ }
86
+
87
+ if ( LightGBMTrainers . Contains ( currentNode . Name ) )
88
+ {
89
+ includeLightGbmPackage = true ;
90
+ }
91
+ else if ( mklComponentsTrainers . Contains ( currentNode . Name ) )
92
+ {
93
+ includeMklComponentsPackage = true ;
94
+ }
95
+ }
96
+ }
97
+
98
+ internal ( string ConsoleAppProgramCSFileContent , string ConsoleAppProjectFileContent , string modelBuilderCSFileContent ) GenerateConsoleAppProjectContents ( string namespaceValue , Type labelTypeCsharp , bool includeLightGbmPackage , bool includeMklComponentsPackage )
69
99
{
70
100
var predictProgramCSFileContent = GeneratePredictProgramCSFileContent ( namespaceValue ) ;
71
101
predictProgramCSFileContent = Utils . FormatCode ( predictProgramCSFileContent ) ;
72
102
73
- var predictProjectFileContent = GeneratPredictProjectFileContent ( namespaceValue , true , true ) ;
103
+ var predictProjectFileContent = GeneratPredictProjectFileContent ( namespaceValue , includeLightGbmPackage , includeMklComponentsPackage ) ;
74
104
75
105
var transformsAndTrainers = GenerateTransformsAndTrainers ( ) ;
76
106
var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent ( transformsAndTrainers . Usings , transformsAndTrainers . TrainerMethod , transformsAndTrainers . PreTrainerTransforms , transformsAndTrainers . PostTrainerTransforms , namespaceValue , pipeline . CacheBeforeTrainer , labelTypeCsharp . Name ) ;
@@ -79,14 +109,14 @@ public void GenerateOutput()
79
109
return ( predictProgramCSFileContent , predictProjectFileContent , modelBuilderCSFileContent ) ;
80
110
}
81
111
82
- internal ( string ObservationCSFileContent , string PredictionCSFileContent , string ModelProjectFileContent ) GenerateModelProjectContents ( string namespaceValue , Type labelTypeCsharp )
112
+ internal ( string ObservationCSFileContent , string PredictionCSFileContent , string ModelProjectFileContent ) GenerateModelProjectContents ( string namespaceValue , Type labelTypeCsharp , bool includeLightGbmPackage , bool includeMklComponentsPackage )
83
113
{
84
114
var classLabels = this . GenerateClassLabels ( ) ;
85
115
var observationCSFileContent = GenerateObservationCSFileContent ( namespaceValue , classLabels ) ;
86
116
observationCSFileContent = Utils . FormatCode ( observationCSFileContent ) ;
87
117
var predictionCSFileContent = GeneratePredictionCSFileContent ( labelTypeCsharp . Name , namespaceValue ) ;
88
118
predictionCSFileContent = Utils . FormatCode ( predictionCSFileContent ) ;
89
- var modelProjectFileContent = GenerateModelProjectFileContent ( ) ;
119
+ var modelProjectFileContent = GenerateModelProjectFileContent ( includeLightGbmPackage , includeMklComponentsPackage ) ;
90
120
return ( observationCSFileContent , predictionCSFileContent , modelProjectFileContent ) ;
91
121
}
92
122
@@ -218,9 +248,9 @@ internal IList<string> GenerateClassLabels()
218
248
}
219
249
220
250
#region Model project
221
- private static string GenerateModelProjectFileContent ( )
251
+ private static string GenerateModelProjectFileContent ( bool includeLightGbmPackage , bool includeMklComponentsPackage )
222
252
{
223
- ModelProject modelProject = new ModelProject ( ) ;
253
+ ModelProject modelProject = new ModelProject ( ) { IncludeLightGBMPackage = includeLightGbmPackage , IncludeMklComponentsPackage = includeMklComponentsPackage } ;
224
254
return modelProject . TransformText ( ) ;
225
255
}
226
256
@@ -238,9 +268,9 @@ private string GenerateObservationCSFileContent(string namespaceValue, IList<str
238
268
#endregion
239
269
240
270
#region Predict Project
241
- private static string GeneratPredictProjectFileContent ( string namespaceValue , bool includeMklComponentsPackage , bool includeLightGBMPackage )
271
+ private static string GeneratPredictProjectFileContent ( string namespaceValue , bool includeLightGbmPackage , bool includeMklComponentsPackage )
242
272
{
243
- var predictProjectFileContent = new PredictProject ( ) { Namespace = namespaceValue , IncludeMklComponentsPackage = includeMklComponentsPackage , IncludeLightGBMPackage = includeLightGBMPackage } ;
273
+ var predictProjectFileContent = new PredictProject ( ) { Namespace = namespaceValue , IncludeMklComponentsPackage = includeMklComponentsPackage , IncludeLightGBMPackage = includeLightGbmPackage } ;
244
274
return predictProjectFileContent . TransformText ( ) ;
245
275
}
246
276
@@ -290,6 +320,5 @@ private string GenerateModelBuilderCSFileContent(string usings,
290
320
return modelBuilder . TransformText ( ) ;
291
321
}
292
322
#endregion
293
-
294
323
}
295
324
}
0 commit comments