Skip to content

Commit 5f7527f

Browse files
ashbhandarecodemzs
authored andcommitted
Redesign DnnCatalog methods API for ease of use and consistency. (#4362)
* WIP initial change * Changed API design, changed tests and samples to use new API * Combined DnnCatalog.Options and ImageClassificationEstimator.Options, addressed review comments. * Added unit test and sample * Removed duplicate members in Options class, addresses PR comments * Removed preview remark for ImageClassification.
1 parent 8d07a53 commit 5f7527f

File tree

9 files changed

+610
-458
lines changed

9 files changed

+610
-458
lines changed
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+

2+
using System;
3+
using System.Collections.Generic;
4+
using System.IO;
5+
using System.IO.Compression;
6+
using System.Linq;
7+
using System.Net;
8+
using System.Threading;
9+
using System.Threading.Tasks;
10+
using Microsoft.ML;
11+
using Microsoft.ML.Data;
12+
using Microsoft.ML.Transforms;
13+
using static Microsoft.ML.DataOperationsCatalog;
14+
15+
namespace Samples.Dynamic
16+
{
17+
public class ImageClassificationDefault
18+
{
19+
public static void Example()
20+
{
21+
string assetsRelativePath = @"../../../assets";
22+
string assetsPath = GetAbsolutePath(assetsRelativePath);
23+
24+
var outputMlNetModelFilePath = Path.Combine(assetsPath, "outputs",
25+
"imageClassifier.zip");
26+
27+
string imagesDownloadFolderPath = Path.Combine(assetsPath, "inputs",
28+
"images");
29+
30+
//Download the image set and unzip
31+
string finalImagesFolderName = DownloadImageSet(
32+
imagesDownloadFolderPath);
33+
string fullImagesetFolderPath = Path.Combine(
34+
imagesDownloadFolderPath, finalImagesFolderName);
35+
36+
try
37+
{
38+
39+
MLContext mlContext = new MLContext(seed: 1);
40+
41+
//Load all the original images info
42+
IEnumerable<ImageData> images = LoadImagesFromDirectory(
43+
folder: fullImagesetFolderPath, useFolderNameAsLabel: true);
44+
45+
IDataView shuffledFullImagesDataset = mlContext.Data.ShuffleRows(
46+
mlContext.Data.LoadFromEnumerable(images));
47+
48+
shuffledFullImagesDataset = mlContext.Transforms.Conversion
49+
.MapValueToKey("Label")
50+
.Append(mlContext.Transforms.LoadImages("Image",
51+
fullImagesetFolderPath, false, "ImagePath"))
52+
.Fit(shuffledFullImagesDataset)
53+
.Transform(shuffledFullImagesDataset);
54+
55+
// Split the data 90:10 into train and test sets, train and
56+
// evaluate.
57+
TrainTestData trainTestData = mlContext.Data.TrainTestSplit(
58+
shuffledFullImagesDataset, testFraction: 0.1, seed: 1);
59+
60+
IDataView trainDataset = trainTestData.TrainSet;
61+
IDataView testDataset = trainTestData.TestSet;
62+
63+
var pipeline = mlContext.Model.ImageClassification("Image", "Label", validationSet: testDataset)
64+
.Append(mlContext.Transforms.Conversion.MapKeyToValue(
65+
outputColumnName: "PredictedLabel",
66+
inputColumnName: "PredictedLabel"));
67+
68+
69+
Console.WriteLine("*** Training the image classification model " +
70+
"with DNN Transfer Learning on top of the selected " +
71+
"pre-trained model/architecture ***");
72+
73+
// Measuring training time
74+
var watch = System.Diagnostics.Stopwatch.StartNew();
75+
76+
var trainedModel = pipeline.Fit(trainDataset);
77+
78+
watch.Stop();
79+
long elapsedMs = watch.ElapsedMilliseconds;
80+
81+
Console.WriteLine("Training with transfer learning took: " +
82+
(elapsedMs / 1000).ToString() + " seconds");
83+
84+
mlContext.Model.Save(trainedModel, shuffledFullImagesDataset.Schema,
85+
"model.zip");
86+
87+
ITransformer loadedModel;
88+
DataViewSchema schema;
89+
using (var file = File.OpenRead("model.zip"))
90+
loadedModel = mlContext.Model.Load(file, out schema);
91+
92+
EvaluateModel(mlContext, testDataset, loadedModel);
93+
94+
watch = System.Diagnostics.Stopwatch.StartNew();
95+
96+
// Predict image class using an in-memory image.
97+
TrySinglePrediction(fullImagesetFolderPath, mlContext, loadedModel);
98+
99+
watch.Stop();
100+
elapsedMs = watch.ElapsedMilliseconds;
101+
102+
Console.WriteLine("Prediction engine took: " +
103+
(elapsedMs / 1000).ToString() + " seconds");
104+
}
105+
catch (Exception ex)
106+
{
107+
Console.WriteLine(ex.ToString());
108+
}
109+
110+
Console.WriteLine("Press any key to finish");
111+
Console.ReadKey();
112+
}
113+
114+
private static void TrySinglePrediction(string imagesForPredictions,
115+
MLContext mlContext, ITransformer trainedModel)
116+
{
117+
// Create prediction function to try one prediction
118+
var predictionEngine = mlContext.Model
119+
.CreatePredictionEngine<InMemoryImageData, ImagePrediction>(trainedModel);
120+
121+
IEnumerable<InMemoryImageData> testImages = LoadInMemoryImagesFromDirectory(
122+
imagesForPredictions, false);
123+
124+
InMemoryImageData imageToPredict = new InMemoryImageData
125+
{
126+
Image = testImages.First().Image
127+
};
128+
129+
var prediction = predictionEngine.Predict(imageToPredict);
130+
131+
Console.WriteLine($"Scores : [{string.Join(",", prediction.Score)}], " +
132+
$"Predicted Label : {prediction.PredictedLabel}");
133+
}
134+
135+
136+
private static void EvaluateModel(MLContext mlContext,
137+
IDataView testDataset, ITransformer trainedModel)
138+
{
139+
Console.WriteLine("Making bulk predictions and evaluating model's " +
140+
"quality...");
141+
142+
// Measuring time
143+
var watch2 = System.Diagnostics.Stopwatch.StartNew();
144+
145+
IDataView predictions = trainedModel.Transform(testDataset);
146+
var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
147+
148+
Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}," +
149+
$"macro-accuracy = {metrics.MacroAccuracy}");
150+
151+
watch2.Stop();
152+
long elapsed2Ms = watch2.ElapsedMilliseconds;
153+
154+
Console.WriteLine("Predicting and Evaluation took: " +
155+
(elapsed2Ms / 1000).ToString() + " seconds");
156+
}
157+
158+
public static IEnumerable<ImageData> LoadImagesFromDirectory(string folder,
159+
bool useFolderNameAsLabel = true)
160+
{
161+
var files = Directory.GetFiles(folder, "*",
162+
searchOption: SearchOption.AllDirectories);
163+
foreach (var file in files)
164+
{
165+
if (Path.GetExtension(file) != ".jpg")
166+
continue;
167+
168+
var label = Path.GetFileName(file);
169+
if (useFolderNameAsLabel)
170+
label = Directory.GetParent(file).Name;
171+
else
172+
{
173+
for (int index = 0; index < label.Length; index++)
174+
{
175+
if (!char.IsLetter(label[index]))
176+
{
177+
label = label.Substring(0, index);
178+
break;
179+
}
180+
}
181+
}
182+
183+
yield return new ImageData()
184+
{
185+
ImagePath = file,
186+
Label = label
187+
};
188+
189+
}
190+
}
191+
192+
public static IEnumerable<InMemoryImageData>
193+
LoadInMemoryImagesFromDirectory(string folder,
194+
bool useFolderNameAsLabel = true)
195+
{
196+
var files = Directory.GetFiles(folder, "*",
197+
searchOption: SearchOption.AllDirectories);
198+
foreach (var file in files)
199+
{
200+
if (Path.GetExtension(file) != ".jpg")
201+
continue;
202+
203+
var label = Path.GetFileName(file);
204+
if (useFolderNameAsLabel)
205+
label = Directory.GetParent(file).Name;
206+
else
207+
{
208+
for (int index = 0; index < label.Length; index++)
209+
{
210+
if (!char.IsLetter(label[index]))
211+
{
212+
label = label.Substring(0, index);
213+
break;
214+
}
215+
}
216+
}
217+
218+
yield return new InMemoryImageData()
219+
{
220+
Image = File.ReadAllBytes(file),
221+
Label = label
222+
};
223+
224+
}
225+
}
226+
227+
public static string DownloadImageSet(string imagesDownloadFolder)
228+
{
229+
// get a set of images to teach the network about the new classes
230+
231+
//SINGLE SMALL FLOWERS IMAGESET (200 files)
232+
string fileName = "flower_photos_small_set.zip";
233+
string url = $"https://mlnetfilestorage.file.core.windows.net/" +
234+
$"imagesets/flower_images/flower_photos_small_set.zip?st=2019-08-" +
235+
$"07T21%3A27%3A44Z&se=2030-08-08T21%3A27%3A00Z&sp=rl&sv=2018-03-" +
236+
$"28&sr=f&sig=SZ0UBX47pXD0F1rmrOM%2BfcwbPVob8hlgFtIlN89micM%3D";
237+
238+
Download(url, imagesDownloadFolder, fileName);
239+
UnZip(Path.Combine(imagesDownloadFolder, fileName), imagesDownloadFolder);
240+
241+
return Path.GetFileNameWithoutExtension(fileName);
242+
}
243+
244+
public static bool Download(string url, string destDir, string destFileName)
245+
{
246+
if (destFileName == null)
247+
destFileName = url.Split(Path.DirectorySeparatorChar).Last();
248+
249+
Directory.CreateDirectory(destDir);
250+
251+
string relativeFilePath = Path.Combine(destDir, destFileName);
252+
253+
if (File.Exists(relativeFilePath))
254+
{
255+
Console.WriteLine($"{relativeFilePath} already exists.");
256+
return false;
257+
}
258+
259+
var wc = new WebClient();
260+
Console.WriteLine($"Downloading {relativeFilePath}");
261+
var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath));
262+
while (!download.IsCompleted)
263+
{
264+
Thread.Sleep(1000);
265+
Console.Write(".");
266+
}
267+
Console.WriteLine("");
268+
Console.WriteLine($"Downloaded {relativeFilePath}");
269+
270+
return true;
271+
}
272+
273+
public static void UnZip(String gzArchiveName, String destFolder)
274+
{
275+
var flag = gzArchiveName.Split(Path.DirectorySeparatorChar)
276+
.Last()
277+
.Split('.')
278+
.First() + ".bin";
279+
280+
if (File.Exists(Path.Combine(destFolder, flag))) return;
281+
282+
Console.WriteLine($"Extracting.");
283+
ZipFile.ExtractToDirectory(gzArchiveName, destFolder);
284+
285+
File.Create(Path.Combine(destFolder, flag));
286+
Console.WriteLine("");
287+
Console.WriteLine("Extracting is completed.");
288+
}
289+
290+
public static string GetAbsolutePath(string relativePath)
291+
{
292+
FileInfo _dataRoot = new FileInfo(typeof(
293+
ResnetV2101TransferLearningTrainTestSplit).Assembly.Location);
294+
295+
string assemblyFolderPath = _dataRoot.Directory.FullName;
296+
297+
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
298+
299+
return fullPath;
300+
}
301+
302+
public class InMemoryImageData
303+
{
304+
[LoadColumn(0)]
305+
public byte[] Image;
306+
307+
[LoadColumn(1)]
308+
public string Label;
309+
}
310+
311+
public class ImageData
312+
{
313+
[LoadColumn(0)]
314+
public string ImagePath;
315+
316+
[LoadColumn(1)]
317+
public string Label;
318+
}
319+
320+
public class ImagePrediction
321+
{
322+
[ColumnName("Score")]
323+
public float[] Score;
324+
325+
[ColumnName("PredictedLabel")]
326+
public string PredictedLabel;
327+
}
328+
}
329+
}

0 commit comments

Comments
 (0)