Skip to content

Commit c832e27

Browse files
authored
fixed path bug and regression metrics correction (#3504)
1 parent da3a403 commit c832e27

6 files changed

+79
-46
lines changed

src/mlnet/Templates/Console/ModelBuilder.cs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -256,21 +256,21 @@ public static string GetAbsolutePath(string relativePath)
256256
"s)\r\n {\r\n var L1 = crossValidationResults.Select(r => r.Metrics" +
257257
".MeanAbsoluteError);\r\n var L2 = crossValidationResults.Select(r => r." +
258258
"Metrics.MeanSquaredError);\r\n var RMS = crossValidationResults.Select(" +
259-
"r => r.Metrics.MeanAbsoluteError);\r\n var lossFunction = crossValidati" +
260-
"onResults.Select(r => r.Metrics.LossFunction);\r\n var R2 = crossValida" +
261-
"tionResults.Select(r => r.Metrics.RSquared);\r\n\r\n Console.WriteLine($\"" +
262-
"********************************************************************************" +
263-
"*****************************\");\r\n Console.WriteLine($\"* Metric" +
264-
"s for Regression model \");\r\n Console.WriteLine($\"*--------------" +
259+
"r => r.Metrics.RootMeanSquaredError);\r\n var lossFunction = crossValid" +
260+
"ationResults.Select(r => r.Metrics.LossFunction);\r\n var R2 = crossVal" +
261+
"idationResults.Select(r => r.Metrics.RSquared);\r\n\r\n Console.WriteLine" +
262+
"($\"*****************************************************************************" +
263+
"********************************\");\r\n Console.WriteLine($\"* Met" +
264+
"rics for Regression model \");\r\n Console.WriteLine($\"*-----------" +
265265
"--------------------------------------------------------------------------------" +
266-
"--------------\");\r\n Console.WriteLine($\"* Average L1 Loss: {" +
267-
"L1.Average():0.###} \");\r\n Console.WriteLine($\"* Average L2 Loss" +
268-
": {L2.Average():0.###} \");\r\n Console.WriteLine($\"* Average " +
269-
"RMS: {RMS.Average():0.###} \");\r\n Console.WriteLine($\"* " +
270-
" Average Loss Function: {lossFunction.Average():0.###} \");\r\n Consol" +
271-
"e.WriteLine($\"* Average R-squared: {R2.Average():0.###} \");\r\n " +
272-
"Console.WriteLine($\"************************************************************" +
273-
"*************************************************\");\r\n }\r\n");
266+
"-----------------\");\r\n Console.WriteLine($\"* Average L1 Loss: " +
267+
" {L1.Average():0.###} \");\r\n Console.WriteLine($\"* Average L" +
268+
"2 Loss: {L2.Average():0.###} \");\r\n Console.WriteLine($\"* " +
269+
" Average RMS: {RMS.Average():0.###} \");\r\n Console.WriteLin" +
270+
"e($\"* Average Loss Function: {lossFunction.Average():0.###} \");\r\n " +
271+
" Console.WriteLine($\"* Average R-squared: {R2.Average():0.###} \");" +
272+
"\r\n Console.WriteLine($\"**********************************************" +
273+
"***************************************************************\");\r\n }\r\n");
274274
} if("BinaryClassification".Equals(TaskType)){
275275
this.Write(" public static void PrintBinaryClassificationMetrics(BinaryClassificationM" +
276276
"etrics metrics)\r\n {\r\n Console.WriteLine($\"********************" +

src/mlnet/Templates/Console/ModelBuilder.tt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,18 +188,18 @@ else{#>
188188
{
189189
var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
190190
var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);
191-
var RMS = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
191+
var RMS = crossValidationResults.Select(r => r.Metrics.RootMeanSquaredError);
192192
var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);
193193
var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);
194194

195195
Console.WriteLine($"*************************************************************************************************************");
196196
Console.WriteLine($"* Metrics for Regression model ");
197197
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
198-
Console.WriteLine($"* Average L1 Loss: {L1.Average():0.###} ");
199-
Console.WriteLine($"* Average L2 Loss: {L2.Average():0.###} ");
200-
Console.WriteLine($"* Average RMS: {RMS.Average():0.###} ");
198+
Console.WriteLine($"* Average L1 Loss: {L1.Average():0.###} ");
199+
Console.WriteLine($"* Average L2 Loss: {L2.Average():0.###} ");
200+
Console.WriteLine($"* Average RMS: {RMS.Average():0.###} ");
201201
Console.WriteLine($"* Average Loss Function: {lossFunction.Average():0.###} ");
202-
Console.WriteLine($"* Average R-squared: {R2.Average():0.###} ");
202+
Console.WriteLine($"* Average R-squared: {R2.Average():0.###} ");
203203
Console.WriteLine($"*************************************************************************************************************");
204204
}
205205
<# } if("BinaryClassification".Equals(TaskType)){ #>

