Skip to content

Commit 9a8c46c

Browse files
committed
Fixed memory leaks from OnnxTransformer (#5518)
* Fixed memory leak from OnnxTransformer and related x86 build fixes * Reverting x86 build related fixes to focus only on the memory leaks * Updated docs * Reverted OnnxRuntimeOutputCatcher to private class * Addressed code review comments * Refactored OnnxTransform back to using MapperBase based on code review comments
1 parent f151a4a commit 9a8c46c

File tree

5 files changed

+105
-57
lines changed

5 files changed

+105
-57
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/ModelOperations/OnnxConversion.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ public static void Example()
8282
//Create the pipeline using onnx file.
8383
var onnxModelPath = "your_path_to_sample_onnx_conversion_1.onnx";
8484
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
85-
var onnxTransformer = onnxEstimator.Fit(trainTestOriginalData.TrainSet);
85+
//Make sure to either use the 'using' clause or explicitly dispose the returned onnxTransformer to prevent memory leaks
86+
using var onnxTransformer = onnxEstimator.Fit(trainTestOriginalData.TrainSet);
8687

8788
//Inference the testset
8889
var output = transformer.Transform(trainTestOriginalData.TestSet);

src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ private protected abstract class MapperBase : IRowMapper
5858
{
5959
protected readonly IHost Host;
6060
protected readonly DataViewSchema InputSchema;
61-
private readonly Lazy<DataViewSchema.DetachedColumn[]> _outputColumns;
61+
protected readonly Lazy<DataViewSchema.DetachedColumn[]> OutputColumns;
6262
private readonly RowToRowTransformerBase _parent;
6363

6464
protected MapperBase(IHost host, DataViewSchema inputSchema, RowToRowTransformerBase parent)
@@ -68,21 +68,21 @@ protected MapperBase(IHost host, DataViewSchema inputSchema, RowToRowTransformer
6868
Host = host;
6969
InputSchema = inputSchema;
7070
_parent = parent;
71-
_outputColumns = new Lazy<DataViewSchema.DetachedColumn[]>(GetOutputColumnsCore);
71+
OutputColumns = new Lazy<DataViewSchema.DetachedColumn[]>(GetOutputColumnsCore);
7272
}
7373

7474
protected abstract DataViewSchema.DetachedColumn[] GetOutputColumnsCore();
7575

76-
DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => _outputColumns.Value;
76+
DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => OutputColumns.Value;
7777

78-
Delegate[] IRowMapper.CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
78+
public virtual Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
7979
{
8080
// REVIEW: it used to be that the mapper's input schema in the constructor was required to be reference-equal to the schema
8181
// of the input row.
8282
// It still has to be the same schema, but because we may make a transition from lazy to eager schema, the reference-equality
8383
// is no longer always possible. So, we relax the assert as below.
8484
Contracts.Assert(input.Schema == InputSchema);
85-
int n = _outputColumns.Value.Length;
85+
int n = OutputColumns.Value.Length;
8686
var result = new Delegate[n];
8787
var disposers = new Action[n];
8888
for (int i = 0; i < n; i++)

src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs

+56-18
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,10 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a
484484
private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
485485

486486
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
487+
=> throw new NotImplementedException("This should never be called!");
488+
489+
private Delegate CreateGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, OnnxRuntimeOutputCacher outputCacher)
487490
{
488-
disposer = null;
489491
Host.AssertValue(input);
490492

491493
var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
@@ -495,26 +497,59 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
495497
var elemRawType = vectorType.ItemType.RawType;
496498
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
497499
if (vectorType.ItemType is TextDataViewType)
498-
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames);
500+
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
499501
else
500-
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames);
502+
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
501503
}
502504
else
503505
{
504506
var type = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].DataViewType.RawType;
505507
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
506-
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames);
508+
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
509+
}
510+
}
511+
512+
public override Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
513+
{
514+
Contracts.Assert(input.Schema == InputSchema);
515+
516+
OnnxRuntimeOutputCacher outputCacher = new OnnxRuntimeOutputCacher();
517+
518+
int n = OutputColumns.Value.Length;
519+
var result = new Delegate[n];
520+
for (int i = 0; i < n; i++)
521+
{
522+
if (!activeOutput(i))
523+
continue;
524+
result[i] = CreateGetter(input, i, activeOutput, outputCacher);
507525
}
526+
disposer = () =>
527+
{
528+
outputCacher.Dispose();
529+
};
530+
return result;
508531
}
509532

