Skip to content

Commit dba3828

Browse files
authored
Add API option to store models on disk (instead of in memory); fix IEstimator memory leak (dotnet#269)
1 parent 001b8df commit dba3828

File tree

6 files changed

+136
-26
lines changed

6 files changed

+136
-26
lines changed

src/Microsoft.ML.Auto/API/ExperimentSettings.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.IO;
56
using System.Threading;
67

78
namespace Microsoft.ML.Auto
@@ -12,6 +13,13 @@ public class ExperimentSettings
1213
public CancellationToken CancellationToken { get; set; } = default;
1314

1415
/// <summary>
16+
/// This is a pointer to a directory where all models trained during the AutoML experiment will be saved.
17+
/// If null, models will be kept in memory instead of written to disk.
18+
/// (Please note: for an experiment with high runtime operating on a large dataset, opting to keep models in
19+
/// memory could cause a system to run out of memory.)
20+
/// </summary>
21+
public DirectoryInfo ModelDirectory { get; set; } = null;
22+
1523
/// This setting controls whether or not an AutoML experiment will make use of ML.NET-provided caching.
1624
/// If set to true, caching will be forced on for all pipelines. If set to false, caching will be forced off.
1725
/// If set to null (default value), AutoML will decide whether to enable caching for each model.

src/Microsoft.ML.Auto/API/RunResult.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.IO;
67
using System.Linq;
8+
using Microsoft.ML.Data;
79

810
namespace Microsoft.ML.Auto
911
{
1012
public sealed class RunResult<T>
1113
{
1214
public T ValidationMetrics { get; private set; }
13-
public ITransformer Model { get; private set; }
15+
public ITransformer Model { get { return _modelContainer.GetModel(); } }
1416
public Exception Exception { get; private set; }
1517
public string TrainerName { get; private set; }
1618
public int RuntimeInSeconds { get; private set; }
@@ -19,16 +21,17 @@ public sealed class RunResult<T>
1921
internal Pipeline Pipeline { get; private set; }
2022
internal int PipelineInferenceTimeInSeconds { get; private set; }
2123

22-
internal RunResult(
23-
ITransformer model,
24+
private readonly ModelContainer _modelContainer;
25+
26+
internal RunResult(ModelContainer modelContainer,
2427
T metrics,
2528
IEstimator<ITransformer> estimator,
2629
Pipeline pipeline,
2730
Exception exception,
2831
int runtimeInSeconds,
2932
int pipelineInferenceTimeInSeconds)
3033
{
31-
Model = model;
34+
_modelContainer = modelContainer;
3235
ValidationMetrics = metrics;
3336
Pipeline = pipeline;
3437
Estimator = estimator;

src/Microsoft.ML.Auto/Experiment/Experiment.cs

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Diagnostics;
8+
using System.IO;
9+
using System.Linq;
810
using System.Text;
911
using Microsoft.Data.DataView;
1012

@@ -22,9 +24,11 @@ internal class Experiment<T> where T : class
2224
private readonly ExperimentSettings _experimentSettings;
2325
private readonly IMetricsAgent<T> _metricsAgent;
2426
private readonly IEnumerable<TrainerName> _trainerWhitelist;
27+
private readonly DirectoryInfo _modelDirectory;
2528

2629
private IDataView _trainData;
2730
private IDataView _validationData;
31+
private ITransformer _preprocessorTransform;
2832

2933
List<RunResult<T>> iterationResults = new List<RunResult<T>>();
3034

@@ -57,17 +61,17 @@ public Experiment(MLContext context,
5761
_experimentSettings = experimentSettings;
5862
_metricsAgent = metricsAgent;
5963
_trainerWhitelist = trainerWhitelist;
64+
_modelDirectory = GetModelDirectory(_experimentSettings.ModelDirectory);
6065
}
6166

6267
public List<RunResult<T>> Execute()
6368
{
64-
ITransformer preprocessorTransform = null;
6569
if (_preFeaturizers != null)
6670
{
6771
// preprocess train and validation data
68-
preprocessorTransform = _preFeaturizers.Fit(_trainData);
69-
_trainData = preprocessorTransform.Transform(_trainData);
70-
_validationData = preprocessorTransform.Transform(_validationData);
72+
_preprocessorTransform = _preFeaturizers.Fit(_trainData);
73+
_trainData = _preprocessorTransform.Transform(_trainData);
74+
_validationData = _preprocessorTransform.Transform(_validationData);
7175
}
7276

7377
var stopwatch = Stopwatch.StartNew();
@@ -97,12 +101,6 @@ public List<RunResult<T>> Execute()
97101
// evaluate pipeline
98102
runResult = ProcessPipeline(pipeline);
99103

100-
if (_preFeaturizers != null)
101-
{
102-
runResult.Estimator = _preFeaturizers.Append(runResult.Estimator);
103-
runResult.Model = preprocessorTransform.Append(runResult.Model);
104-
}
105-
106104
runResult.RuntimeInSeconds = (int)iterationStopwatch.Elapsed.TotalSeconds;
107105
runResult.PipelineInferenceTimeInSeconds = (int)getPiplelineStopwatch.Elapsed.TotalSeconds;
108106
}
@@ -129,6 +127,33 @@ public List<RunResult<T>> Execute()
129127
return iterationResults;
130128
}
131129

130+
private static DirectoryInfo GetModelDirectory(DirectoryInfo rootDir)
131+
{
132+
if (rootDir == null)
133+
{
134+
return null;
135+
}
136+
var subdirs = rootDir.Exists ?
137+
new HashSet<string>(rootDir.EnumerateDirectories().Select(d => d.Name)) :
138+
new HashSet<string>();
139+
string experimentDir;
140+
for (var i = 0; ; i++)
141+
{
142+
experimentDir = $"experiment{i}";
143+
if (!subdirs.Contains(experimentDir))
144+
{
145+
break;
146+
}
147+
}
148+
var experimentDirFullPath = Path.Combine(rootDir.FullName, experimentDir);
149+
var experimentDirInfo = new DirectoryInfo(experimentDirFullPath);
150+
if (!experimentDirInfo.Exists)
151+
{
152+
experimentDirInfo.Create();
153+
}
154+
return experimentDirInfo;
155+
}
156+
132157
private void ReportProgress(RunResult<T> iterationResult)
133158
{
134159
try
@@ -141,6 +166,17 @@ private void ReportProgress(RunResult<T> iterationResult)
141166
}
142167
}
143168

169+
private FileInfo GetNextModelFileInfo()
170+
{
171+
if (_experimentSettings.ModelDirectory == null)
172+
{
173+
return null;
174+
}
175+
176+
return new FileInfo(Path.Combine(_modelDirectory.FullName,
177+
$"Model{_history.Count + 1}.zip"));
178+
}
179+
144180
private SuggestedPipelineResult<T> ProcessPipeline(SuggestedPipeline pipeline)
145181
{
146182
// run pipeline
@@ -150,22 +186,33 @@ private SuggestedPipelineResult<T> ProcessPipeline(SuggestedPipeline pipeline)
150186

151187
WriteDebugLog(DebugStream.RunResult, $"Processing pipeline {commandLineStr}.");
152188

153-
var pipelineEstimator = pipeline.ToEstimator();
154-
155189
SuggestedPipelineResult<T> runResult;
156190

157191
try
158192
{
159-
var pipelineModel = pipelineEstimator.Fit(_trainData);
160-
var scoredValidationData = pipelineModel.Transform(_validationData);
193+
var model = pipeline.ToEstimator().Fit(_trainData);
194+
var scoredValidationData = model.Transform(_validationData);
161195
var metrics = GetEvaluatedMetrics(scoredValidationData);
162196
var score = _metricsAgent.GetScore(metrics);
163-
runResult = new SuggestedPipelineResult<T>(metrics, pipelineEstimator, pipelineModel, pipeline, score, null);
197+
198+
var estimator = pipeline.ToEstimator();
199+
if (_preFeaturizers != null)
200+
{
201+
estimator = _preFeaturizers.Append(estimator);
202+
model = _preprocessorTransform.Append(model);
203+
}
204+
205+
var modelFileInfo = GetNextModelFileInfo();
206+
var modelContainer = modelFileInfo == null ?
207+
new ModelContainer(_context, model) :
208+
new ModelContainer(_context, modelFileInfo, model);
209+
210+
runResult = new SuggestedPipelineResult<T>(metrics, estimator, modelContainer, pipeline, score, null);
164211
}
165212
catch(Exception ex)
166213
{
167214
WriteDebugLog(DebugStream.Exception, $"{pipeline.Trainer} Crashed {ex}");
168-
runResult = new SuggestedPipelineResult<T>(null, pipelineEstimator, null, pipeline, 0, ex);
215+
runResult = new SuggestedPipelineResult<T>(null, pipeline.ToEstimator(), null, pipeline, 0, ex);
169216
}
170217

171218
// save pipeline run
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.IO;
6+
using Microsoft.ML.Data;
7+
8+
namespace Microsoft.ML.Auto
9+
{
10+
internal class ModelContainer
11+
{
12+
private readonly MLContext _mlContext;
13+
private readonly FileInfo _fileInfo;
14+
private readonly ITransformer _model;
15+
16+
internal ModelContainer(MLContext mlContext, ITransformer model)
17+
{
18+
_mlContext = mlContext;
19+
_model = model;
20+
}
21+
22+
internal ModelContainer(MLContext mlContext, FileInfo fileInfo, ITransformer model)
23+
{
24+
_mlContext = mlContext;
25+
_fileInfo = fileInfo;
26+
27+
// Write model to disk
28+
using (var fs = File.Create(fileInfo.FullName))
29+
{
30+
model.SaveTo(mlContext, fs);
31+
}
32+
}
33+
34+
public ITransformer GetModel()
35+
{
36+
// If model stored in memory, return it
37+
if (_model != null)
38+
{
39+
return _model;
40+
}
41+
42+
// Load model from disk
43+
ITransformer model;
44+
using (var stream = new FileStream(_fileInfo.FullName, FileMode.Open, FileAccess.Read, FileShare.Read))
45+
{
46+
model = _mlContext.Model.Load(stream);
47+
}
48+
return model;
49+
}
50+
}
51+
}

src/Microsoft.ML.Auto/Experiment/SuggestedPipelineResult.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.IO;
67

78
namespace Microsoft.ML.Auto
89
{
@@ -34,25 +35,25 @@ internal class SuggestedPipelineResult<T> : SuggestedPipelineResult
3435
{
3536
public readonly T EvaluatedMetrics;
3637
public IEstimator<ITransformer> Estimator { get; set; }
37-
public ITransformer Model { get; set; }
38+
public ModelContainer ModelContainer { get; set; }
3839
public Exception Exception { get; set; }
3940

4041
public int RuntimeInSeconds { get; set; }
4142
public int PipelineInferenceTimeInSeconds { get; set; }
4243

43-
public SuggestedPipelineResult(T evaluatedMetrics, IEstimator<ITransformer> estimator,
44-
ITransformer model, SuggestedPipeline pipeline, double score, Exception exception)
44+
public SuggestedPipelineResult(T evaluatedMetrics, IEstimator<ITransformer> estimator,
45+
ModelContainer modelContainer, SuggestedPipeline pipeline, double score, Exception exception)
4546
: base(pipeline, score, exception == null)
4647
{
4748
EvaluatedMetrics = evaluatedMetrics;
4849
Estimator = estimator;
49-
Model = model;
50+
ModelContainer = modelContainer;
5051
Exception = exception;
5152
}
5253

5354
public RunResult<T> ToIterationResult()
5455
{
55-
return new RunResult<T>(Model, EvaluatedMetrics, Estimator, Pipeline.ToPipeline(), Exception, RuntimeInSeconds, PipelineInferenceTimeInSeconds);
56+
return new RunResult<T>(ModelContainer, EvaluatedMetrics, Estimator, Pipeline.ToPipeline(), Exception, RuntimeInSeconds, PipelineInferenceTimeInSeconds);
5657
}
5758
}
5859
}

src/Test/AutoFitTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public void AutoFitMultiTest()
4141
.CreateMulticlassClassificationExperiment(0)
4242
.Execute(trainData, validationData, new ColumnInformation() { LabelColumn = DatasetUtil.TrivialMulticlassDatasetLabel });
4343

44-
Assert.IsTrue(result.Max(i => i.ValidationMetrics.AccuracyMacro) > 0.80);
44+
Assert.IsTrue(result.Max(i => i.ValidationMetrics.AccuracyMicro) >= 0.8);
4545
}
4646

4747
[TestMethod]

0 commit comments

Comments
 (0)