Skip to content

Commit ffeed1b

Browse files
authored
[AutoML] Minor changes to generated project in CLI based on feedback (#3371)
* nitpicks for generated project * revert back the target framework
1 parent 4ead03d commit ffeed1b

14 files changed

+132
-99
lines changed

src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ internal class CodeGenerator : IProjectGenerator
1919
private readonly Pipeline pipeline;
2020
private readonly CodeGeneratorSettings settings;
2121
private readonly ColumnInferenceResults columnInferenceResult;
22+
private readonly HashSet<string> LightGBMTrainers = new HashSet<string>() { TrainerName.LightGbmBinary.ToString(), TrainerName.LightGbmMulti.ToString(), TrainerName.LightGbmRegression.ToString() };
23+
private readonly HashSet<string> mklComponentsTrainers = new HashSet<string>() { TrainerName.OlsRegression.ToString(), TrainerName.SymbolicSgdLogisticRegressionBinary.ToString() };
2224

2325
internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInferenceResult, CodeGeneratorSettings settings)
2426
{
@@ -29,25 +31,32 @@ internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInference
2931

3032
public void GenerateOutput()
3133
{
34+
// Get the extra nuget packages to be included in the generated project.
35+
var trainerNodes = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Trainer);
36+
37+
bool includeLightGbmPackage = false;
38+
bool includeMklComponentsPackage = false;
39+
SetRequiredNugetPackages(trainerNodes, ref includeLightGbmPackage, ref includeMklComponentsPackage);
40+
3241
// Get Namespace
3342
var namespaceValue = Utils.Normalize(settings.OutputName);
3443
var labelType = columnInferenceResult.TextLoaderOptions.Columns.Where(t => t.Name == columnInferenceResult.ColumnInformation.LabelColumnName).First().DataKind;
3544
Type labelTypeCsharp = Utils.GetCSharpType(labelType);
3645

3746
// Generate Model Project
38-
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp);
47+
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage);
3948

4049
// Write files to disk.
4150
var modelprojectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.Model");
4251
var dataModelsDir = Path.Combine(modelprojectDir, "DataModels");
4352
var modelProjectName = $"{settings.OutputName}.Model.csproj";
4453

45-
Utils.WriteOutputToFiles(modelProjectContents.ObservationCSFileContent, "Observation.cs", dataModelsDir);
46-
Utils.WriteOutputToFiles(modelProjectContents.PredictionCSFileContent, "Prediction.cs", dataModelsDir);
54+
Utils.WriteOutputToFiles(modelProjectContents.ObservationCSFileContent, "SampleObservation.cs", dataModelsDir);
55+
Utils.WriteOutputToFiles(modelProjectContents.PredictionCSFileContent, "SamplePrediction.cs", dataModelsDir);
4756
Utils.WriteOutputToFiles(modelProjectContents.ModelProjectFileContent, modelProjectName, modelprojectDir);
4857

4958
// Generate ConsoleApp Project
50-
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp);
59+
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage);
5160

5261
// Write files to disk.
5362
var consoleAppProjectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.ConsoleApp");
@@ -65,12 +74,33 @@ public void GenerateOutput()
6574
Utils.AddProjectsToSolution(modelprojectDir, modelProjectName, consoleAppProjectDir, consoleAppProjectName, solutionPath);
6675
}
6776

68-
internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp)
77+
private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, ref bool includeLightGbmPackage, ref bool includeMklComponentsPackage)
78+
{
79+
foreach (var node in trainerNodes)
80+
{
81+
PipelineNode currentNode = node;
82+
if (currentNode.Name == TrainerName.Ova.ToString())
83+
{
84+
currentNode = (PipelineNode)currentNode.Properties["BinaryTrainer"];
85+
}
86+
87+
if (LightGBMTrainers.Contains(currentNode.Name))
88+
{
89+
includeLightGbmPackage = true;
90+
}
91+
else if (mklComponentsTrainers.Contains(currentNode.Name))
92+
{
93+
includeMklComponentsPackage = true;
94+
}
95+
}
96+
}
97+
98+
internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage)
6999
{
70100
var predictProgramCSFileContent = GeneratePredictProgramCSFileContent(namespaceValue);
71101
predictProgramCSFileContent = Utils.FormatCode(predictProgramCSFileContent);
72102

73-
var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, true, true);
103+
var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, includeLightGbmPackage, includeMklComponentsPackage);
74104