510-
private class OnnxRuntimeOutputCacher
533+
private sealed class OnnxRuntimeOutputCacher : IDisposable
511534
{
512535
public long Position;
513-
public Dictionary<string, NamedOnnxValue> Outputs;
536+
public Dictionary<string, DisposableNamedOnnxValue> Outputs;
537+
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> OutputOnnxValues;
538+
514539
public OnnxRuntimeOutputCacher()
515540
{
516541
Position = -1;
517-
Outputs = new Dictionary<string, NamedOnnxValue>();
542+
Outputs = new Dictionary<string, DisposableNamedOnnxValue>();
543+
}
544+
545+
private bool _isDisposed;
546+
547+
public void Dispose()
548+
{
549+
if (_isDisposed)
550+
return;
551+
OutputOnnxValues?.Dispose();
552+
_isDisposed = true;
518553
}
519554
}
520555

@@ -529,21 +564,22 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed
529564
inputNameOnnxValues.Add(srcNamedOnnxValueGetters[i].GetNamedOnnxValue());
530565
}
531566

532-
var outputNamedOnnxValues = _parent.Model.Run(inputNameOnnxValues);
533-
Contracts.Assert(outputNamedOnnxValues.Count > 0);
567+
outputCache.OutputOnnxValues?.Dispose();
568+
outputCache.OutputOnnxValues = _parent.Model.Run(inputNameOnnxValues);
569+
Contracts.Assert(outputCache.OutputOnnxValues.Count > 0);
534570

535-
foreach (var outputNameOnnxValue in outputNamedOnnxValues)
571+
foreach (var outputNameOnnxValue in outputCache.OutputOnnxValues)
536572
{
537573
outputCache.Outputs[outputNameOnnxValue.Name] = outputNameOnnxValue;
538574
}
539575
outputCache.Position = position;
540576
}
541577
}
542578

543-
private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
579+
private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
580+
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
544581
{
545582
Host.AssertValue(input);
546-
var outputCacher = new OnnxRuntimeOutputCacher();
547583
ValueGetter<VBuffer<T>> valueGetter = (ref VBuffer<T> dst) =>
548584
{
549585
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
@@ -558,10 +594,11 @@ private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxVal
558594
return valueGetter;
559595
}
560596

561-
private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
597+
private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
598+
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
562599
{
563600
Host.AssertValue(input);
564-
var outputCacher = new OnnxRuntimeOutputCacher();
601+
565602
ValueGetter<VBuffer<ReadOnlyMemory<char>>> valueGetter = (ref VBuffer<ReadOnlyMemory<char>> dst) =>
566603
{
567604
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
@@ -580,14 +617,15 @@ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnx
580617
return valueGetter;
581618
}
582619

583-
private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
620+
private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
621+
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
584622
{
585623
Host.AssertValue(input);
586-
var outputCache = new OnnxRuntimeOutputCacher();
624+
587625
ValueGetter<T> valueGetter = (ref T dst) =>
588626
{
589-
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache);
590-
var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]];
627+
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
628+
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
591629
var trueValue = namedOnnxValue.AsEnumerable<NamedOnnxValue>().Select(value => value.AsDictionary<string, float>());
592630
var caster = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].Caster;
593631
dst = (T)caster(namedOnnxValue);

src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs

+40-31
Original file line numberDiff line numberDiff line change
@@ -198,40 +198,49 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
198198
_session = new InferenceSession(modelFile);
199199
}
200200

