5
5
using System ;
6
6
using System . Collections . Generic ;
7
7
using System . Diagnostics ;
8
+ using System . IO ;
9
+ using System . Linq ;
8
10
using System . Text ;
9
11
using Microsoft . Data . DataView ;
10
12
@@ -22,9 +24,11 @@ internal class Experiment<T> where T : class
22
24
private readonly ExperimentSettings _experimentSettings ;
23
25
private readonly IMetricsAgent < T > _metricsAgent ;
24
26
private readonly IEnumerable < TrainerName > _trainerWhitelist ;
27
+ private readonly DirectoryInfo _modelDirectory ;
25
28
26
29
private IDataView _trainData ;
27
30
private IDataView _validationData ;
31
+ private ITransformer _preprocessorTransform ;
28
32
29
33
List < RunResult < T > > iterationResults = new List < RunResult < T > > ( ) ;
30
34
@@ -57,17 +61,17 @@ public Experiment(MLContext context,
57
61
_experimentSettings = experimentSettings ;
58
62
_metricsAgent = metricsAgent ;
59
63
_trainerWhitelist = trainerWhitelist ;
64
+ _modelDirectory = GetModelDirectory ( _experimentSettings . ModelDirectory ) ;
60
65
}
61
66
62
67
public List < RunResult < T > > Execute ( )
63
68
{
64
- ITransformer preprocessorTransform = null ;
65
69
if ( _preFeaturizers != null )
66
70
{
67
71
// preprocess train and validation data
68
- preprocessorTransform = _preFeaturizers . Fit ( _trainData ) ;
69
- _trainData = preprocessorTransform . Transform ( _trainData ) ;
70
- _validationData = preprocessorTransform . Transform ( _validationData ) ;
72
+ _preprocessorTransform = _preFeaturizers . Fit ( _trainData ) ;
73
+ _trainData = _preprocessorTransform . Transform ( _trainData ) ;
74
+ _validationData = _preprocessorTransform . Transform ( _validationData ) ;
71
75
}
72
76
73
77
var stopwatch = Stopwatch . StartNew ( ) ;
@@ -97,12 +101,6 @@ public List<RunResult<T>> Execute()
97
101
// evaluate pipeline
98
102
runResult = ProcessPipeline ( pipeline ) ;
99
103
100
- if ( _preFeaturizers != null )
101
- {
102
- runResult . Estimator = _preFeaturizers . Append ( runResult . Estimator ) ;
103
- runResult . Model = preprocessorTransform . Append ( runResult . Model ) ;
104
- }
105
-
106
104
runResult . RuntimeInSeconds = ( int ) iterationStopwatch . Elapsed . TotalSeconds ;
107
105
runResult . PipelineInferenceTimeInSeconds = ( int ) getPiplelineStopwatch . Elapsed . TotalSeconds ;
108
106
}
@@ -129,6 +127,33 @@ public List<RunResult<T>> Execute()
129
127
return iterationResults ;
130
128
}
131
129
130
+ private static DirectoryInfo GetModelDirectory ( DirectoryInfo rootDir )
131
+ {
132
+ if ( rootDir == null )
133
+ {
134
+ return null ;
135
+ }
136
+ var subdirs = rootDir . Exists ?
137
+ new HashSet < string > ( rootDir . EnumerateDirectories ( ) . Select ( d => d . Name ) ) :
138
+ new HashSet < string > ( ) ;
139
+ string experimentDir ;
140
+ for ( var i = 0 ; ; i ++ )
141
+ {
142
+ experimentDir = $ "experiment{ i } ";
143
+ if ( ! subdirs . Contains ( experimentDir ) )
144
+ {
145
+ break ;
146
+ }
147
+ }
148
+ var experimentDirFullPath = Path . Combine ( rootDir . FullName , experimentDir ) ;
149
+ var experimentDirInfo = new DirectoryInfo ( experimentDirFullPath ) ;
150
+ if ( ! experimentDirInfo . Exists )
151
+ {
152
+ experimentDirInfo . Create ( ) ;
153
+ }
154
+ return experimentDirInfo ;
155
+ }
156
+
132
157
private void ReportProgress ( RunResult < T > iterationResult )
133
158
{
134
159
try
@@ -141,6 +166,17 @@ private void ReportProgress(RunResult<T> iterationResult)
141
166
}
142
167
}
143
168
169
+ private FileInfo GetNextModelFileInfo ( )
170
+ {
171
+ if ( _experimentSettings . ModelDirectory == null )
172
+ {
173
+ return null ;
174
+ }
175
+
176
+ return new FileInfo ( Path . Combine ( _modelDirectory . FullName ,
177
+ $ "Model{ _history . Count + 1 } .zip") ) ;
178
+ }
179
+
144
180
private SuggestedPipelineResult < T > ProcessPipeline ( SuggestedPipeline pipeline )
145
181
{
146
182
// run pipeline
@@ -150,22 +186,33 @@ private SuggestedPipelineResult<T> ProcessPipeline(SuggestedPipeline pipeline)
150
186
151
187
WriteDebugLog ( DebugStream . RunResult , $ "Processing pipeline { commandLineStr } .") ;
152
188
153
- var pipelineEstimator = pipeline . ToEstimator ( ) ;
154
-
155
189
SuggestedPipelineResult < T > runResult ;
156
190
157
191
try
158
192
{
159
- var pipelineModel = pipelineEstimator . Fit ( _trainData ) ;
160
- var scoredValidationData = pipelineModel . Transform ( _validationData ) ;
193
+ var model = pipeline . ToEstimator ( ) . Fit ( _trainData ) ;
194
+ var scoredValidationData = model . Transform ( _validationData ) ;
161
195
var metrics = GetEvaluatedMetrics ( scoredValidationData ) ;
162
196
var score = _metricsAgent . GetScore ( metrics ) ;
163
- runResult = new SuggestedPipelineResult < T > ( metrics , pipelineEstimator , pipelineModel , pipeline , score , null ) ;
197
+
198
+ var estimator = pipeline . ToEstimator ( ) ;
199
+ if ( _preFeaturizers != null )
200
+ {
201
+ estimator = _preFeaturizers . Append ( estimator ) ;
202
+ model = _preprocessorTransform . Append ( model ) ;
203
+ }
204
+
205
+ var modelFileInfo = GetNextModelFileInfo ( ) ;
206
+ var modelContainer = modelFileInfo == null ?
207
+ new ModelContainer ( _context , model ) :
208
+ new ModelContainer ( _context , modelFileInfo , model ) ;
209
+
210
+ runResult = new SuggestedPipelineResult < T > ( metrics , estimator , modelContainer , pipeline , score , null ) ;
164
211
}
165
212
catch ( Exception ex )
166
213
{
167
214
WriteDebugLog ( DebugStream . Exception , $ "{ pipeline . Trainer } Crashed { ex } ") ;
168
- runResult = new SuggestedPipelineResult < T > ( null , pipelineEstimator , null , pipeline , 0 , ex ) ;
215
+ runResult = new SuggestedPipelineResult < T > ( null , pipeline . ToEstimator ( ) , null , pipeline , 0 , ex ) ;
169
216
}
170
217
171
218
// save pipeline run
0 commit comments