75105
var transformsAndTrainers = GenerateTransformsAndTrainers();
76106
var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent(transformsAndTrainers.Usings, transformsAndTrainers.TrainerMethod, transformsAndTrainers.PreTrainerTransforms, transformsAndTrainers.PostTrainerTransforms, namespaceValue, pipeline.CacheBeforeTrainer, labelTypeCsharp.Name);
@@ -79,14 +109,14 @@ public void GenerateOutput()
79109
return (predictProgramCSFileContent, predictProjectFileContent, modelBuilderCSFileContent);
80110
}
81111

82-
internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp)
112+
internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage)
83113
{
84114
var classLabels = this.GenerateClassLabels();
85115
var observationCSFileContent = GenerateObservationCSFileContent(namespaceValue, classLabels);
86116
observationCSFileContent = Utils.FormatCode(observationCSFileContent);
87117
var predictionCSFileContent = GeneratePredictionCSFileContent(labelTypeCsharp.Name, namespaceValue);
88118
predictionCSFileContent = Utils.FormatCode(predictionCSFileContent);
89-
var modelProjectFileContent = GenerateModelProjectFileContent();
119+
var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage);
90120
return (observationCSFileContent, predictionCSFileContent, modelProjectFileContent);
91121
}
92122

@@ -218,9 +248,9 @@ internal IList<string> GenerateClassLabels()
218248
}
219249

220250
#region Model project
221-
private static string GenerateModelProjectFileContent()
251+
private static string GenerateModelProjectFileContent(bool includeLightGbmPackage, bool includeMklComponentsPackage)
222252
{
223-
ModelProject modelProject = new ModelProject();
253+
ModelProject modelProject = new ModelProject() { IncludeLightGBMPackage = includeLightGbmPackage, IncludeMklComponentsPackage = includeMklComponentsPackage };
224254
return modelProject.TransformText();
225255
}
226256

@@ -238,9 +268,9 @@ private string GenerateObservationCSFileContent(string namespaceValue, IList<str
238268
#endregion
239269

240270
#region Predict Project
241-
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeMklComponentsPackage, bool includeLightGBMPackage)
271+
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeLightGbmPackage, bool includeMklComponentsPackage)
242272
{
243-
var predictProjectFileContent = new PredictProject() { Namespace = namespaceValue, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeLightGBMPackage = includeLightGBMPackage };
273+
var predictProjectFileContent = new PredictProject() { Namespace = namespaceValue, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeLightGBMPackage = includeLightGbmPackage };
244274
return predictProjectFileContent.TransformText();
245275
}
246276

@@ -290,6 +320,5 @@ private string GenerateModelBuilderCSFileContent(string usings,
290320
return modelBuilder.TransformText();
291321
}
292322
#endregion
293-
294323
}
295324
}

src/mlnet/Templates/Console/ModelBuilder.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,7 @@ public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDat
224224
{
225225
// Save/persist the trained model to a .ZIP file
226226
Console.WriteLine($""=============== Saving the model ==============="");
227-
using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write))
228-
mlContext.Model.Save(mlModel, modelInputSchema, fs);
229-
227+
mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath));
230228
Console.WriteLine(""The model is saved to {0}"", GetAbsolutePath(modelRelativePath));
231229
}
232230

src/mlnet/Templates/Console/ModelBuilder.tt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,7 @@ else{#>
156156
{
157157
// Save/persist the trained model to a .ZIP file
158158
Console.WriteLine($"=============== Saving the model ===============");
159-
using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write))
160-
mlContext.Model.Save(mlModel, modelInputSchema, fs);
161-
159+
mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath));
162160
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
163161
}
164162

src/mlnet/Templates/Console/ModelProject.cs

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace Microsoft.ML.CLI.Templates.Console
1818
/// Class to produce the template output
1919
/// </summary>
2020