src/mlnet/Templates/Console/PredictProgram.cs

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,19 @@ public virtual string TransformText()
3737
//*****************************************************************************************
3838
3939
using System;
40+
using System.IO;
4041
using System.Linq;
4142
using Microsoft.ML;
4243
using ");
4344

44-
#line 17 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
45+
#line 18 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
4546
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
4647

4748
#line default
4849
#line hidden
4950
this.Write(".Model.DataModels;\r\n\r\n\r\nnamespace ");
5051

51-
#line 20 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
52+
#line 21 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
5253
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
5354

5455
#line default
@@ -57,35 +58,35 @@ public virtual string TransformText()
5758
"d and use for predictions\r\n private const string MODEL_FILEPATH = @\"MLMod" +
5859
"el.zip\";\r\n\r\n //Dataset to use for predictions \r\n");
5960

60-
#line 28 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
61+
#line 29 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
6162
if(string.IsNullOrEmpty(TestDataPath)){
6263

6364
#line default
6465
#line hidden
6566
this.Write(" private const string DATA_FILEPATH = @\"");
6667

67-
#line 29 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
68+
#line 30 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
6869
this.Write(this.ToStringHelper.ToStringWithCulture(TrainDataPath));
6970

7071
#line default
7172
#line hidden
7273
this.Write("\";\r\n");
7374

74-
#line 30 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
75+
#line 31 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
7576
} else{
7677

7778
#line default
7879
#line hidden
7980
this.Write(" private const string DATA_FILEPATH = @\"");
8081

81-
#line 31 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
82+
#line 32 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
8283
this.Write(this.ToStringHelper.ToStringWithCulture(TestDataPath));
8384

8485
#line default
8586
#line hidden
8687
this.Write("\";\r\n");
8788

88-
#line 32 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
89+
#line 33 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
8990
}
9091

9192
#line default
@@ -98,7 +99,7 @@ static void Main(string[] args)
9899
// Training code used by ML.NET CLI and AutoML to generate the model
99100
//ModelBuilder.CreateModel();
100101
101-
ITransformer mlModel = mlContext.Model.Load(MODEL_FILEPATH, out DataViewSchema inputSchema);
102+
ITransformer mlModel = mlContext.Model.Load(GetAbsolutePath(MODEL_FILEPATH), out DataViewSchema inputSchema);
102103
var predEngine = mlContext.Model.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlModel);
103104
104105
// Create sample data to do a single prediction with it
@@ -109,50 +110,50 @@ static void Main(string[] args)
109110
110111
");
111112

112-
#line 50 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
113+
#line 51 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
113114
if("BinaryClassification".Equals(TaskType)){
114115

115116
#line default
116117
#line hidden
117118
this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData.");
118119

119-
#line 51 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
120+
#line 52 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
120121
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
121122

122123
#line default
123124
#line hidden
124125
this.Write("} | Predicted value: {predictionResult.Prediction}\");\r\n");
125126

126-
#line 52 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
127+
#line 53 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
127128
}else if("Regression".Equals(TaskType)){
128129

129130
#line default
130131
#line hidden
131132
this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData.");
132133

133-
#line 53 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
134+
#line 54 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
134135
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
135136

136137
#line default
137138
#line hidden
138139
this.Write("} | Predicted value: {predictionResult.Score}\");\r\n");
139140

140-
#line 54 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
141+
#line 55 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
141142
} else if("MulticlassClassification".Equals(TaskType)){
142143

143144
#line default
144145
#line hidden
145146
this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData.");
146147

147-
#line 55 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
148+
#line 56 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
148149
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
149150

150151
#line default
151152
#line hidden
152153
this.Write("} | Predicted value: {predictionResult.Prediction} | Predicted scores: [{String.J" +
153154
"oin(\",\", predictionResult.Score)}]\");\r\n");
154155

155-
#line 56 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
156+
#line 57 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
156157
}
157158

158159
#line default
@@ -171,28 +172,28 @@ private static SampleObservation CreateSingleDataSample(MLContext mlContext, str
171172
path: dataFilePath,
172173
hasHeader : ");
173174

174-
#line 69 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
175+
#line 70 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
175176
this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant()));
176177

177178
#line default
178179
#line hidden
179180
this.Write(",\r\n separatorChar : \'");
180181

181-
#line 70 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
182+
#line 71 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
182183
this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString())));
183184

184185
#line default
185186
#line hidden
186187
this.Write("\',\r\n allowQuoting : ");
187188

188-
#line 71 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
189+
#line 72 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
189190
this.Write(this.ToStringHelper.ToStringWithCulture(AllowQuoting.ToString().ToLowerInvariant()));
190191

191192
#line default
192193
#line hidden
193194
this.Write(",\r\n allowSparse: ");
194195

