diff --git a/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs b/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs index 908ef6d9d3..bb0fc079d5 100644 --- a/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs +++ b/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs @@ -19,6 +19,8 @@ internal class CodeGenerator : IProjectGenerator private readonly Pipeline pipeline; private readonly CodeGeneratorSettings settings; private readonly ColumnInferenceResults columnInferenceResult; + private readonly HashSet LightGBMTrainers = new HashSet() { TrainerName.LightGbmBinary.ToString(), TrainerName.LightGbmMulti.ToString(), TrainerName.LightGbmRegression.ToString() }; + private readonly HashSet mklComponentsTrainers = new HashSet() { TrainerName.OlsRegression.ToString(), TrainerName.SymbolicSgdLogisticRegressionBinary.ToString() }; internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInferenceResult, CodeGeneratorSettings settings) { @@ -29,25 +31,32 @@ internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInference public void GenerateOutput() { + // Get the extra nuget packages to be included in the generated project. + var trainerNodes = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Trainer); + + bool includeLightGbmPackage = false; + bool includeMklComponentsPackage = false; + SetRequiredNugetPackages(trainerNodes, ref includeLightGbmPackage, ref includeMklComponentsPackage); + // Get Namespace var namespaceValue = Utils.Normalize(settings.OutputName); var labelType = columnInferenceResult.TextLoaderOptions.Columns.Where(t => t.Name == columnInferenceResult.ColumnInformation.LabelColumnName).First().DataKind; Type labelTypeCsharp = Utils.GetCSharpType(labelType); // Generate Model Project - var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp); + var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage); // Write files to disk. var modelprojectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.Model"); var dataModelsDir = Path.Combine(modelprojectDir, "DataModels"); var modelProjectName = $"{settings.OutputName}.Model.csproj"; - Utils.WriteOutputToFiles(modelProjectContents.ObservationCSFileContent, "Observation.cs", dataModelsDir); - Utils.WriteOutputToFiles(modelProjectContents.PredictionCSFileContent, "Prediction.cs", dataModelsDir); + Utils.WriteOutputToFiles(modelProjectContents.ObservationCSFileContent, "SampleObservation.cs", dataModelsDir); + Utils.WriteOutputToFiles(modelProjectContents.PredictionCSFileContent, "SamplePrediction.cs", dataModelsDir); Utils.WriteOutputToFiles(modelProjectContents.ModelProjectFileContent, modelProjectName, modelprojectDir); // Generate ConsoleApp Project - var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp); + var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage); // Write files to disk. var consoleAppProjectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.ConsoleApp"); @@ -65,12 +74,33 @@ public void GenerateOutput() Utils.AddProjectsToSolution(modelprojectDir, modelProjectName, consoleAppProjectDir, consoleAppProjectName, solutionPath); } - internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp) + private void SetRequiredNugetPackages(IEnumerable trainerNodes, ref bool includeLightGbmPackage, ref bool includeMklComponentsPackage) + { + foreach (var node in trainerNodes) + { + PipelineNode currentNode = node; + if (currentNode.Name == TrainerName.Ova.ToString()) + { + currentNode = (PipelineNode)currentNode.Properties["BinaryTrainer"]; + } + + if (LightGBMTrainers.Contains(currentNode.Name)) + { + includeLightGbmPackage = true; + } + else if (mklComponentsTrainers.Contains(currentNode.Name)) + { + includeMklComponentsPackage = true; + } + } + } + + internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage) { var predictProgramCSFileContent = GeneratePredictProgramCSFileContent(namespaceValue); predictProgramCSFileContent = Utils.FormatCode(predictProgramCSFileContent); - var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, true, true); + var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, includeLightGbmPackage, includeMklComponentsPackage); var transformsAndTrainers = GenerateTransformsAndTrainers(); var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent(transformsAndTrainers.Usings, transformsAndTrainers.TrainerMethod, transformsAndTrainers.PreTrainerTransforms, transformsAndTrainers.PostTrainerTransforms, namespaceValue, pipeline.CacheBeforeTrainer, labelTypeCsharp.Name); @@ -79,14 +109,14 @@ public void GenerateOutput() return (predictProgramCSFileContent, predictProjectFileContent, modelBuilderCSFileContent); } - internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp) + internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage) { var classLabels = this.GenerateClassLabels(); var observationCSFileContent = GenerateObservationCSFileContent(namespaceValue, classLabels); observationCSFileContent = Utils.FormatCode(observationCSFileContent); var predictionCSFileContent = GeneratePredictionCSFileContent(labelTypeCsharp.Name, namespaceValue); predictionCSFileContent = Utils.FormatCode(predictionCSFileContent); - var modelProjectFileContent = GenerateModelProjectFileContent(); + var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage); return (observationCSFileContent, predictionCSFileContent, modelProjectFileContent); } @@ -218,9 +248,9 @@ internal IList GenerateClassLabels() } #region Model project - private static string GenerateModelProjectFileContent() + private static string GenerateModelProjectFileContent(bool includeLightGbmPackage, bool includeMklComponentsPackage) { - ModelProject modelProject = new ModelProject(); + ModelProject modelProject = new ModelProject() { IncludeLightGBMPackage = includeLightGbmPackage, IncludeMklComponentsPackage = includeMklComponentsPackage }; return modelProject.TransformText(); } @@ -238,9 +268,9 @@ private string GenerateObservationCSFileContent(string namespaceValue, IList { // Save/persist the trained model to a .ZIP file Console.WriteLine($"=============== Saving the model ==============="); - using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write)) - mlContext.Model.Save(mlModel, modelInputSchema, fs); - + mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath)); Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath)); } diff --git a/src/mlnet/Templates/Console/ModelProject.cs b/src/mlnet/Templates/Console/ModelProject.cs index 5a9b788408..ce20728401 100644 --- a/src/mlnet/Templates/Console/ModelProject.cs +++ b/src/mlnet/Templates/Console/ModelProject.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.CLI.Templates.Console /// Class to produce the template output /// - #line 1 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\ModelProject.tt" + #line 1 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt" [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] public partial class ModelProject : ModelProjectBase { @@ -28,30 +28,51 @@ public partial class ModelProject : ModelProjectBase /// public virtual string TransformText() { - this.Write(@" - - - netcoreapp2.1 - - - - https://api.nuget.org/v3/index.json; - - - - - - - - - PreserveNewest - - - - -"); + this.Write("\r\n\r\n \r\n netc" + + "oreapp2.1\r\n \r\n \r\n \r\n"); + + #line 13 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt" + if(IncludeLightGBMPackage){ + + #line default + #line hidden + this.Write(" \r" + + "\n"); + + #line 15 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt" +} + + #line default + #line hidden + + #line 16 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt" + if(IncludeMklComponentsPackage){ + + #line default + #line hidden + this.Write(" \r\n"); + + #line 18 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt" +} + + #line default + #line hidden + this.Write(" \r\n\r\n \r\n \r\n PreserveNewest\r\n \r\n \r\n \r\n\r\n"); return this.GenerationEnvironment.ToString(); } + + #line 28 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt" + +public bool IncludeLightGBMPackage {get;set;} +public bool IncludeMklComponentsPackage {get;set;} + + + #line default + #line hidden } #line default diff --git a/src/mlnet/Templates/Console/ModelProject.tt b/src/mlnet/Templates/Console/ModelProject.tt index 7ca417d9d1..72edb243d7 100644 --- a/src/mlnet/Templates/Console/ModelProject.tt +++ b/src/mlnet/Templates/Console/ModelProject.tt @@ -7,14 +7,15 @@ netcoreapp2.1 - - - - https://api.nuget.org/v3/index.json; - +<# if(IncludeLightGBMPackage){ #> + +<#}#> +<# if(IncludeMklComponentsPackage){ #> + +<#}#> @@ -24,3 +25,7 @@ +<#+ +public bool IncludeLightGBMPackage {get;set;} +public bool IncludeMklComponentsPackage {get;set;} +#> diff --git a/src/mlnet/Templates/Console/PredictProgram.cs b/src/mlnet/Templates/Console/PredictProgram.cs index 1bc67d8265..63e8039073 100644 --- a/src/mlnet/Templates/Console/PredictProgram.cs +++ b/src/mlnet/Templates/Console/PredictProgram.cs @@ -20,7 +20,7 @@ namespace Microsoft.ML.CLI.Templates.Console /// Class to produce the template output /// - #line 1 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 1 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] public partial class PredictProgram : PredictProgramBase { @@ -37,21 +37,18 @@ public virtual string TransformText() //***************************************************************************************** using System; -using System.IO; using System.Linq; -using System.Collections.Generic; using Microsoft.ML; -using Microsoft.ML.Data; using "); - #line 20 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 17 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" this.Write(this.ToStringHelper.ToStringWithCulture(Namespace)); #line default #line hidden this.Write(".Model.DataModels;\r\n\r\n\r\nnamespace "); - #line 23 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 20 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" this.Write(this.ToStringHelper.ToStringWithCulture(Namespace)); #line default @@ -60,35 +57,35 @@ public virtual string TransformText() "d and use for predictions\r\n private const string MODEL_FILEPATH = @\"MLMod" + "el.zip\";\r\n\r\n //Dataset to use for predictions \r\n"); - #line 31 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 28 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" if(string.IsNullOrEmpty(TestDataPath)){ #line default #line hidden this.Write(" private const string DATA_FILEPATH = @\""); - #line 32 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 29 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" this.Write(this.ToStringHelper.ToStringWithCulture(TrainDataPath)); #line default #line hidden this.Write("\";\r\n"); - #line 33 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 30 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" } else{ #line default #line hidden this.Write(" private const string DATA_FILEPATH = @\""); - #line 34 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 31 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" this.Write(this.ToStringHelper.ToStringWithCulture(TestDataPath)); #line default #line hidden this.Write("\";\r\n"); - #line 35 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 32 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" } #line default @@ -112,42 +109,42 @@ static void Main(string[] args) "); - #line 53 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 50 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" if("BinaryClassification".Equals(TaskType)){ #line default #line hidden this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData."); - #line 54 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 51 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName))); #line default #line hidden this.Write("} | Predicted value: {predictionResult.Prediction}\");\r\n"); - #line 55 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 52 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" }else if("Regression".Equals(TaskType)){ #line default #line hidden this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData."); - #line 56 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 53 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName))); #line default #line hidden this.Write("} | Predicted value: {predictionResult.Score}\");\r\n"); - #line 57 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 54 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" } else if("MulticlassClassification".Equals(TaskType)){ #line default #line hidden this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData."); - #line 58 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 55 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName))); #line default @@ -155,7 +152,7 @@ static void Main(string[] args) this.Write("} | Predicted value: {predictionResult.Prediction} | Predicted scores: [{String.J" + "oin(\",\", predictionResult.Score)}]\");\r\n"); - #line 59 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 56 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" } #line default @@ -174,28 +171,28 @@ private static SampleObservation CreateSingleDataSample(MLContext mlContext, str path: dataFilePath, hasHeader : "); - #line 72 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 69 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant())); #line default #line hidden this.Write(",\r\n separatorChar : \'"); - #line 73 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 70 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString()))); #line default #line hidden this.Write("\',\r\n allowQuoting : "); - #line 74 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 71 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" this.Write(this.ToStringHelper.ToStringWithCulture(AllowQuoting.ToString().ToLowerInvariant())); #line default #line hidden this.Write(",\r\n allowSparse: "); - #line 75 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 72 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant())); #line default @@ -213,7 +210,7 @@ private static SampleObservation CreateSingleDataSample(MLContext mlContext, str return this.GenerationEnvironment.ToString(); } - #line 84 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + #line 81 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt" public string TaskType {get;set;} public string Namespace {get;set;} diff --git a/src/mlnet/Templates/Console/PredictProgram.tt b/src/mlnet/Templates/Console/PredictProgram.tt index fc9ed43172..57f9722218 100644 --- a/src/mlnet/Templates/Console/PredictProgram.tt +++ b/src/mlnet/Templates/Console/PredictProgram.tt @@ -12,11 +12,8 @@ //***************************************************************************************** using System; -using System.IO; using System.Linq; -using System.Collections.Generic; using Microsoft.ML; -using Microsoft.ML.Data; using <#= Namespace #>.Model.DataModels; diff --git a/src/mlnet/Utilities/ProgressHandlers.cs b/src/mlnet/Utilities/ProgressHandlers.cs index 98519c210c..d5e08fa41b 100644 --- a/src/mlnet/Utilities/ProgressHandlers.cs +++ b/src/mlnet/Utilities/ProgressHandlers.cs @@ -40,7 +40,7 @@ public void Report(RunDetail iterationResult) iterationIndex++; UpdateBestResult(iterationResult); if (progressBar != null) - progressBar.Message = $"Best {this.optimizationMetric}: {GetScore(bestResult):F4}, Best Algorithm: {bestResult?.TrainerName}, Last Algorithm: {iterationResult?.TrainerName}"; + progressBar.Message = $"Best quality({this.optimizationMetric}): {GetScore(bestResult):F4}, Best Algorithm: {bestResult?.TrainerName}, Last Algorithm: {iterationResult?.TrainerName}"; ConsolePrinter.PrintMetrics(iterationIndex, iterationResult?.TrainerName, iterationResult?.ValidationMetrics, GetScore(bestResult), iterationResult?.RuntimeInSeconds, LogLevel.Trace); if (iterationResult.Exception != null) { diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentBinaryTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentBinaryTest.approved.txt index e25d5e46a3..7a4f93beff 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentBinaryTest.approved.txt +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentBinaryTest.approved.txt @@ -88,9 +88,7 @@ namespace TestNamespace.ConsoleApp { // Save/persist the trained model to a .ZIP file Console.WriteLine($"=============== Saving the model ==============="); - using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write)) - mlContext.Model.Save(mlModel, modelInputSchema, fs); - + mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath)); Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath)); } diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentOvaTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentOvaTest.approved.txt index 36b7deff19..7a8649b242 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentOvaTest.approved.txt +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentOvaTest.approved.txt @@ -88,9 +88,7 @@ namespace TestNamespace.ConsoleApp { // Save/persist the trained model to a .ZIP file Console.WriteLine($"=============== Saving the model ==============="); - using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write)) - mlContext.Model.Save(mlModel, modelInputSchema, fs); - + mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath)); Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath)); } diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentRegressionTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentRegressionTest.approved.txt index 122634ff08..8211bd7e61 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentRegressionTest.approved.txt +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentRegressionTest.approved.txt @@ -88,9 +88,7 @@ namespace TestNamespace.ConsoleApp { // Save/persist the trained model to a .ZIP file Console.WriteLine($"=============== Saving the model ==============="); - using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write)) - mlContext.Model.Save(mlModel, modelInputSchema, fs); - + mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath)); Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath)); } diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProgramCSFileContentTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProgramCSFileContentTest.approved.txt index 34ce3713fa..4798394512 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProgramCSFileContentTest.approved.txt +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProgramCSFileContentTest.approved.txt @@ -5,11 +5,8 @@ //***************************************************************************************** using System; -using System.IO; using System.Linq; -using System.Collections.Generic; using Microsoft.ML; -using Microsoft.ML.Data; using TestNamespace.Model.DataModels; diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ModelProjectFileContentTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ModelProjectFileContentTest.approved.txt index 8f1acbadb8..32cfc8d2ec 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ModelProjectFileContentTest.approved.txt +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ModelProjectFileContentTest.approved.txt @@ -2,14 +2,11 @@ netcoreapp2.1 - - - - https://api.nuget.org/v3/index.json; - + + diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs index d902ff9a64..6a94fc6156 100644 --- a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs @@ -42,7 +42,7 @@ public void ConsoleAppModelBuilderCSFileContentOvaTest() LabelName = "Label", ModelPath = "x:\\models\\model.zip" }); - var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float)); + var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true); Approvals.Verify(result.modelBuilderCSFileContent); } @@ -65,7 +65,7 @@ public void ConsoleAppModelBuilderCSFileContentBinaryTest() LabelName = "Label", ModelPath = "x:\\models\\model.zip" }); - var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float)); + var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true); Approvals.Verify(result.modelBuilderCSFileContent); } @@ -88,7 +88,7 @@ public void ConsoleAppModelBuilderCSFileContentRegressionTest() LabelName = "Label", ModelPath = "x:\\models\\model.zip" }); - var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float)); + var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true); Approvals.Verify(result.modelBuilderCSFileContent); } @@ -111,7 +111,7 @@ public void ModelProjectFileContentTest() LabelName = "Label", ModelPath = "x:\\models\\model.zip" }); - var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float)); + var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true); Approvals.Verify(result.ModelProjectFileContent); } @@ -134,7 +134,7 @@ public void ObservationCSFileContentTest() LabelName = "Label", ModelPath = "x:\\models\\model.zip" }); - var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float)); + var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true); Approvals.Verify(result.ObservationCSFileContent); } @@ -158,7 +158,7 @@ public void PredictionCSFileContentTest() LabelName = "Label", ModelPath = "x:\\models\\model.zip" }); - var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float)); + var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true); Approvals.Verify(result.PredictionCSFileContent); } @@ -181,7 +181,7 @@ public void ConsoleAppProgramCSFileContentTest() LabelName = "Label", ModelPath = "x:\\models\\model.zip" }); - var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float)); + var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true); Approvals.Verify(result.ConsoleAppProgramCSFileContent); } @@ -204,7 +204,7 @@ public void ConsoleAppProjectFileContentTest() LabelName = "Label", ModelPath = "x:\\models\\model.zip" }); - var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float)); + var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true); Approvals.Verify(result.ConsoleAppProjectFileContent); }