201-
// Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
202-
// doesn't expose full type information via its C# APIs.
203-
ModelFile = modelFile;
204-
var model = new OnnxCSharpToProtoWrapper.ModelProto();
205-
using (var modelStream = File.OpenRead(modelFile))
206-
using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
207-
model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);
208-
209-
// Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
210-
var inputTypePool = new Dictionary<string, DataViewType>();
211-
foreach (var valueInfo in model.Graph.Input)
212-
inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
213-
214-
var initializerTypePool = new Dictionary<string, DataViewType>();
215-
foreach (var valueInfo in model.Graph.Initializer)
216-
initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);
217-
218-
var outputTypePool = new Dictionary<string, DataViewType>();
219-
// Build casters which maps NamedOnnxValue to .NET objects.
220-
var casterPool = new Dictionary<string, Func<NamedOnnxValue, object>>();
221-
foreach (var valueInfo in model.Graph.Output)
201+
try
222202
{
223-
outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
224-
casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
225-
}
203+
// Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
204+
// doesn't expose full type information via its C# APIs.
205+
ModelFile = modelFile;
206+
var model = new OnnxCSharpToProtoWrapper.ModelProto();
207+
using (var modelStream = File.OpenRead(modelFile))
208+
using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
209+
model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);
210+
211+
// Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
212+
var inputTypePool = new Dictionary<string, DataViewType>();
213+
foreach (var valueInfo in model.Graph.Input)
214+
inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
215+
216+
var initializerTypePool = new Dictionary<string, DataViewType>();
217+
foreach (var valueInfo in model.Graph.Initializer)
218+
initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);
219+
220+
var outputTypePool = new Dictionary<string, DataViewType>();
221+
// Build casters which maps NamedOnnxValue to .NET objects.
222+
var casterPool = new Dictionary<string, Func<NamedOnnxValue, object>>();
223+
foreach (var valueInfo in model.Graph.Output)
224+
{
225+
outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
226+
casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
227+
}
226228

227-
var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
228-
var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
229-
var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);
229+
var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
230+
var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
231+
var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);
230232

231-
// Create a view to the used ONNX model from ONNXRuntime's perspective.
232-
ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);
233+
// Create a view to the used ONNX model from ONNXRuntime's perspective.
234+
ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);
233235

234-
Graph = model.Graph;
236+
Graph = model.Graph;
237+
}
238+
catch
239+
{
240+
_session.Dispose();
241+
_session = null;
242+
throw;
243+
}
235244
}
236245

237246
private List<OnnxVariableInfo> GetOnnxVariablesFromMetadata(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata,
@@ -350,7 +359,7 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, int? gpuDeviceId = nu
350359
/// </summary>
351360
/// <param name="inputNamedOnnxValues">The NamedOnnxValues to score.</param>
352361
/// <returns>Resulting output NamedOnnxValues list.</returns>
353-
public IReadOnlyCollection<NamedOnnxValue> Run(List<NamedOnnxValue> inputNamedOnnxValues)
362+
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(List<NamedOnnxValue> inputNamedOnnxValues)
354363
{
355364
return _session.Run(inputNamedOnnxValues);
356365
}

test/Microsoft.ML.Tests/OnnxConversionTest.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ public void RemoveVariablesInPipelineTest()
788788
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
789789
.Append(mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: "Label", featureColumnName: "Features", numberOfLeaves: 2, numberOfTrees: 1, minimumExampleCountPerLeaf: 2));
790790

791-
var model = pipeline.Fit(data);
791+
using var model = pipeline.Fit(data);
792792
var transformedData = model.Transform(data);
793793

794794
var onnxConversionContext = new OnnxContextImpl(mlContext, "A Simple Pipeline", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable);
@@ -2029,7 +2029,7 @@ private void TestPipeline<TLastTransformer, TRow>(EstimatorChain<TLastTransforme
20292029
private void TestPipeline<TLastTransformer>(EstimatorChain<TLastTransformer> pipeline, IDataView dataView, string onnxFileName, ColumnComparison[] columnsToCompare, string onnxTxtName = null, string onnxTxtSubDir = null)
20302030
where TLastTransformer : class, ITransformer
20312031
{
2032-
var model = pipeline.Fit(dataView);
2032+
using var model = pipeline.Fit(dataView);
20332033
var transformedData = model.Transform(dataView);
20342034
var onnxModel = ML.Model.ConvertToOnnxProtobuf(model, dataView);
20352035

0 commit comments

Comments
 (0)