195-
#line 72 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
196+
#line 73 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
196197
this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant()));
197198

198199
#line default
@@ -204,13 +205,23 @@ private static SampleObservation CreateSingleDataSample(MLContext mlContext, str
204205
.First();
205206
return sampleForPrediction;
206207
}
208+
209+
public static string GetAbsolutePath(string relativePath)
210+
{
211+
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
212+
string assemblyFolderPath = _dataRoot.Directory.FullName;
213+
214+
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
215+
216+
return fullPath;
217+
}
207218
}
208219
}
209220
");
210221
return this.GenerationEnvironment.ToString();
211222
}
212223

213-
#line 81 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
224+
#line 92 "E:\src\machinelearning\src\mlnet\Templates\Console\PredictProgram.tt"
214225

215226
public string TaskType {get;set;}
216227
public string Namespace {get;set;}

src/mlnet/Templates/Console/PredictProgram.tt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//*****************************************************************************************
1313

1414
using System;
15+
using System.IO;
1516
using System.Linq;
1617
using Microsoft.ML;
1718
using <#= Namespace #>.Model.DataModels;
@@ -38,7 +39,7 @@ namespace <#= Namespace #>.ConsoleApp
3839
// Training code used by ML.NET CLI and AutoML to generate the model
3940
//ModelBuilder.CreateModel();
4041

41-
ITransformer mlModel = mlContext.Model.Load(MODEL_FILEPATH, out DataViewSchema inputSchema);
42+
ITransformer mlModel = mlContext.Model.Load(GetAbsolutePath(MODEL_FILEPATH), out DataViewSchema inputSchema);
4243
var predEngine = mlContext.Model.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlModel);
4344

4445
// Create sample data to do a single prediction with it
@@ -76,6 +77,16 @@ namespace <#= Namespace #>.ConsoleApp
7677
.First();
7778
return sampleForPrediction;
7879
}
80+
81+
public static string GetAbsolutePath(string relativePath)
82+
{
83+
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
84+
string assemblyFolderPath = _dataRoot.Directory.FullName;
85+
86+
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
87+
88+
return fullPath;
89+
}
7990
}
8091
}
8192
<#+

test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentRegressionTest.approved.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,18 @@ namespace TestNamespace.ConsoleApp
119119
{
120120
var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
121121
var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);
122-
var RMS = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
122+
var RMS = crossValidationResults.Select(r => r.Metrics.RootMeanSquaredError);
123123
var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);
124124
var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);
125125

126126
Console.WriteLine($"*************************************************************************************************************");
127127
Console.WriteLine($"* Metrics for Regression model ");
128128
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
129-
Console.WriteLine($"* Average L1 Loss: {L1.Average():0.###} ");
130-
Console.WriteLine($"* Average L2 Loss: {L2.Average():0.###} ");
131-
Console.WriteLine($"* Average RMS: {RMS.Average():0.###} ");
129+
Console.WriteLine($"* Average L1 Loss: {L1.Average():0.###} ");
130+
Console.WriteLine($"* Average L2 Loss: {L2.Average():0.###} ");
131+
Console.WriteLine($"* Average RMS: {RMS.Average():0.###} ");
132132
Console.WriteLine($"* Average Loss Function: {lossFunction.Average():0.###} ");
133-
Console.WriteLine($"* Average R-squared: {R2.Average():0.###} ");
133+
Console.WriteLine($"* Average R-squared: {R2.Average():0.###} ");
134134
Console.WriteLine($"*************************************************************************************************************");
135135
}
136136
}

test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProgramCSFileContentTest.approved.txt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//*****************************************************************************************
66

77
using System;
8+
using System.IO;
89
using System.Linq;
910
using Microsoft.ML;
1011
using TestNamespace.Model.DataModels;
@@ -27,7 +28,7 @@ namespace TestNamespace.ConsoleApp
2728
// Training code used by ML.NET CLI and AutoML to generate the model
2829
//ModelBuilder.CreateModel();
2930

30-
ITransformer mlModel = mlContext.Model.Load(MODEL_FILEPATH, out DataViewSchema inputSchema);
31+
ITransformer mlModel = mlContext.Model.Load(GetAbsolutePath(MODEL_FILEPATH), out DataViewSchema inputSchema);
3132
var predEngine = mlContext.Model.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlModel);
3233

3334
// Create sample data to do a single prediction with it
@@ -59,5 +60,15 @@ namespace TestNamespace.ConsoleApp
5960
.First();
6061
return sampleForPrediction;
6162
}
63+
64+
public static string GetAbsolutePath(string relativePath)
65+
{
66+
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
67+
string assemblyFolderPath = _dataRoot.Directory.FullName;
68+
69+
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
70+
71+
return fullPath;
72+
}
6273
}
6374
}

0 commit comments

Comments
 (0)