Skip to content

Commit 10c4fc6

Browse files
authored
Add V1 Introspective Training Tests (#2859)
* Adding introspective training scenario tests.
1 parent a558010 commit 10c4fc6

File tree

7 files changed

+568
-184
lines changed

7 files changed

+568
-184
lines changed

test/Microsoft.ML.Functional.Tests/Common.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,5 +267,35 @@ public static void AssertMetricsStatistics(RegressionMetricsStatistics metrics)
267267
AssertMetricStatistics(metrics.RSquared);
268268
AssertMetricStatistics(metrics.LossFunction);
269269
}
270+
271+
/// <summary>
272+
/// Verify that a float array has no NaNs or infinities.
273+
/// </summary>
274+
/// <param name="array">An array of doubles.</param>
275+
public static void AssertFiniteNumbers(IList<float> array, int ignoreElementAt = -1)
276+
{
277+
for (int i = 0; i < array.Count; i++)
278+
{
279+
if (i == ignoreElementAt)
280+
continue;
281+
Assert.False(float.IsNaN(array[i]));
282+
Assert.False(float.IsInfinity(array[i]));
283+
}
284+
}
285+
286+
/// <summary>
287+
/// Verify that a double array has no NaNs or infinities.
288+
/// </summary>
289+
/// <param name="array">An array of doubles.</param>
290+
public static void AssertFiniteNumbers(IList<double> array, int ignoreElementAt = -1)
291+
{
292+
for (int i = 0; i < array.Count; i++)
293+
{
294+
if (i == ignoreElementAt)
295+
continue;
296+
Assert.False(double.IsNaN(array[i]));
297+
Assert.False(double.IsInfinity(array[i]));
298+
}
299+
}
270300
}
271301
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Data;
6+
7+
namespace Microsoft.ML.Functional.Tests.Datasets
8+
{
9+
/// <summary>
10+
/// A class for the Adult test dataset.
11+
/// </summary>
12+
internal sealed class Adult
13+
{
14+
[LoadColumn(0)]
15+
public bool Label { get; set; }
16+
17+
[LoadColumn(1)]
18+
public string WorkClass { get; set; }
19+
20+
[LoadColumn(2)]
21+
public string Education { get; set; }
22+
23+
[LoadColumn(3)]
24+
public string MaritalStatus { get; set; }
25+
26+
[LoadColumn(4)]
27+
public string Occupation { get; set; }
28+
29+
[LoadColumn(5)]
30+
public string Relationship { get; set; }
31+
32+
[LoadColumn(6)]
33+
public string Ethnicity { get; set; }
34+
35+
[LoadColumn(7)]
36+
public string Sex { get; set; }
37+
38+
[LoadColumn(8)]
39+
public string NativeCountryRegion { get; set; }
40+
41+
[LoadColumn(9)]
42+
public float Age { get; set; }
43+
44+
[LoadColumn(10)]
45+
public float FinalWeight { get; set; }
46+
47+
[LoadColumn(11)]
48+
public float EducationNum { get; set; }
49+
50+
[LoadColumn(12)]
51+
public float CapitalGain { get; set; }
52+
53+
[LoadColumn(13)]
54+
public float CapitalLoss { get; set; }
55+
56+
[LoadColumn(14)]
57+
public float HoursPerWeek { get; set; }
58+
59+
/// <summary>
60+
/// The list of columns commonly used as categorical features.
61+
/// </summary>
62+
public static readonly string[] CategoricalFeatures = new string[] { "WorkClass", "Education", "MaritalStatus", "Occupation", "Relationship", "Ethnicity", "Sex", "NativeCountryRegion" };
63+
64+
/// <summary>
65+
/// The list of columns commonly used as numerical features.
66+
/// </summary>
67+
public static readonly string[] NumericalFeatures = new string[] { "Age", "FinalWeight", "EducationNum", "CapitalGain", "CapitalLoss", "HoursPerWeek" };
68+
}
69+
}

test/Microsoft.ML.Functional.Tests/Evaluation.cs

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,17 +237,11 @@ public void TrainAndEvaluateRegression()
237237
{
238238
var mlContext = new MLContext(seed: 1);
239239

240-
// Get the dataset.
241-
var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(),
242-
hasHeader: TestDatasets.housing.fileHasHeader, separatorChar: TestDatasets.housing.fileSeparator)
243-
.Load(GetDataPath(TestDatasets.housing.trainFilename));
244-
245-
// Create a pipeline to train on the sentiment data.
246-
var pipeline = mlContext.Transforms.Concatenate("Features", new string[] {
247-
"CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling",
248-
"PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"})
249-
.Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue"))
250-
.Append(mlContext.Regression.Trainers.FastTree(new FastTreeRegressionTrainer.Options { NumberOfThreads = 1 }));
240+
// Get the dataset
241+
var data = mlContext.Data.LoadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);
242+
// Create a pipeline to train on the housing data.
243+
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
244+
.Append(mlContext.Regression.Trainers.FastForest(new FastForestRegression.Options { NumberOfThreads = 1 }));
251245

252246
// Train the model.
253247
var model = pipeline.Fit(data);

0 commit comments

Comments
 (0)