Skip to content

Commit f342403

Browse files
authored
Create model file V1 scenario tests (#2899)
* Combining ModelFile scenario with ModelLoading.
1 parent 909721e commit f342403

File tree

3 files changed

+102
-63
lines changed

3 files changed

+102
-63
lines changed

test/Microsoft.ML.Functional.Tests/Datasets/CommonColumns.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,23 @@ internal sealed class FeatureColumn
1313
}
1414

1515
/// <summary>
16-
/// A class to hold the output of FeatureContributionCalculator.
16+
/// A class to hold the output of FeatureContributionCalculator
1717
/// </summary>
1818
internal sealed class FeatureContributionOutput
1919
{
2020
public float[] FeatureContributions { get; set; }
2121
}
2222

2323
/// <summary>
24-
/// A class to hold the Score column.
24+
/// A class to hold a score column.
2525
/// </summary>
2626
internal sealed class ScoreColumn
2727
{
2828
public float Score { get; set; }
2929
}
3030

3131
/// <summary>
32-
/// A class to hold a vector Score column.
32+
/// A class to hold a vector score column.
3333
/// </summary>
3434
internal sealed class VectorScoreColumn
3535
{

test/Microsoft.ML.Functional.Tests/ModelLoading.cs renamed to test/Microsoft.ML.Functional.Tests/ModelFiles.cs

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
using System;
66
using System.IO;
7+
using System.IO.Compression;
78
using System.Linq;
89
using Microsoft.ML.Calibrators;
910
using Microsoft.ML.Data;
11+
using Microsoft.ML.Functional.Tests.Datasets;
1012
using Microsoft.ML.RunTests;
1113
using Microsoft.ML.Trainers.FastTree;
1214
using Microsoft.ML.Transforms;
@@ -15,9 +17,9 @@
1517

1618
namespace Microsoft.ML.Functional.Tests
1719
{
18-
public partial class ModelLoadingTests : TestDataPipeBase
20+
public partial class ModelFiles : TestDataPipeBase
1921
{
20-
public ModelLoadingTests(ITestOutputHelper output) : base(output)
22+
public ModelFiles(ITestOutputHelper output) : base(output)
2123
{
2224
}
2325

@@ -30,6 +32,101 @@ private class InputData
3032
public float[] Features { get; set; }
3133
}
3234

35+
/// <summary>
36+
/// Model Files: The (minimum) nuget version can be found in the model file.
37+
/// </summary>
38+
[Fact]
39+
public void DetermineNugetVersionFromModel()
40+
{
41+
var mlContext = new MLContext(seed: 1);
42+
43+
// Get the dataset.
44+
var data = mlContext.Data.LoadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);
45+
46+
// Create a pipeline to train on the housing data.
47+
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
48+
.Append(mlContext.Regression.Trainers.FastTree(
49+
new FastTreeRegressionTrainer.Options { NumberOfThreads = 1, NumberOfTrees = 10 }));
50+
51+
// Fit the pipeline.
52+
var model = pipeline.Fit(data);
53+
54+
// Save model to a file.
55+
var modelPath = DeleteOutputPath("determineNugetVersionFromModel.zip");
56+
mlContext.Model.Save(model, data.Schema, modelPath);
57+
58+
// Check that the version can be extracted from the model.
59+
var versionFileName = @"TrainingInfo" + Path.DirectorySeparatorChar + "Version.txt";
60+
using (ZipArchive archive = ZipFile.OpenRead(modelPath))
61+
{
62+
// The version of the entire model is kept in the version file.
63+
var versionPath = archive.Entries.First(x => x.FullName == versionFileName);
64+
Assert.NotNull(versionPath);
65+
using (var stream = versionPath.Open())
66+
using (var reader = new StreamReader(stream))
67+
{
68+
// The only line in the file is the version of the model.
69+
var line = reader.ReadLine();
70+
Assert.Equal(@"1.0.0.0", line);
71+
}
72+
}
73+
}
74+
75+
/// <summary>
76+
/// Model Files: Save a model, including all transforms, then load and make predictions.
77+
/// </summary>
78+
/// <remarks>
79+
/// Serves two scenarios:
80+
/// 1. I can train a model and save it to a file, including transforms.
81+
/// 2. Training and prediction happen in different processes (or even different machines).
82+
/// The actual test will not run in different processes, but will simulate the idea that the
83+
/// "communication pipe" is just a serialized model of some form.
84+
/// </remarks>
85+
[Fact]
86+
public void FitPipelineSaveModelAndPredict()
87+
{
88+
var mlContext = new MLContext(seed: 1);
89+
90+
// Get the dataset.
91+
var data = mlContext.Data.LoadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);
92+
93+
// Create a pipeline to train on the housing data.
94+
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
95+
.Append(mlContext.Regression.Trainers.FastTree(
96+
new FastTreeRegressionTrainer.Options { NumberOfThreads = 1, NumberOfTrees = 10 }));
97+
98+
// Fit the pipeline.
99+
var model = pipeline.Fit(data);
100+
101+
var modelPath = DeleteOutputPath("fitPipelineSaveModelAndPredict.zip");
102+
// Save model to a file.
103+
mlContext.Model.Save(model, data.Schema, modelPath);
104+
105+
// Load model from a file.
106+
ITransformer serializedModel;
107+
using (var file = File.OpenRead(modelPath))
108+
{
109+
serializedModel = mlContext.Model.Load(file, out var serializedSchema);
110+
CheckSameSchemas(data.Schema, serializedSchema);
111+
}
112+
113+
// Create prediction engine and test predictions.
114+
var originalPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, ScoreColumn>(model);
115+
var serializedPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, ScoreColumn>(serializedModel);
116+
117+
// Take a handful of examples out of the dataset and compute predictions.
118+
var dataEnumerator = mlContext.Data.CreateEnumerable<HousingRegression>(mlContext.Data.TakeRows(data, 5), false);
119+
foreach (var row in dataEnumerator)
120+
{
121+
var originalPrediction = originalPredictionEngine.Predict(row);
122+
var serializedPrediction = serializedPredictionEngine.Predict(row);
123+
// Check that the predictions are identical.
124+
Assert.Equal(originalPrediction.Score, serializedPrediction.Score);
125+
}
126+
127+
Done();
128+
}
129+
33130
[Fact]
34131
public void LoadModelAndExtractPredictor()
35132
{

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)