21-
#line 1 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\ModelProject.tt"
21+
#line 1 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
2222
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")]
2323
public partial class ModelProject : ModelProjectBase
2424
{
@@ -28,30 +28,51 @@ public partial class ModelProject : ModelProjectBase
2828
/// </summary>
2929
public virtual string TransformText()
3030
{
31-
this.Write(@"<Project Sdk=""Microsoft.NET.Sdk"">
32-
33-
<PropertyGroup>
34-
<TargetFramework>netcoreapp2.1</TargetFramework>
35-
</PropertyGroup>
36-
<PropertyGroup>
37-
<RestoreSources>
38-
https://api.nuget.org/v3/index.json;
39-
</RestoreSources>
40-
</PropertyGroup>
41-
<ItemGroup>
42-
<PackageReference Include=""Microsoft.ML"" Version=""1.0.0-preview"" />
43-
</ItemGroup>
44-
45-
<ItemGroup>
46-
<None Update=""MLModel.zip"">
47-
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
48-
</None>
49-
</ItemGroup>
50-
51-
</Project>
52-
");
31+
this.Write("<Project Sdk=\"Microsoft.NET.Sdk\">\r\n\r\n <PropertyGroup>\r\n <TargetFramework>netc" +
32+
"oreapp2.1</TargetFramework>\r\n </PropertyGroup>\r\n <ItemGroup>\r\n <PackageRefe" +
33+
"rence Include=\"Microsoft.ML\" Version=\"1.0.0-preview\" />\r\n");
34+
35+
#line 13 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
36+
if(IncludeLightGBMPackage){
37+
38+
#line default
39+
#line hidden
40+
this.Write(" <PackageReference Include=\"Microsoft.ML.LightGBM\" Version=\"1.0.0-preview\" />\r" +
41+
"\n");
42+
43+
#line 15 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
44+
}
45+
46+
#line default
47+
#line hidden
48+
49+
#line 16 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
50+
if(IncludeMklComponentsPackage){
51+
52+
#line default
53+
#line hidden
54+
this.Write(" <PackageReference Include=\"Microsoft.ML.Mkl.Components\" Version=\"1.0.0-previe" +
55+
"w\" />\r\n");
56+
57+
#line 18 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
58+
}
59+
60+
#line default
61+
#line hidden
62+
this.Write(" </ItemGroup>\r\n\r\n <ItemGroup>\r\n <None Update=\"MLModel.zip\">\r\n <CopyToOu" +
63+
"tputDirectory>PreserveNewest</CopyToOutputDirectory>\r\n </None>\r\n </ItemGroup" +
64+
">\r\n \r\n</Project>\r\n");
5365
return this.GenerationEnvironment.ToString();
5466
}
67+
68+
#line 28 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
69+
70+
public bool IncludeLightGBMPackage {get;set;}
71+
public bool IncludeMklComponentsPackage {get;set;}
72+
73+
74+
#line default
75+
#line hidden
5576
}
5677

5778
#line default

src/mlnet/Templates/Console/ModelProject.tt

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77

88
<PropertyGroup>
99
<TargetFramework>netcoreapp2.1</TargetFramework>
10-
</PropertyGroup>
11-
<PropertyGroup>
12-
<RestoreSources>
13-
https://api.nuget.org/v3/index.json;
14-
</RestoreSources>
1510
</PropertyGroup>
1611
<ItemGroup>
1712
<PackageReference Include="Microsoft.ML" Version="1.0.0-preview" />
13+
<# if(IncludeLightGBMPackage){ #>
14+
<PackageReference Include="Microsoft.ML.LightGBM" Version="1.0.0-preview" />
15+
<#}#>
16+
<# if(IncludeMklComponentsPackage){ #>
17+
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
18+
<#}#>
1819
</ItemGroup>
1920

2021
<ItemGroup>
@@ -24,3 +25,7 @@
2425
</ItemGroup>
2526

2627
</Project>
28+
<#+
29+
public bool IncludeLightGBMPackage {get;set;}
30+
public bool IncludeMklComponentsPackage {get;set;}
31+
#>

0 commit comments

Comments
 (0)