4
4
5
5
using System ;
6
6
using System . IO ;
7
+ using System . IO . Compression ;
7
8
using System . Linq ;
8
9
using Microsoft . ML . Calibrators ;
9
10
using Microsoft . ML . Data ;
11
+ using Microsoft . ML . Functional . Tests . Datasets ;
10
12
using Microsoft . ML . RunTests ;
11
13
using Microsoft . ML . Trainers . FastTree ;
12
14
using Microsoft . ML . Transforms ;
15
17
16
18
namespace Microsoft . ML . Functional . Tests
17
19
{
18
- public partial class ModelLoadingTests : TestDataPipeBase
20
+ public partial class ModelFiles : TestDataPipeBase
19
21
{
20
- public ModelLoadingTests ( ITestOutputHelper output ) : base ( output )
22
+ public ModelFiles ( ITestOutputHelper output ) : base ( output )
21
23
{
22
24
}
23
25
@@ -30,6 +32,101 @@ private class InputData
30
32
public float [ ] Features { get ; set ; }
31
33
}
32
34
35
+ /// <summary>
36
+ /// Model Files: The (minimum) nuget version can be found in the model file.
37
+ /// </summary>
38
+ [ Fact ]
39
+ public void DetermineNugetVersionFromModel ( )
40
+ {
41
+ var mlContext = new MLContext ( seed : 1 ) ;
42
+
43
+ // Get the dataset.
44
+ var data = mlContext . Data . LoadFromTextFile < HousingRegression > ( GetDataPath ( TestDatasets . housing . trainFilename ) , hasHeader : true ) ;
45
+
46
+ // Create a pipeline to train on the housing data.
47
+ var pipeline = mlContext . Transforms . Concatenate ( "Features" , HousingRegression . Features )
48
+ . Append ( mlContext . Regression . Trainers . FastTree (
49
+ new FastTreeRegressionTrainer . Options { NumberOfThreads = 1 , NumberOfTrees = 10 } ) ) ;
50
+
51
+ // Fit the pipeline.
52
+ var model = pipeline . Fit ( data ) ;
53
+
54
+ // Save model to a file.
55
+ var modelPath = DeleteOutputPath ( "determineNugetVersionFromModel.zip" ) ;
56
+ mlContext . Model . Save ( model , data . Schema , modelPath ) ;
57
+
58
+ // Check that the version can be extracted from the model.
59
+ var versionFileName = @"TrainingInfo" + Path . DirectorySeparatorChar + "Version.txt" ;
60
+ using ( ZipArchive archive = ZipFile . OpenRead ( modelPath ) )
61
+ {
62
+ // The version of the entire model is kept in the version file.
63
+ var versionPath = archive . Entries . First ( x => x . FullName == versionFileName ) ;
64
+ Assert . NotNull ( versionPath ) ;
65
+ using ( var stream = versionPath . Open ( ) )
66
+ using ( var reader = new StreamReader ( stream ) )
67
+ {
68
+ // The only line in the file is the version of the model.
69
+ var line = reader . ReadLine ( ) ;
70
+ Assert . Equal ( @"1.0.0.0" , line ) ;
71
+ }
72
+ }
73
+ }
74
+
75
+ /// <summary>
76
+ /// Model Files: Save a model, including all transforms, then load and make predictions.
77
+ /// </summary>
78
+ /// <remarks>
79
+ /// Serves two scenarios:
80
+ /// 1. I can train a model and save it to a file, including transforms.
81
+ /// 2. Training and prediction happen in different processes (or even different machines).
82
+ /// The actual test will not run in different processes, but will simulate the idea that the
83
+ /// "communication pipe" is just a serialized model of some form.
84
+ /// </remarks>
85
+ [ Fact ]
86
+ public void FitPipelineSaveModelAndPredict ( )
87
+ {
88
+ var mlContext = new MLContext ( seed : 1 ) ;
89
+
90
+ // Get the dataset.
91
+ var data = mlContext . Data . LoadFromTextFile < HousingRegression > ( GetDataPath ( TestDatasets . housing . trainFilename ) , hasHeader : true ) ;
92
+
93
+ // Create a pipeline to train on the housing data.
94
+ var pipeline = mlContext . Transforms . Concatenate ( "Features" , HousingRegression . Features )
95
+ . Append ( mlContext . Regression . Trainers . FastTree (
96
+ new FastTreeRegressionTrainer . Options { NumberOfThreads = 1 , NumberOfTrees = 10 } ) ) ;
97
+
98
+ // Fit the pipeline.
99
+ var model = pipeline . Fit ( data ) ;
100
+
101
+ var modelPath = DeleteOutputPath ( "fitPipelineSaveModelAndPredict.zip" ) ;
102
+ // Save model to a file.
103
+ mlContext . Model . Save ( model , data . Schema , modelPath ) ;
104
+
105
+ // Load model from a file.
106
+ ITransformer serializedModel ;
107
+ using ( var file = File . OpenRead ( modelPath ) )
108
+ {
109
+ serializedModel = mlContext . Model . Load ( file , out var serializedSchema ) ;
110
+ CheckSameSchemas ( data . Schema , serializedSchema ) ;
111
+ }
112
+
113
+ // Create prediction engine and test predictions.
114
+ var originalPredictionEngine = mlContext . Model . CreatePredictionEngine < HousingRegression , ScoreColumn > ( model ) ;
115
+ var serializedPredictionEngine = mlContext . Model . CreatePredictionEngine < HousingRegression , ScoreColumn > ( serializedModel ) ;
116
+
117
+ // Take a handful of examples out of the dataset and compute predictions.
118
+ var dataEnumerator = mlContext . Data . CreateEnumerable < HousingRegression > ( mlContext . Data . TakeRows ( data , 5 ) , false ) ;
119
+ foreach ( var row in dataEnumerator )
120
+ {
121
+ var originalPrediction = originalPredictionEngine . Predict ( row ) ;
122
+ var serializedPrediction = serializedPredictionEngine . Predict ( row ) ;
123
+ // Check that the predictions are identical.
124
+ Assert . Equal ( originalPrediction . Score , serializedPrediction . Score ) ;
125
+ }
126
+
127
+ Done ( ) ;
128
+ }
129
+
33
130
[ Fact ]
34
131
public void LoadModelAndExtractPredictor ( )
35
132
{
0 commit comments