Skip to content

Commit e1fe0a8

Browse files
Let ImageLoadingTransformer dispose the last image it loads (#5056)
* Let Image Loader dispose the last image it loaded, instead of image resizer * Added Tests * Added a PredictionEngine to tests * Change using statement to reuse test class
1 parent db84060 commit e1fe0a8

File tree

4 files changed

+142
-2
lines changed

4 files changed

+142
-2
lines changed

src/Microsoft.ML.ImageAnalytics/ImageLoader.cs

+13-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,17 @@ private Delegate MakeGetterImageDataViewType(DataViewRow input, int iinfo, Func<
217217
{
218218
Contracts.AssertValue(input);
219219
Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
220-
disposer = null;
220+
var lastImage = default(Bitmap);
221+
222+
disposer = () =>
223+
{
224+
if (lastImage != null)
225+
{
226+
lastImage.Dispose();
227+
lastImage = null;
228+
}
229+
};
230+
221231
var getSrc = input.GetGetter<ReadOnlyMemory<char>>(input.Schema[ColMapNewToOld[iinfo]]);
222232
ReadOnlyMemory<char> src = default;
223233
ValueGetter<Bitmap> del =
@@ -247,6 +257,8 @@ private Delegate MakeGetterImageDataViewType(DataViewRow input, int iinfo, Func<
247257
if (dst.PixelFormat == System.Drawing.Imaging.PixelFormat.DontCare)
248258
throw Host.Except($"Failed to load image {src.ToString()}.");
249259
}
260+
261+
lastImage = dst;
250262
};
251263

252264
return del;

src/Microsoft.ML.ImageAnalytics/ImageResizer.cs

-1
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
285285
{
286286
if (src != null)
287287
{
288-
src.Dispose();
289288
src = null;
290289
}
291290
};

test/Microsoft.ML.Tests/ImagesTests.cs

+98
Original file line numberDiff line numberDiff line change
@@ -908,5 +908,103 @@ private class DataPoint
908908
[VectorType(InputSize)]
909909
public double[] Features { get; set; }
910910
}
911+
912+
public class InMemoryImage
913+
{
914+
[ImageType(229, 299)]
915+
public Bitmap LoadedImage;
916+
public string Label;
917+
918+
public static List<InMemoryImage> LoadFromTsv(MLContext mlContext, string tsvPath, string imageFolder)
919+
{
920+
var inMemoryImages = new List<InMemoryImage>();
921+
var tsvFile = mlContext.Data.LoadFromTextFile(tsvPath, columns: new[]
922+
{
923+
new TextLoader.Column("ImagePath", DataKind.String, 0),
924+
new TextLoader.Column("Label", DataKind.String, 1),
925+
}
926+
);
927+
928+
using (var cursor = tsvFile.GetRowCursorForAllColumns())
929+
{
930+
var pathBuffer = default(ReadOnlyMemory<char>);
931+
var labelBuffer = default(ReadOnlyMemory<char>);
932+
var pathGetter = cursor.GetGetter<ReadOnlyMemory<char>>(tsvFile.Schema["ImagePath"]);
933+
var labelGetter = cursor.GetGetter<ReadOnlyMemory<char>>(tsvFile.Schema["Label"]);
934+
while (cursor.MoveNext())
935+
{
936+
pathGetter(ref pathBuffer);
937+
labelGetter(ref labelBuffer);
938+
939+
var label = labelBuffer.ToString();
940+
var fileName = pathBuffer.ToString();
941+
var imagePath = Path.Combine(imageFolder, fileName);
942+
943+
inMemoryImages.Add(
944+
new InMemoryImage()
945+
{
946+
Label = label,
947+
LoadedImage = (Bitmap)Image.FromFile(imagePath)
948+
}
949+
);
950+
}
951+
}
952+
953+
return inMemoryImages;
954+
955+
}
956+
}
957+
958+
public class InMemoryImageOutput : InMemoryImage
959+
{
960+
[ImageType(100, 100)]
961+
public Bitmap ResizedImage;
962+
}
963+
964+
[Fact]
965+
public void ResizeInMemoryImages()
966+
{
967+
var mlContext = new MLContext(seed: 1);
968+
var dataFile = GetDataPath("images/images.tsv");
969+
var imageFolder = Path.GetDirectoryName(dataFile);
970+
var dataObjects = InMemoryImage.LoadFromTsv(mlContext, dataFile, imageFolder);
971+
972+
var dataView = mlContext.Data.LoadFromEnumerable<InMemoryImage>(dataObjects);
973+
var pipeline = mlContext.Transforms.ResizeImages("ResizedImage", 100, 100, nameof(InMemoryImage.LoadedImage));
974+
975+
// Check that the output is resized, and that it didn't resize the original image object
976+
var model = pipeline.Fit(dataView);
977+
var resizedDV = model.Transform(dataView);
978+
var rowView = resizedDV.Preview().RowView;
979+
var resizedImage = (Bitmap)rowView.First().Values.Last().Value;
980+
Assert.Equal(100, resizedImage.Height);
981+
Assert.NotEqual(100, dataObjects[0].LoadedImage.Height);
982+
983+
// Also check usage of prediction Engine
984+
// And that the references to the original image objects aren't lost
985+
var predEngine = mlContext.Model.CreatePredictionEngine<InMemoryImage, InMemoryImageOutput>(model);
986+
for(int i = 0; i < dataObjects.Count(); i++)
987+
{
988+
var prediction = predEngine.Predict(dataObjects[i]);
989+
Assert.Equal(100, prediction.ResizedImage.Height);
990+
Assert.NotEqual(100, prediction.LoadedImage.Height);
991+
Assert.True(prediction.LoadedImage == dataObjects[i].LoadedImage);
992+
Assert.False(prediction.ResizedImage == dataObjects[i].LoadedImage);
993+
}
994+
995+
// Check that the last in-memory image hasn't been disposed
996+
// By running ResizeImageTransformer (see https://github.com/dotnet/machinelearning/issues/4126)
997+
bool disposed = false;
998+
try
999+
{
1000+
int i = dataObjects.Last().LoadedImage.Height;
1001+
}
1002+
catch
1003+
{
1004+
disposed = true;
1005+
}
1006+
1007+
Assert.False(disposed, "The last in memory image had been disposed by running ResizeImageTransformer");
1008+
}
9111009
}
9121010
}

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

+31
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.Drawing;
78
using System.IO;
89
using System.IO.Compression;
910
using System.Linq;
@@ -18,6 +19,7 @@
1819
using Microsoft.ML.Transforms;
1920
using Microsoft.ML.Transforms.Image;
2021
using Microsoft.ML.TensorFlow;
22+
using InMemoryImage = Microsoft.ML.Tests.ImageTests.InMemoryImage;
2123
using Xunit;
2224
using Xunit.Abstractions;
2325
using static Microsoft.ML.DataOperationsCatalog;
@@ -1126,6 +1128,35 @@ public void TensorFlowTransformCifarSavedModel()
11261128
}
11271129
}
11281130

1131+
// This test doesn't really check the values of the results
1132+
// Simply checks that CrossValidation is doable with in-memory images
1133+
// See issue https://github.com/dotnet/machinelearning/issues/4126
1134+
[TensorFlowFact]
1135+
public void TensorFlowTransformCifarCrossValidationWithInMemoryImages()
1136+
{
1137+
var modelLocation = "cifar_saved_model";
1138+
var mlContext = new MLContext(seed: 1);
1139+
using var tensorFlowModel = mlContext.Model.LoadTensorFlowModel(modelLocation);
1140+
var schema = tensorFlowModel.GetInputSchema();
1141+
Assert.True(schema.TryGetColumnIndex("Input", out int column));
1142+
var type = (VectorDataViewType)schema[column].Type;
1143+
var imageHeight = type.Dimensions[0];
1144+
var imageWidth = type.Dimensions[1];
1145+
var dataFile = GetDataPath("images/images.tsv");
1146+
var imageFolder = Path.GetDirectoryName(dataFile);
1147+
var dataObjects = InMemoryImage.LoadFromTsv(mlContext, dataFile, imageFolder);
1148+
1149+
var dataView = mlContext.Data.LoadFromEnumerable<InMemoryImage>(dataObjects);
1150+
var pipeline = mlContext.Transforms.ResizeImages("ResizedImage", imageWidth, imageHeight, nameof(InMemoryImage.LoadedImage))
1151+
.Append(mlContext.Transforms.ExtractPixels("Input", "ResizedImage", interleavePixelColors: true))
1152+
.Append(tensorFlowModel.ScoreTensorFlowModel("Output", "Input"))
1153+
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
1154+
.Append(mlContext.MulticlassClassification.Trainers.NaiveBayes("Label", "Output"));
1155+
1156+
var cross = mlContext.MulticlassClassification.CrossValidate(dataView, pipeline, 2);
1157+
Assert.Equal(2, cross.Count());
1158+
}
1159+
11291160
// This test has been created as result of https://github.com/dotnet/machinelearning/issues/2156.
11301161
[TensorFlowFact]
11311162
public void TensorFlowGettingSchemaMultipleTimes()

0 commit comments

Comments
 (0)