diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/InceptionV3TransferLearning.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/InceptionV3TransferLearning.cs
deleted file mode 100644
index 0e1ec80973..0000000000
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/InceptionV3TransferLearning.cs
+++ /dev/null
@@ -1,109 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.IO;
-using System.Linq;
-using Microsoft.ML;
-using Microsoft.ML.Data;
-using Microsoft.ML.Transforms;
-
-namespace Samples.Dynamic
-{
- public static class InceptionV3TransferLearning
- {
- ///
- /// Example use of Image classification API in a ML.NET pipeline.
- ///
- public static void Example()
- {
- var mlContext = new MLContext(seed: 1);
-
- var imagesDataFile = Path.GetDirectoryName(
- Microsoft.ML.SamplesUtils.DatasetUtils.DownloadImages());
-
- var data = mlContext.Data.LoadFromEnumerable(
- ImageNetData.LoadImagesFromDirectory(imagesDataFile, 4));
-
- data = mlContext.Data.ShuffleRows(data, 5);
- var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label")
- .Append(mlContext.Transforms.LoadImages("ImageObject", null,
- "ImagePath"))
- .Append(mlContext.Transforms.ResizeImages("Image",
- inputColumnName: "ImageObject", imageWidth: 299,
- imageHeight: 299))
- .Append(mlContext.Transforms.ExtractPixels("Image",
- interleavePixelColors: true))
- .Append(mlContext.Model.ImageClassification("Image",
- "Label", arch: DnnEstimator.Architecture.InceptionV3, epoch: 4,
- batchSize: 4));
-
- var trainedModel = pipeline.Fit(data);
- var predicted = trainedModel.Transform(data);
- var metrics = mlContext.MulticlassClassification.Evaluate(predicted);
-
- Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}," +
- $"macro-accuracy = {metrics.MacroAccuracy}");
-
- // Create prediction function and test prediction
- var predictFunction = mlContext.Model
- .CreatePredictionEngine(trainedModel);
-
- var prediction = predictFunction
- .Predict(ImageNetData.LoadImagesFromDirectory(imagesDataFile, 4)
- .First());
-
- Console.WriteLine($"Scores : [{string.Join(",", prediction.Score)}], " +
- $"Predicted Label : {prediction.PredictedLabel}");
-
- }
- }
-
- public class ImageNetData
- {
- [LoadColumn(0)]
- public string ImagePath;
-
- [LoadColumn(1)]
- public string Label;
-
- public static IEnumerable LoadImagesFromDirectory(
- string folder, int repeat = 1, bool useFolderNameasLabel = false)
- {
- var files = Directory.GetFiles(folder, "*",
- searchOption: SearchOption.AllDirectories);
-
- foreach (var file in files)
- {
- if (Path.GetExtension(file) != ".jpg")
- continue;
-
- var label = Path.GetFileName(file);
- if (useFolderNameasLabel)
- label = Directory.GetParent(file).Name;
- else
- {
- for (int index = 0; index < label.Length; index++)
- {
- if (!char.IsLetter(label[index]))
- {
- label = label.Substring(0, index);
- break;
- }
- }
- }
-
- for (int index = 0; index < repeat; index++)
- yield return new ImageNetData() {
- ImagePath = file,Label = label };
- }
- }
- }
-
- public class ImagePrediction
- {
- [ColumnName("Score")]
- public float[] Score;
-
- [ColumnName("PredictedLabel")]
- public Int64 PredictedLabel;
- }
-}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearning.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearning.cs
deleted file mode 100644
index 9d3136b01b..0000000000
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearning.cs
+++ /dev/null
@@ -1,123 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Diagnostics;
-using System.IO;
-using System.Linq;
-using Microsoft.ML;
-using Microsoft.ML.Data;
-using Microsoft.ML.Transforms;
-
-namespace Samples.Dynamic
-{
- public static class ResnetV2101TransferLearning
- {
- ///
- /// Example use of Image classification API in a ML.NET pipeline.
- ///
- public static void Example()
- {
- var mlContext = new MLContext(seed: 1);
-
- var imagesDataFile = Path.GetDirectoryName(
- Microsoft.ML.SamplesUtils.DatasetUtils.DownloadImages());
-
- var data = mlContext.Data.LoadFromEnumerable(
- ImageNetData.LoadImagesFromDirectory(imagesDataFile, 4));
-
- data = mlContext.Data.ShuffleRows(data, 5);
- var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label")
- .Append(mlContext.Transforms.LoadImages("ImageObject", null,
- "ImagePath"))
- .Append(mlContext.Transforms.ResizeImages("Image",
- inputColumnName: "ImageObject", imageWidth: 299,
- imageHeight: 299))
- .Append(mlContext.Transforms.ExtractPixels("Image",
- interleavePixelColors: true))
- .Append(mlContext.Model.ImageClassification("Image",
- "Label", arch: DnnEstimator.Architecture.ResnetV2101, epoch: 4,
- batchSize: 4));
-
- var trainedModel = pipeline.Fit(data);
- var predicted = trainedModel.Transform(data);
- var metrics = mlContext.MulticlassClassification.Evaluate(predicted);
-
- Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}," +
- $"macro-accuracy = {metrics.MacroAccuracy}");
-
- mlContext.Model.Save(trainedModel, data.Schema, "model.zip");
-
- ITransformer loadedModel;
- using (var file = File.OpenRead("model.zip"))
- loadedModel = mlContext.Model.Load(file, out DataViewSchema schema);
-
- // Create prediction function and test prediction
- var predictFunction = mlContext.Model
- .CreatePredictionEngine(loadedModel);
-
- var prediction = predictFunction
- .Predict(ImageNetData.LoadImagesFromDirectory(imagesDataFile, 4)
- .First());
-
- Console.WriteLine($"Scores : [{string.Join(",", prediction.Score)}], " +
- $"Predicted Label : {prediction.PredictedLabel}");
- }
-
- private const int imageHeight = 224;
- private const int imageWidth = 224;
- private const int numChannels = 3;
- private const int inputSize = imageHeight * imageWidth * numChannels;
-
- public class ImageNetData
- {
- [LoadColumn(0)]
- public string ImagePath;
-
- [LoadColumn(1)]
- public string Label;
-
- public static IEnumerable LoadImagesFromDirectory(
- string folder, int repeat = 1, bool useFolderNameasLabel = false)
- {
- var files = Directory.GetFiles(folder, "*",
- searchOption: SearchOption.AllDirectories);
-
- foreach (var file in files)
- {
- if (Path.GetExtension(file) != ".jpg")
- continue;
-
- var label = Path.GetFileName(file);
- if (useFolderNameasLabel)
- label = Directory.GetParent(file).Name;
- else
- {
- for (int index = 0; index < label.Length; index++)
- {
- if (!char.IsLetter(label[index]))
- {
- label = label.Substring(0, index);
- break;
- }
- }
- }
-
- for (int index = 0; index < repeat; index++)
- yield return new ImageNetData()
- {
- ImagePath = file,
- Label = label
- };
- }
- }
- }
-
- public class ImagePrediction
- {
- [ColumnName("Score")]
- public float[] Score;
-
- [ColumnName("PredictedLabel")]
- public Int64 PredictedLabel;
- }
- }
-}
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs
new file mode 100644
index 0000000000..ed3f4da5a4
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs
@@ -0,0 +1,307 @@
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Threading.Tasks;
+using Microsoft.ML;
+using Microsoft.ML.Transforms;
+using static Microsoft.ML.DataOperationsCatalog;
+using System.Linq;
+using Microsoft.ML.Data;
+using System.IO.Compression;
+using System.Threading;
+using System.Net;
+
+namespace Samples.Dynamic
+{
+ public class ResnetV2101TransferLearningTrainTestSplit
+ {
+ public static void Example()
+ {
+ string assetsRelativePath = @"../../../assets";
+ string assetsPath = GetAbsolutePath(assetsRelativePath);
+
+ var outputMlNetModelFilePath = Path.Combine(assetsPath, "outputs",
+ "imageClassifier.zip");
+
+ string imagesDownloadFolderPath = Path.Combine(assetsPath, "inputs",
+ "images");
+
+ //Download the image set and unzip
+ string finalImagesFolderName = DownloadImageSet(
+ imagesDownloadFolderPath);
+
+ string fullImagesetFolderPath = Path.Combine(
+ imagesDownloadFolderPath, finalImagesFolderName);
+
+ try
+ {
+
+ MLContext mlContext = new MLContext(seed: 1);
+
+ //Load all the original images info
+ IEnumerable images = LoadImagesFromDirectory(
+ folder: fullImagesetFolderPath, useFolderNameasLabel: true);
+
+ IDataView shuffledFullImagesDataset = mlContext.Data.ShuffleRows(
+ mlContext.Data.LoadFromEnumerable(images));
+
+ shuffledFullImagesDataset = mlContext.Transforms.Conversion
+ .MapValueToKey("Label")
+ .Fit(shuffledFullImagesDataset)
+ .Transform(shuffledFullImagesDataset);
+
+ // Split the data 90:10 into train and test sets, train and evaluate.
+ TrainTestData trainTestData = mlContext.Data.TrainTestSplit(
+ shuffledFullImagesDataset, testFraction: 0.1, seed: 1);
+
+ IDataView trainDataset = trainTestData.TrainSet;
+ IDataView testDataset = trainTestData.TestSet;
+
+ var pipeline = mlContext.Model.ImageClassification(
+ "ImagePath", "Label",
+ // Just by changing/selecting InceptionV3 here instead of
+ // ResnetV2101 you can try a different architecture/pre-trained
+ // model.
+ arch: ImageClassificationEstimator.Architecture.ResnetV2101,
+ epoch: 50,
+ batchSize: 10,
+ learningRate: 0.01f,
+ metricsCallback: (metrics) => Console.WriteLine(metrics),
+ validationSet: testDataset);
+
+
+ Console.WriteLine("*** Training the image classification model with " +
+ "DNN Transfer Learning on top of the selected pre-trained " +
+ "model/architecture ***");
+
+ // Measuring training time
+ var watch = System.Diagnostics.Stopwatch.StartNew();
+
+ var trainedModel = pipeline.Fit(trainDataset);
+
+ watch.Stop();
+ long elapsedMs = watch.ElapsedMilliseconds;
+
+ Console.WriteLine("Training with transfer learning took: " +
+ (elapsedMs / 1000).ToString() + " seconds");
+
+ mlContext.Model.Save(trainedModel, shuffledFullImagesDataset.Schema,
+ "model.zip");
+
+ ITransformer loadedModel;
+ DataViewSchema schema;
+ using (var file = File.OpenRead("model.zip"))
+ loadedModel = mlContext.Model.Load(file, out schema);
+
+ EvaluateModel(mlContext, testDataset, loadedModel);
+
+ VBuffer> keys = default;
+ loadedModel.GetOutputSchema(schema)["Label"].GetKeyValues(ref keys);
+
+ watch = System.Diagnostics.Stopwatch.StartNew();
+ TrySinglePrediction(fullImagesetFolderPath, mlContext, loadedModel,
+ keys.DenseValues().ToArray());
+
+ watch.Stop();
+ elapsedMs = watch.ElapsedMilliseconds;
+
+ Console.WriteLine("Prediction engine took: " +
+ (elapsedMs / 1000).ToString() + " seconds");
+ }
+ catch (Exception ex)
+ {
+ Console.WriteLine(ex.ToString());
+ }
+
+ Console.WriteLine("Press any key to finish");
+ Console.ReadKey();
+ }
+
+ private static void TrySinglePrediction(string imagesForPredictions,
+ MLContext mlContext, ITransformer trainedModel,
+ ReadOnlyMemory[] originalLabels)
+ {
+ // Create prediction function to try one prediction
+ var predictionEngine = mlContext.Model
+ .CreatePredictionEngine(trainedModel);
+
+ IEnumerable testImages = LoadImagesFromDirectory(
+ imagesForPredictions, false);
+
+ ImageData imageToPredict = new ImageData
+ {
+ ImagePath = testImages.First().ImagePath
+ };
+
+ var prediction = predictionEngine.Predict(imageToPredict);
+ var index = prediction.PredictedLabel;
+
+ Console.WriteLine($"ImageFile : " +
+ $"[{Path.GetFileName(imageToPredict.ImagePath)}], " +
+ $"Scores : [{string.Join(",", prediction.Score)}], " +
+ $"Predicted Label : {originalLabels[index]}");
+ }
+
+
+ private static void EvaluateModel(MLContext mlContext,
+ IDataView testDataset, ITransformer trainedModel)
+ {
+ Console.WriteLine("Making bulk predictions and evaluating model's " +
+ "quality...");
+
+ // Measuring time
+ var watch2 = System.Diagnostics.Stopwatch.StartNew();
+
+ IDataView predictions = trainedModel.Transform(testDataset);
+ var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
+
+ Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}," +
+ $"macro-accuracy = {metrics.MacroAccuracy}");
+
+ watch2.Stop();
+ long elapsed2Ms = watch2.ElapsedMilliseconds;
+
+ Console.WriteLine("Predicting and Evaluation took: " +
+ (elapsed2Ms / 1000).ToString() + " seconds");
+ }
+
+ public static IEnumerable LoadImagesFromDirectory(string folder,
+ bool useFolderNameasLabel = true)
+ {
+ var files = Directory.GetFiles(folder, "*",
+ searchOption: SearchOption.AllDirectories);
+
+ foreach (var file in files)
+ {
+ if (Path.GetExtension(file) != ".jpg")
+ continue;
+
+ var label = Path.GetFileName(file);
+ if (useFolderNameasLabel)
+ label = Directory.GetParent(file).Name;
+ else
+ {
+ for (int index = 0; index < label.Length; index++)
+ {
+ if (!char.IsLetter(label[index]))
+ {
+ label = label.Substring(0, index);
+ break;
+ }
+ }
+ }
+
+ yield return new ImageData()
+ {
+ ImagePath = file,
+ Label = label
+ };
+
+ }
+ }
+
+ public static string DownloadImageSet(string imagesDownloadFolder)
+ {
+ // get a set of images to teach the network about the new classes
+
+ //SINGLE SMALL FLOWERS IMAGESET (200 files)
+ string fileName = "flower_photos_small_set.zip";
+ string url = $"https://mlnetfilestorage.file.core.windows.net/" +
+ $"imagesets/flower_images/flower_photos_small_set.zip?st=2019-08-" +
+ $"07T21%3A27%3A44Z&se=2030-08-08T21%3A27%3A00Z&sp=rl&sv=2018-03-" +
+ $"28&sr=f&sig=SZ0UBX47pXD0F1rmrOM%2BfcwbPVob8hlgFtIlN89micM%3D";
+
+ Download(url, imagesDownloadFolder, fileName);
+ UnZip(Path.Combine(imagesDownloadFolder, fileName), imagesDownloadFolder);
+
+ return Path.GetFileNameWithoutExtension(fileName);
+ }
+
+ public static bool Download(string url, string destDir, string destFileName)
+ {
+ if (destFileName == null)
+ destFileName = url.Split(Path.DirectorySeparatorChar).Last();
+
+ Directory.CreateDirectory(destDir);
+
+ string relativeFilePath = Path.Combine(destDir, destFileName);
+
+ if (File.Exists(relativeFilePath))
+ {
+ Console.WriteLine($"{relativeFilePath} already exists.");
+ return false;
+ }
+
+ var wc = new WebClient();
+ Console.WriteLine($"Downloading {relativeFilePath}");
+ var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath));
+ while (!download.IsCompleted)
+ {
+ Thread.Sleep(1000);
+ Console.Write(".");
+ }
+ Console.WriteLine("");
+ Console.WriteLine($"Downloaded {relativeFilePath}");
+
+ return true;
+ }
+
+ public static void UnZip(String gzArchiveName, String destFolder)
+ {
+ var flag = gzArchiveName.Split(Path.DirectorySeparatorChar)
+ .Last()
+ .Split('.')
+ .First() + ".bin";
+
+ if (File.Exists(Path.Combine(destFolder, flag))) return;
+
+ Console.WriteLine($"Extracting.");
+ var task = Task.Run(() =>
+ {
+ ZipFile.ExtractToDirectory(gzArchiveName, destFolder);
+ });
+
+ while (!task.IsCompleted)
+ {
+ Thread.Sleep(200);
+ Console.Write(".");
+ }
+
+ File.Create(Path.Combine(destFolder, flag));
+ Console.WriteLine("");
+ Console.WriteLine("Extracting is completed.");
+ }
+
+ public static string GetAbsolutePath(string relativePath)
+ {
+ FileInfo _dataRoot = new FileInfo(typeof(
+ ResnetV2101TransferLearningTrainTestSplit).Assembly.Location);
+
+ string assemblyFolderPath = _dataRoot.Directory.FullName;
+
+ string fullPath = Path.Combine(assemblyFolderPath, relativePath);
+
+ return fullPath;
+ }
+
+ public class ImageData
+ {
+ [LoadColumn(0)]
+ public string ImagePath;
+
+ [LoadColumn(1)]
+ public string Label;
+ }
+
+ public class ImagePrediction
+ {
+ [ColumnName("Score")]
+ public float[] Score;
+
+ [ColumnName("PredictedLabel")]
+ public UInt32 PredictedLabel;
+ }
+ }
+}
+
diff --git a/src/Microsoft.ML.Dnn/DnnCatalog.cs b/src/Microsoft.ML.Dnn/DnnCatalog.cs
index 3f5ee076ff..fea3735796 100644
--- a/src/Microsoft.ML.Dnn/DnnCatalog.cs
+++ b/src/Microsoft.ML.Dnn/DnnCatalog.cs
@@ -9,11 +9,12 @@
using Microsoft.ML.Data;
using Microsoft.ML.Transforms;
using Microsoft.ML.Transforms.Dnn;
-using static Microsoft.ML.Transforms.DnnEstimator;
+using static Microsoft.ML.Transforms.ImageClassificationEstimator;
+using Options = Microsoft.ML.Transforms.DnnRetrainEstimator.Options;
namespace Microsoft.ML
{
- ///
+ ///
public static class DnnCatalog
{
@@ -36,11 +37,10 @@ public static class DnnCatalog
/// Learning rate to use during optimization (Optional).
/// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
/// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well.
- ///
///
/// The support for retraining is under preview.
///
- public static DnnEstimator RetrainDnnModel(
+ public static DnnRetrainEstimator RetrainDnnModel(
this ModelOperationsCatalog catalog,
string[] outputColumnNames,
string[] inputColumnNames,
@@ -54,8 +54,7 @@ public static DnnEstimator RetrainDnnModel(
string metricOperation = null,
string learningRateOperation = null,
float learningRate = 0.01f,
- bool addBatchDimensionInput = false,
- DnnFramework dnnFramework = DnnFramework.Tensorflow)
+ bool addBatchDimensionInput = false)
{
var options = new Options()
{
@@ -71,12 +70,11 @@ public static DnnEstimator RetrainDnnModel(
LearningRateOperation = learningRateOperation,
LearningRate = learningRate,
BatchSize = batchSize,
- AddBatchDimensionInputs = addBatchDimensionInput,
- ReTrain = true
+ AddBatchDimensionInputs = addBatchDimensionInput
};
var env = CatalogUtils.GetEnvironment(catalog);
- return new DnnEstimator(env, options, DnnUtils.LoadDnnModel(env, modelPath, true));
+ return new DnnRetrainEstimator(env, options, DnnUtils.LoadDnnModel(env, modelPath, true));
}
///
@@ -85,33 +83,50 @@ public static DnnEstimator RetrainDnnModel(
///
/// The name of the input features column.
/// The name of the labels column.
- /// Optional name of the path where a copy new graph should be saved. The graph will be saved as part of model.
/// The name of the output score column.
/// The name of the output predicted label columns.
- /// The name of the prefix for checkpoint files.
/// The architecture of the image recognition DNN model.
- /// The backend DNN framework to use, currently only Tensorflow is supported.
- /// Number of training epochs.
+ /// Number of training iterations. Each iteration/epoch refers to one pass over the dataset.
/// The batch size for training.
/// The learning rate for training.
+ /// Callback for reporting model statistics during training phase.
+ /// Indicates the frequency of epochs at which to report model statistics during training phase.
+ /// Indicates the choice of DNN training framework. Currently only tensorflow is supported.
+ /// Optional name of the path where a copy new graph should be saved. The graph will be saved as part of model.
+ /// The name of the prefix for the final mode and checkpoint files.
+ /// Validation set.
+ /// Indicates to evaluate the model on train set after every epoch.
+ /// Indicates to not re-compute cached trainset bottleneck values if already available in the bin folder.
+ /// Indicates to not re-compute validataionset cached bottleneck validationset values if already available in the bin folder.
+ /// Indicates the file path to store trainset bottleneck values for caching.
+ /// Indicates the file path to store validationset bottleneck values for caching.
///
/// The support for image classification is under preview.
///
- public static DnnEstimator ImageClassification(
+ public static ImageClassificationEstimator ImageClassification(
this ModelOperationsCatalog catalog,
string featuresColumnName,
string labelColumnName,
- string outputGraphPath = null,
string scoreColumnName = "Score",
string predictedLabelColumnName = "PredictedLabel",
- string checkpointName = "_retrain_checkpoint",
- Architecture arch = Architecture.ResnetV2101,
- DnnFramework dnnFramework = DnnFramework.Tensorflow,
- int epoch = 10,
- int batchSize = 20,
- float learningRate = 0.01f)
+ Architecture arch = Architecture.InceptionV3,
+ int epoch = 100,
+ int batchSize = 10,
+ float learningRate = 0.01f,
+ ImageClassificationMetricsCallback metricsCallback = null,
+ int statisticFrequency = 1,
+ DnnFramework framework = DnnFramework.Tensorflow,
+ string modelSavePath = null,
+ string finalModelPrefix = "custom_retrained_model_based_on_",
+ IDataView validationSet = null,
+ bool testOnTrainSet = true,
+ bool reuseTrainSetBottleneckCachedValues = false,
+ bool reuseValidationSetBottleneckCachedValues = false,
+ string trainSetBottleneckCachedValuesFilePath = "trainSetBottleneckFile.csv",
+ string validationSetBottleneckCachedValuesFilePath = "validationSetBottleneckFile.csv"
+ )
{
- var options = new Options()
+ var options = new ImageClassificationEstimator.Options()
{
ModelLocation = arch == Architecture.ResnetV2101 ? @"resnet_v2_101_299.meta" : @"InceptionV3.meta",
InputColumns = new[] { featuresColumnName },
@@ -121,13 +136,20 @@ public static DnnEstimator ImageClassification(
Epoch = epoch,
LearningRate = learningRate,
BatchSize = batchSize,
- AddBatchDimensionInputs = arch == Architecture.InceptionV3 ? false : true,
- TransferLearning = true,
ScoreColumnName = scoreColumnName,
PredictedLabelColumnName = predictedLabelColumnName,
- CheckpointName = checkpointName,
+ FinalModelPrefix = finalModelPrefix,
Arch = arch,
- MeasureTrainAccuracy = false
+ MetricsCallback = metricsCallback,
+ StatisticsFrequency = statisticFrequency,
+ Framework = framework,
+ ModelSavePath = modelSavePath,
+ ValidationSet = validationSet,
+ TestOnTrainSet = testOnTrainSet,
+ TrainSetBottleneckCachedValuesFilePath = trainSetBottleneckCachedValuesFilePath,
+ ValidationSetBottleneckCachedValuesFilePath = validationSetBottleneckCachedValuesFilePath,
+ ReuseTrainSetBottleneckCachedValues = reuseTrainSetBottleneckCachedValues,
+ ReuseValidationSetBottleneckCachedValues = reuseValidationSetBottleneckCachedValues
};
if (!File.Exists(options.ModelLocation))
@@ -158,7 +180,7 @@ public static DnnEstimator ImageClassification(
}
var env = CatalogUtils.GetEnvironment(catalog);
- return new DnnEstimator(env, options, DnnUtils.LoadDnnModel(env, options.ModelLocation, true));
+ return new ImageClassificationEstimator(env, options, DnnUtils.LoadDnnModel(env, options.ModelLocation, true));
}
}
}
diff --git a/src/Microsoft.ML.Dnn/DnnModel.cs b/src/Microsoft.ML.Dnn/DnnModel.cs
index 6f8c54edb7..a5324e9e39 100644
--- a/src/Microsoft.ML.Dnn/DnnModel.cs
+++ b/src/Microsoft.ML.Dnn/DnnModel.cs
@@ -4,14 +4,14 @@
using Microsoft.ML.Runtime;
using Tensorflow;
-using static Microsoft.ML.Transforms.DnnEstimator;
+using static Microsoft.ML.Transforms.DnnRetrainEstimator;
namespace Microsoft.ML.Transforms
{
///
/// This class holds the information related to TensorFlow model and session.
/// It provides some convenient methods to query model schema as well as
- /// creation of object.
+ /// creation of object.
///
public sealed class DnnModel
{
diff --git a/src/Microsoft.ML.Dnn/DnnTransform.cs b/src/Microsoft.ML.Dnn/DnnRetrainTransform.cs
similarity index 64%
rename from src/Microsoft.ML.Dnn/DnnTransform.cs
rename to src/Microsoft.ML.Dnn/DnnRetrainTransform.cs
index 56b3e13fd1..99cbd1afa5 100644
--- a/src/Microsoft.ML.Dnn/DnnTransform.cs
+++ b/src/Microsoft.ML.Dnn/DnnRetrainTransform.cs
@@ -8,7 +8,6 @@
using System.IO;
using System.Linq;
using System.Text;
-using Google.Protobuf;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
@@ -18,33 +17,29 @@
using Microsoft.ML.Transforms.Dnn;
using NumSharp;
using Tensorflow;
-using Tensorflow.Summaries;
using static Microsoft.ML.Transforms.Dnn.DnnUtils;
-using static Microsoft.ML.Transforms.DnnEstimator;
-using static Tensorflow.Python;
-[assembly: LoadableClass(DnnTransformer.Summary, typeof(IDataTransform), typeof(DnnTransformer),
- typeof(DnnEstimator.Options), typeof(SignatureDataTransform), DnnTransformer.UserName, DnnTransformer.ShortName)]
+[assembly: LoadableClass(DnnRetrainTransformer.Summary, typeof(IDataTransform), typeof(DnnRetrainTransformer),
+ typeof(DnnRetrainEstimator.Options), typeof(SignatureDataTransform), DnnRetrainTransformer.UserName, DnnRetrainTransformer.ShortName)]
-[assembly: LoadableClass(DnnTransformer.Summary, typeof(IDataTransform), typeof(DnnTransformer), null, typeof(SignatureLoadDataTransform),
- DnnTransformer.UserName, DnnTransformer.LoaderSignature)]
+[assembly: LoadableClass(DnnRetrainTransformer.Summary, typeof(IDataTransform), typeof(DnnRetrainTransformer), null, typeof(SignatureLoadDataTransform),
+ DnnRetrainTransformer.UserName, DnnRetrainTransformer.LoaderSignature)]
-[assembly: LoadableClass(typeof(DnnTransformer), null, typeof(SignatureLoadModel),
- DnnTransformer.UserName, DnnTransformer.LoaderSignature)]
+[assembly: LoadableClass(typeof(DnnRetrainTransformer), null, typeof(SignatureLoadModel),
+ DnnRetrainTransformer.UserName, DnnRetrainTransformer.LoaderSignature)]
-[assembly: LoadableClass(typeof(IRowMapper), typeof(DnnTransformer), null, typeof(SignatureLoadRowMapper),
- DnnTransformer.UserName, DnnTransformer.LoaderSignature)]
+[assembly: LoadableClass(typeof(IRowMapper), typeof(DnnRetrainTransformer), null, typeof(SignatureLoadRowMapper),
+ DnnRetrainTransformer.UserName, DnnRetrainTransformer.LoaderSignature)]
namespace Microsoft.ML.Transforms
{
///
- /// for the .
+ /// for the .
///
- public sealed class DnnTransformer : RowToRowTransformerBase
+ public sealed class DnnRetrainTransformer : RowToRowTransformerBase
{
private readonly IHostEnvironment _env;
private readonly string _modelLocation;
- private readonly bool _transferLearning;
private readonly bool _isTemporarySavedModel;
private readonly bool _addBatchDimensionInput;
private Session _session;
@@ -56,33 +51,15 @@ public sealed class DnnTransformer : RowToRowTransformerBase
private readonly (Operation, int)[] _tfOutputOperations;
private TF_Output[] _tfInputNodes;
private readonly TF_Output[] _tfOutputNodes;
- private Tensor _bottleneckTensor;
- private Operation _trainStep;
- private Tensor _softMaxTensor;
- private Tensor _crossEntropy;
- private Tensor _labelTensor;
- private Tensor _evaluationStep;
- private Tensor _prediction;
- private readonly int _classCount;
- private readonly string _checkpointPath;
- private readonly string _bottleneckOperationName;
private Graph Graph => _session.graph;
private readonly Dictionary _idvToTfMapping;
private readonly string[] _inputs;
private readonly string[] _outputs;
- private readonly string _labelColumnName;
- private readonly string _checkpointName;
- private readonly Architecture _arch;
- private readonly string _scoreColumnName;
- private readonly string _predictedLabelColumnName;
- private readonly float _learningRate;
- private readonly string _softmaxTensorName;
- private readonly string _predictionTensorName;
-
- internal const string Summary = "Trains Dnn models.";
- internal const string UserName = "DnnTransform";
- internal const string ShortName = "DnnTransform";
- internal const string LoaderSignature = "DnnTransform";
+
+ internal const string Summary = "Re-Trains Dnn models.";
+ internal const string UserName = "DnnRtTransform";
+ internal const string ShortName = "DnnRtTransform";
+ internal const string LoaderSignature = "DnnRtTransform";
internal static class DefaultModelFileNames
{
@@ -102,11 +79,11 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00000001,
verWeCanReadBack: 0x00000001,
loaderSignature: LoaderSignature,
- loaderAssemblyName: typeof(DnnTransformer).Assembly.FullName);
+ loaderAssemblyName: typeof(DnnRetrainTransformer).Assembly.FullName);
}
// Factory method for SignatureLoadModel.
- private static DnnTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
+ private static DnnRetrainTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
@@ -123,9 +100,7 @@ private static DnnTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
// int: id of output column name
// stream: tensorFlow model.
- GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput,
- out bool transferLearning, out string labelColumn, out string checkpointName, out Architecture arch, out string scoreColumnName,
- out string predictedColumnName, out float learningRate, out int classCount, out string predictionTensorName, out string softMaxTensorName);
+ GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput);
if (isFrozen)
{
@@ -133,12 +108,11 @@ private static DnnTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
throw env.ExceptDecode();
- return new DnnTransformer(env, DnnUtils.LoadTFSession(env, modelBytes), outputs, inputs,
- null, false, addBatchDimensionInput, 1, transferLearning, labelColumn, checkpointName, arch,
- scoreColumnName, predictedColumnName, learningRate, null, classCount, true, predictionTensorName, softMaxTensorName);
+ return new DnnRetrainTransformer(env, DnnUtils.LoadTFSession(env, modelBytes), outputs, inputs,
+ null, false, addBatchDimensionInput, 1);
}
- var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(DnnTransformer) + "_" + Guid.NewGuid()));
+ var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(DnnRetrainTransformer) + "_" + Guid.NewGuid()));
DnnUtils.CreateFolderWithAclIfNotExists(env, tempDirPath);
try
{
@@ -164,9 +138,8 @@ private static DnnTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
}
});
- return new DnnTransformer(env, DnnUtils.GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true,
- addBatchDimensionInput, 1, transferLearning, labelColumn, checkpointName, arch,
- scoreColumnName, predictedColumnName, learningRate, null, classCount, true, predictionTensorName, softMaxTensorName);
+ return new DnnRetrainTransformer(env, DnnUtils.GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true,
+ addBatchDimensionInput, 1);
}
catch (Exception)
{
@@ -176,7 +149,7 @@ private static DnnTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
}
// Factory method for SignatureDataTransform.
- internal static IDataTransform Create(IHostEnvironment env, DnnEstimator.Options options, IDataView input)
+ internal static IDataTransform Create(IHostEnvironment env, DnnRetrainEstimator.Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(options, nameof(options));
@@ -184,33 +157,30 @@ internal static IDataTransform Create(IHostEnvironment env, DnnEstimator.Options
env.CheckValue(options.InputColumns, nameof(options.InputColumns));
env.CheckValue(options.OutputColumns, nameof(options.OutputColumns));
- return new DnnTransformer(env, options, input).MakeDataTransform(input);
+ return new DnnRetrainTransformer(env, options, input).MakeDataTransform(input);
}
- internal DnnTransformer(IHostEnvironment env, DnnEstimator.Options options, IDataView input)
+ internal DnnRetrainTransformer(IHostEnvironment env, DnnRetrainEstimator.Options options, IDataView input)
: this(env, options, DnnUtils.LoadDnnModel(env, options.ModelLocation), input)
{
}
- internal DnnTransformer(IHostEnvironment env, DnnEstimator.Options options, DnnModel tensorFlowModel, IDataView input, IDataView validationSet = null)
+ internal DnnRetrainTransformer(IHostEnvironment env, DnnRetrainEstimator.Options options, DnnModel tensorFlowModel, IDataView input, IDataView validationSet = null)
: this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns,
- options.ModelLocation, false, options.AddBatchDimensionInputs, options.BatchSize, options.TransferLearning,
- options.LabelColumn, options.CheckpointName, options.Arch, options.ScoreColumnName,
- options.PredictedLabelColumnName, options.LearningRate, input.Schema)
+ options.ModelLocation, false, options.AddBatchDimensionInputs, options.BatchSize)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(options, nameof(options));
env.CheckValue(input, nameof(input));
- if (options.ReTrain)
- CheckTrainingParameters(options);
+ CheckTrainingParameters(options);
- if (options.ReTrain && !DnnUtils.IsSavedModel(env, options.ModelLocation))
+ if (!DnnUtils.IsSavedModel(env, options.ModelLocation))
throw env.ExceptNotSupp("TensorFlowTransform: Re-Training of TensorFlow model is only supported for un-frozen model.");
TrainCore(options, input, validationSet);
}
- private void CheckTrainingParameters(DnnEstimator.Options options)
+ private void CheckTrainingParameters(DnnRetrainEstimator.Options options)
{
Host.CheckNonWhiteSpace(options.LabelColumn, nameof(options.LabelColumn));
Host.CheckNonWhiteSpace(options.OptimizationOperation, nameof(options.OptimizationOperation));
@@ -296,7 +266,7 @@ private void CheckTrainingParameters(DnnEstimator.Options options)
return (inputColIndex, isInputVector, tfInputType, tfInputShape);
}
- private void TrainCore(DnnEstimator.Options options, IDataView input, IDataView validationSet)
+ private void TrainCore(DnnRetrainEstimator.Options options, IDataView input, IDataView validationSet)
{
var inputsForTraining = new string[_inputs.Length + 1];
var inputColIndices = new int[inputsForTraining.Length];
@@ -313,10 +283,7 @@ private void TrainCore(DnnEstimator.Options options, IDataView input, IDataView
GetTrainingInputInfo(inputSchema, _inputs[i], inputsForTraining[i], options.BatchSize);
var index = inputsForTraining.Length - 1;
- if (options.TransferLearning)
- inputsForTraining[index] = _labelTensor.name.Split(':').First();
- else
- inputsForTraining[index] = options.TensorFlowLabel;
+ inputsForTraining[index] = options.TensorFlowLabel;
(inputColIndices[index], isInputVector[index], tfInputTypes[index], tfInputShapes[index]) =
GetTrainingInputInfo(inputSchema, options.LabelColumn, inputsForTraining[index], options.BatchSize);
@@ -324,14 +291,9 @@ private void TrainCore(DnnEstimator.Options options, IDataView input, IDataView
// Create graph inputs.
Operation labelOp;
int labelOpIdx;
- if (options.ReTrain)
- (labelOp, labelOpIdx) = GetOperationFromName(options.TensorFlowLabel, _session);
- else
- (labelOp, labelOpIdx) = GetOperationFromName(_labelTensor.name, _session);
-
+ (labelOp, labelOpIdx) = GetOperationFromName(options.TensorFlowLabel, _session);
TF_Output[] tfInputs;
-
- if (options.ReTrain && !string.IsNullOrEmpty(options.LearningRateOperation))
+ if (!string.IsNullOrEmpty(options.LearningRateOperation))
tfInputs = new TF_Output[_tfInputNodes.Length + 2]; //Inputs + Label + Learning Rate.
else
tfInputs = new TF_Output[_tfInputNodes.Length + 1]; //Inputs + Label.
@@ -339,32 +301,13 @@ private void TrainCore(DnnEstimator.Options options, IDataView input, IDataView
Array.Copy(_tfInputNodes, tfInputs, _tfInputNodes.Length);
tfInputs[_tfInputNodes.Length] = new TF_Output(labelOp, labelOpIdx);
-
- if (options.ReTrain)
- {
- var lr = GetOperationFromName(options.LearningRateOperation, _session);
- tfInputs[_tfInputNodes.Length + 1] = new TF_Output(lr.Item1, lr.Item2);
- }
+ var lr = GetOperationFromName(options.LearningRateOperation, _session);
+ tfInputs[_tfInputNodes.Length + 1] = new TF_Output(lr.Item1, lr.Item2);
// Create graph operations.
IntPtr[] ops = null;
- if (options.ReTrain && options.OptimizationOperation != null)
+ if (options.OptimizationOperation != null)
ops = new[] { c_api.TF_GraphOperationByName(Graph, options.OptimizationOperation) };
- else
- ops = new[] { (IntPtr)_trainStep };
-
- Saver trainSaver = null;
- FileWriter trainWriter = null;
- Tensor merged = null;
- Runner testSetRunner = null;
- Runner validationSetRunner = null;
- if (options.TransferLearning)
- {
- merged = tf.summary.merge_all();
- trainWriter = tf.summary.FileWriter(Path.Combine(Directory.GetCurrentDirectory(), "train"), _session.graph);
- trainSaver = tf.train.Saver();
- trainSaver.save(_session, _checkpointPath);
- }
// Instantiate the graph.
Runner runner;
@@ -379,190 +322,53 @@ private void TrainCore(DnnEstimator.Options options, IDataView input, IDataView
using (var ch = Host.Start("Training TensorFlow model..."))
using (var pch = Host.StartProgressChannel("TensorFlow training progress..."))
{
- if (options.ReTrain)
- {
- float loss = 0;
- float metric = 0;
- pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch));
+ float loss = 0;
+ float metric = 0;
+ pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch));
- while (cursor.MoveNext())
- {
- for (int i = 0; i < inputsForTraining.Length; i++)
- {
- isDataLeft = true;
- srcTensorGetters[i].BufferTrainingData();
- }
-
- if (((cursor.Position + 1) % options.BatchSize) == 0)
- {
- isDataLeft = false;
- runner = new Runner(_session);
-
- // Add Learning Rate.
- if (!string.IsNullOrEmpty(options.LearningRateOperation))
- runner.AddInput(options.LearningRateOperation, new Tensor(options.LearningRate));
-
- // Add operations.
- if (!string.IsNullOrEmpty(options.OptimizationOperation))
- runner.AddOperation(options.OptimizationOperation);
-
- // Add outputs.
- if (options.LossOperation != null)
- runner.AddOutputs(options.LossOperation);
- if (options.MetricOperation != null)
- runner.AddOutputs(options.MetricOperation);
-
- var (l, m) = ExecuteGraphAndRetrieveMetrics(inputsForTraining, srcTensorGetters, runner);
- loss += l;
- metric += m;
- }
- }
- if (isDataLeft)
- {
- isDataLeft = false;
- ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize);
- }
- pch.Checkpoint(new double?[] { loss, metric });
- }
- else
+ while (cursor.MoveNext())
{
- pch.SetHeader(new ProgressHeader(null, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch));
-
- while (cursor.MoveNext())
+ for (int i = 0; i < inputsForTraining.Length; i++)
{
- for (int i = 0; i < inputsForTraining.Length; i++)
- {
- isDataLeft = true;
- srcTensorGetters[i].BufferTrainingData();
- }
-
- if (((cursor.Position + 1) % options.BatchSize) == 0)
- {
- isDataLeft = false;
- runner = new Runner(_session);
-
- // Add operations.
- runner.AddOperation(_trainStep);
-
- // Feed inputs.
- for (int i = 0; i < inputsForTraining.Length; i++)
- runner.AddInput(inputsForTraining[i], srcTensorGetters[i].GetBufferedBatchTensor());
-
- // Execute the graph.
- var t = runner.Run();
- }
+ isDataLeft = true;
+ srcTensorGetters[i].BufferTrainingData();
}
- if (isDataLeft)
+ if (((cursor.Position + 1) % options.BatchSize) == 0)
{
isDataLeft = false;
- ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize);
- }
- }
- }
- }
+ runner = new Runner(_session);
- // Measure accuracy of the model.
- if (options.TransferLearning && options.MeasureTrainAccuracy)
- {
- // Test on the training set to get accuracy.
- using (var cursor = input.GetRowCursor(cols))
- {
- var srcTensorGetters = GetTensorValueGetters(cursor, inputColIndices, isInputVector, tfInputTypes, tfInputShapes);
-
- float accuracy = 0;
- float crossEntropy = 0;
- bool isDataLeft = false;
- int batch = 0;
- using (var ch = Host.Start("Test TensorFlow model..."))
- using (var pch = Host.StartProgressChannel("TensorFlow testing progress..."))
- {
- pch.SetHeader(new ProgressHeader(new[] { "Accuracy", "Cross Entropy" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch));
+ // Add Learning Rate.
+ if (!string.IsNullOrEmpty(options.LearningRateOperation))
+ runner.AddInput(options.LearningRateOperation, new Tensor(options.LearningRate));
- while (cursor.MoveNext())
- {
- for (int i = 0; i < inputColIndices.Length; i++)
- {
- isDataLeft = true;
- srcTensorGetters[i].BufferTrainingData();
- }
-
- if (((cursor.Position + 1) % options.BatchSize) == 0)
- {
- isDataLeft = false;
- testSetRunner = new Runner(_session);
- testSetRunner.AddOutputs(_evaluationStep.name);
- testSetRunner.AddOutputs(_crossEntropy.name);
- testSetRunner.AddOutputs(_bottleneckTensor.name);
- var (acc, ce) = ExecuteGraphAndRetrieveMetrics(inputsForTraining, srcTensorGetters, testSetRunner);
- accuracy += acc;
- crossEntropy += ce;
- batch++;
- }
- }
+ // Add operations.
+ if (!string.IsNullOrEmpty(options.OptimizationOperation))
+ runner.AddOperation(options.OptimizationOperation);
- if (isDataLeft)
- {
- isDataLeft = false;
- ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize);
+ // Add outputs.
+ if (options.LossOperation != null)
+ runner.AddOutputs(options.LossOperation);
+ if (options.MetricOperation != null)
+ runner.AddOutputs(options.MetricOperation);
+
+ var (l, m) = ExecuteGraphAndRetrieveMetrics(inputsForTraining, srcTensorGetters, runner);
+ loss += l;
+ metric += m;
}
- pch.Checkpoint(new double?[] { accuracy / batch, crossEntropy / batch });
- ch.Info(MessageSensitivity.None, $"Accuracy: {accuracy / batch}, Cross-Entropy: {crossEntropy / batch}");
}
- }
-
- // Test on the validation set.
- if (validationSet != null)
- {
- using (var cursor = validationSet.GetRowCursor(cols))
+ if (isDataLeft)
{
- var srcTensorGetters = GetTensorValueGetters(cursor, inputColIndices, isInputVector, tfInputTypes, tfInputShapes);
-
- float accuracy = 0;
- bool isDataLeft = false;
- int batch = 0;
- using (var ch = Host.Start("Test TensorFlow model with validation set..."))
- using (var pch = Host.StartProgressChannel("TensorFlow validation progress..."))
- {
- pch.SetHeader(new ProgressHeader(new[] { "Accuracy" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch));
-
- while (cursor.MoveNext())
- {
- for (int i = 0; i < inputColIndices.Length; i++)
- {
- isDataLeft = true;
- srcTensorGetters[i].BufferTrainingData();
- }
-
- if (((cursor.Position + 1) % options.BatchSize) == 0)
- {
- isDataLeft = false;
- validationSetRunner = new Runner(_session);
- validationSetRunner.AddOutputs(_evaluationStep.name);
- var (acc, _) = ExecuteGraphAndRetrieveMetrics(inputsForTraining, srcTensorGetters, validationSetRunner);
- accuracy += acc;
- batch++;
- }
- }
- if (isDataLeft)
- {
- isDataLeft = false;
- ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize);
- }
- pch.Checkpoint(new double?[] { accuracy / batch });
- }
+ isDataLeft = false;
+ ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize);
}
+ pch.Checkpoint(new double?[] { loss, metric });
}
}
}
- if (options.ReTrain)
- UpdateModelOnDisk(options.ModelLocation, options);
- else
- {
- trainSaver.save(_session, _checkpointPath);
- UpdateTransferLearningModelOnDisk(options, _classCount);
- }
+ UpdateModelOnDisk(options.ModelLocation, options);
}
private (float loss, float metric) ExecuteGraphAndRetrieveMetrics(
@@ -588,7 +394,7 @@ private void TrainCore(DnnEstimator.Options options, IDataView input, IDataView
/// After retraining Session and Graphs are both up-to-date
/// However model on disk is not which is used to serialzed to ML.Net stream
///
- private void UpdateModelOnDisk(string modelDir, DnnEstimator.Options options)
+ private void UpdateModelOnDisk(string modelDir, DnnRetrainEstimator.Options options)
{
try
{
@@ -648,150 +454,6 @@ private void UpdateModelOnDisk(string modelDir, DnnEstimator.Options options)
}
}
- private (Session, Tensor, Tensor, Tensor) BuildEvaluationSession(DnnEstimator.Options options, int classCount)
- {
- var evalGraph = DnnUtils.LoadMetaGraph(options.ModelLocation);
- var evalSess = tf.Session(graph: evalGraph);
- Tensor evaluationStep = null;
- Tensor prediction = null;
- Tensor bottleneckTensor = evalGraph.OperationByName(_bottleneckOperationName);
-
- tf_with(evalGraph.as_default(), graph =>
- {
- var (_, _, groundTruthInput, finalTensor) = AddFinalRetrainOps(classCount, options.LabelColumn,
- options.ScoreColumnName, options.LearningRate, bottleneckTensor, false);
-
- tf.train.Saver().restore(evalSess, Path.Combine(Directory.GetCurrentDirectory(), _checkpointPath));
- (evaluationStep, prediction) = AddEvaluationStep(finalTensor, groundTruthInput);
- });
-
- return (evalSess, _labelTensor, evaluationStep, prediction);
- }
-
- private (Tensor, Tensor) AddEvaluationStep(Tensor resultTensor, Tensor groundTruthTensor)
- {
- Tensor evaluationStep = null;
- Tensor correctPrediction = null;
-
- tf_with(tf.name_scope("accuracy"), scope =>
- {
- tf_with(tf.name_scope("correct_prediction"), delegate
- {
- _prediction = tf.argmax(resultTensor, 1);
- correctPrediction = tf.equal(_prediction, groundTruthTensor);
- });
-
- tf_with(tf.name_scope("accuracy"), delegate
- {
- evaluationStep = tf.reduce_mean(tf.cast(correctPrediction, tf.float32));
- });
- });
-
- tf.summary.scalar("accuracy", evaluationStep);
- return (evaluationStep, _prediction);
- }
-
- private void UpdateTransferLearningModelOnDisk(DnnEstimator.Options options, int classCount)
- {
- var (sess, _, _, _) = BuildEvaluationSession(options, classCount);
- var graph = sess.graph;
- var outputGraphDef = tf.graph_util.convert_variables_to_constants(
- sess, graph.as_graph_def(), new string[] { _softMaxTensor.name.Split(':')[0], _prediction.name.Split(':')[0] });
-
- string frozenModelPath = _checkpointPath + ".pb";
- File.WriteAllBytes(_checkpointPath + ".pb", outputGraphDef.ToByteArray());
- _session = LoadTFSessionByModelFilePath(_env, frozenModelPath, false);
- }
-
- private void VariableSummaries(RefVariable var)
- {
- tf_with(tf.name_scope("summaries"), delegate
- {
- var mean = tf.reduce_mean(var);
- tf.summary.scalar("mean", mean);
- Tensor stddev = null;
- tf_with(tf.name_scope("stddev"), delegate
- {
- stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)));
- });
- tf.summary.scalar("stddev", stddev);
- tf.summary.scalar("max", tf.reduce_max(var));
- tf.summary.scalar("min", tf.reduce_min(var));
- tf.summary.histogram("histogram", var);
- });
- }
-
- private (Operation, Tensor, Tensor, Tensor) AddFinalRetrainOps(int classCount, string labelColumn,
- string scoreColumnName, float learningRate, Tensor bottleneckTensor, bool isTraining)
- {
- var (batch_size, bottleneck_tensor_size) = (bottleneckTensor.TensorShape.Dimensions[0], bottleneckTensor.TensorShape.Dimensions[1]);
- tf_with(tf.name_scope("input"), scope =>
- {
- _labelTensor = tf.placeholder(tf.int64, new TensorShape(batch_size), name: labelColumn);
- });
-
- string layerName = "final_retrain_ops";
- Tensor logits = null;
- tf_with(tf.name_scope(layerName), scope =>
- {
- RefVariable layerWeights = null;
- tf_with(tf.name_scope("weights"), delegate
- {
- var initialValue = tf.truncated_normal(new int[] { bottleneck_tensor_size, classCount }, stddev: 0.001f);
- layerWeights = tf.Variable(initialValue, name: "final_weights");
- VariableSummaries(layerWeights);
- });
-
- RefVariable layerBiases = null;
- tf_with(tf.name_scope("biases"), delegate
- {
- layerBiases = tf.Variable(tf.zeros(classCount), name: "final_biases");
- VariableSummaries(layerBiases);
- });
-
- tf_with(tf.name_scope("Wx_plus_b"), delegate
- {
- var matmul = tf.matmul(bottleneckTensor, layerWeights);
- logits = matmul + layerBiases;
- tf.summary.histogram("pre_activations", logits);
- });
- });
-
- _softMaxTensor = tf.nn.softmax(logits, name: scoreColumnName);
-
- tf.summary.histogram("activations", _softMaxTensor);
- if (!isTraining)
- return (null, null, _labelTensor, _softMaxTensor);
-
- Tensor crossEntropyMean = null;
- tf_with(tf.name_scope("cross_entropy"), delegate
- {
- crossEntropyMean = tf.losses.sparse_softmax_cross_entropy(
- labels: _labelTensor, logits: logits);
- });
-
- tf.summary.scalar("cross_entropy", crossEntropyMean);
-
- tf_with(tf.name_scope("train"), delegate
- {
- var optimizer = tf.train.GradientDescentOptimizer(learningRate);
- _trainStep = optimizer.minimize(crossEntropyMean);
- });
-
- return (_trainStep, crossEntropyMean, _labelTensor, _softMaxTensor);
- }
-
- private void AddTransferLearningLayer(string labelColumn,
- string scoreColumnName, float learningRate, int classCount)
- {
- _bottleneckTensor = Graph.OperationByName(_bottleneckOperationName);
- tf_with(Graph.as_default(), delegate
- {
- (_trainStep, _crossEntropy, _labelTensor, _softMaxTensor) =
- AddFinalRetrainOps(classCount, labelColumn, scoreColumnName, learningRate, _bottleneckTensor, true);
- });
- }
-
private static ITensorValueGetter CreateTensorValueGetter(DataViewRow input, bool isVector, int colIndex, TensorShape tfShape, bool keyType = false)
{
if (isVector)
@@ -833,9 +495,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat
=> Create(env, ctx).MakeRowMapper(inputSchema);
private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs,
- out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput, out bool transferLearning,
- out string labelColumn, out string checkpointName, out Architecture arch,
- out string scoreColumnName, out string predictedColumnName, out float learningRate, out int classCount, out string predictionTensorName, out string softMaxTensorName)
+ out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput)
{
isFrozen = ctx.Reader.ReadBoolByte();
addBatchDimensionInput = ctx.Reader.ReadBoolByte();
@@ -851,26 +511,12 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out
outputs = new string[numOutputs];
for (int j = 0; j < outputs.Length; j++)
outputs[j] = ctx.LoadNonEmptyString();
-
- transferLearning = ctx.Reader.ReadBoolean();
- labelColumn = ctx.Reader.ReadString();
- checkpointName = ctx.Reader.ReadString();
- arch = (Architecture)ctx.Reader.ReadInt32();
- scoreColumnName = ctx.Reader.ReadString();
- predictedColumnName = ctx.Reader.ReadString();
- learningRate = ctx.Reader.ReadFloat();
- classCount = ctx.Reader.ReadInt32();
- predictionTensorName = ctx.Reader.ReadString();
- softMaxTensorName = ctx.Reader.ReadString();
-
}
- internal DnnTransformer(IHostEnvironment env, Session session, string[] outputColumnNames,
+ internal DnnRetrainTransformer(IHostEnvironment env, Session session, string[] outputColumnNames,
string[] inputColumnNames, string modelLocation, bool isTemporarySavedModel,
- bool addBatchDimensionInput, int batchSize, bool transferLearning, string labelColumnName, string checkpointName, Architecture arch,
- string scoreColumnName, string predictedLabelColumnName, float learningRate, DataViewSchema inputSchema, int? classCount = null, bool loadModel = false,
- string predictionTensorName = null, string softMaxTensorName = null)
- : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(DnnTransformer)))
+ bool addBatchDimensionInput, int batchSize)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(DnnRetrainTransformer)))
{
Host.CheckValue(session, nameof(session));
@@ -885,76 +531,15 @@ internal DnnTransformer(IHostEnvironment env, Session session, string[] outputCo
_inputs = inputColumnNames;
_outputs = outputColumnNames;
_idvToTfMapping = new Dictionary();
- _transferLearning = transferLearning;
- _labelColumnName = labelColumnName;
- _checkpointName = checkpointName;
- _arch = arch;
- _scoreColumnName = scoreColumnName;
- _predictedLabelColumnName = predictedLabelColumnName;
- _learningRate = learningRate;
- _softmaxTensorName = softMaxTensorName;
- _predictionTensorName = predictionTensorName;
- if (transferLearning)
- {
- if (classCount == null)
- {
- var labelColumn = inputSchema.GetColumnOrNull(labelColumnName).Value;
- var labelType = labelColumn.Type;
- var labelCount = labelType.GetKeyCount();
- if (labelCount <= 0)
- throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", (string)labelColumn.Name, "Key", (string)labelType.ToString());
-
- _classCount = labelCount == 1 ? 2 : (int)labelCount;
- }
- else
- _classCount = classCount.Value;
- _checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), modelLocation + checkpointName);
+ foreach (var x in _inputs)
+ _idvToTfMapping[x] = x;
- // Configure bottleneck tensor based on the model.
- if (arch == DnnEstimator.Architecture.ResnetV2101)
- _bottleneckOperationName = "resnet_v2_101/SpatialSqueeze";
- else if(arch == DnnEstimator.Architecture.InceptionV3)
- _bottleneckOperationName = "module_apply_default/hub_output/feature_vector/SpatialSqueeze";
+ foreach (var x in _outputs)
+ _idvToTfMapping[x] = x;
- if (arch == DnnEstimator.Architecture.ResnetV2101)
- _idvToTfMapping[_inputs[0]] = "input";
- else if (arch == DnnEstimator.Architecture.InceptionV3)
- _idvToTfMapping[_inputs[0]] = "Placeholder";
+ (_tfOutputTypes, _outputTypes, _tfOutputOperations) = GetOutputInfo(Host, _session, _outputs);
- _outputs = new[] { scoreColumnName, predictedLabelColumnName };
-
- if (loadModel == false)
- {
- // Add transfer learning layer.
- AddTransferLearningLayer(labelColumnName, scoreColumnName, learningRate, _classCount);
-
- // Initialize the variables.
- new Runner(_session).AddOperation(tf.global_variables_initializer()).Run();
-
- // Add evaluation layer.
- (_evaluationStep, _) = AddEvaluationStep(_softMaxTensor, _labelTensor);
- _softmaxTensorName = _softMaxTensor.name;
- _predictionTensorName = _prediction.name;
- }
-
- _idvToTfMapping[scoreColumnName] = _softmaxTensorName;
- _idvToTfMapping[predictedLabelColumnName] = _predictionTensorName;
-
- (_tfOutputTypes, _outputTypes, _tfOutputOperations) = GetOutputInfo(Host, _session, new[] { _softmaxTensorName, _predictionTensorName });
- _transferLearning = true;
- }
- else
- {
- foreach (var x in _inputs)
- _idvToTfMapping[x] = x;
-
- foreach (var x in _outputs)
- _idvToTfMapping[x] = x;
-
- (_tfOutputTypes, _outputTypes, _tfOutputOperations) = GetOutputInfo(Host, _session, _outputs);
-
- }
(_tfInputTypes, _tfInputShapes, _tfInputOperations) = GetInputInfo(Host, _session, _inputs.Select(x => _idvToTfMapping[x]).ToArray(), batchSize);
_tfInputNodes = new TF_Output[_inputs.Length];
@@ -1093,7 +678,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
// for each output column
// int: id of output column name
// stream: tensorFlow model.
- var isFrozen = _transferLearning || DnnUtils.IsSavedModel(_env, _modelLocation);
+ var isFrozen = DnnUtils.IsSavedModel(_env, _modelLocation);
ctx.Writer.WriteBoolByte(isFrozen);
ctx.Writer.WriteBoolByte(_addBatchDimensionInput);
@@ -1107,58 +692,35 @@ private protected override void SaveModel(ModelSaveContext ctx)
foreach (var colName in _outputs)
ctx.SaveNonEmptyString(colName);
- ctx.Writer.Write(_transferLearning);
- ctx.Writer.Write(_labelColumnName);
- ctx.Writer.Write(_checkpointName);
- ctx.Writer.Write((int)_arch);
- ctx.Writer.Write(_scoreColumnName);
- ctx.Writer.Write(_predictedLabelColumnName);
- ctx.Writer.Write(_learningRate);
- ctx.Writer.Write(_classCount);
- ctx.Writer.Write(_predictionTensorName);
- ctx.Writer.Write(_softmaxTensorName);
-
- if (isFrozen || _transferLearning)
- {
- Status status = new Status();
- var buffer = _session.graph.ToGraphDef(status);
- ctx.SaveBinaryStream("TFModel", w =>
- {
- w.WriteByteArray(buffer.Data);
- });
- }
- else
+ ctx.SaveBinaryStream("TFSavedModel", w =>
{
- ctx.SaveBinaryStream("TFSavedModel", w =>
+ // only these files need to be saved.
+ string[] modelFilePaths =
{
- // only these files need to be saved.
- string[] modelFilePaths =
- {
- Path.Combine(_modelLocation, DefaultModelFileNames.Graph),
- Path.Combine(_modelLocation, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Data),
- Path.Combine(_modelLocation, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index),
- };
+ Path.Combine(_modelLocation, DefaultModelFileNames.Graph),
+ Path.Combine(_modelLocation, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Data),
+ Path.Combine(_modelLocation, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index),
+ };
- w.Write(modelFilePaths.Length);
+ w.Write(modelFilePaths.Length);
- foreach (var fullPath in modelFilePaths)
- {
- var relativePath = fullPath.Substring(_modelLocation.Length + 1);
- w.Write(relativePath);
+ foreach (var fullPath in modelFilePaths)
+ {
+ var relativePath = fullPath.Substring(_modelLocation.Length + 1);
+ w.Write(relativePath);
- using (var fs = new FileStream(fullPath, FileMode.Open))
- {
- long fileLength = fs.Length;
- w.Write(fileLength);
- long actualWritten = fs.CopyRange(w.BaseStream, fileLength);
- Host.Assert(actualWritten == fileLength);
- }
+ using (var fs = new FileStream(fullPath, FileMode.Open))
+ {
+ long fileLength = fs.Length;
+ w.Write(fileLength);
+ long actualWritten = fs.CopyRange(w.BaseStream, fileLength);
+ Host.Assert(actualWritten == fileLength);
}
- });
- }
+ }
+ });
}
- ~DnnTransformer()
+ ~DnnRetrainTransformer()
{
Dispose(false);
}
@@ -1187,13 +749,13 @@ private void Dispose(bool disposing)
private sealed class Mapper : MapperBase
{
- private readonly DnnTransformer _parent;
+ private readonly DnnRetrainTransformer _parent;
private readonly int[] _inputColIndices;
private readonly bool[] _isInputVector;
private readonly TensorShape[] _fullySpecifiedShapes;
private readonly ConcurrentBag _runners;
- public Mapper(DnnTransformer parent, DataViewSchema inputSchema) :
+ public Mapper(DnnRetrainTransformer parent, DataViewSchema inputSchema) :
base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
{
Host.CheckValue(parent, nameof(parent));
@@ -1612,28 +1174,11 @@ public Tensor GetBufferedBatchTensor()
}
}
- ///
- public sealed class DnnEstimator : IEstimator
+ ///
+ public sealed class DnnRetrainEstimator : IEstimator
{
///
- /// Image classification model.
- ///
- public enum Architecture
- {
- ResnetV2101,
- InceptionV3
- };
-
- ///
- /// Backend DNN training framework.
- ///
- public enum DnnFramework
- {
- Tensorflow
- };
-
- ///
- /// The options for the .
+ /// The options for the .
///
internal sealed class Options : TransformInputBase
{
@@ -1729,12 +1274,6 @@ internal sealed class Options : TransformInputBase
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the input in TensorFlow graph that specifiy the location for saving/restoring models from disk.", SortOrder = 14)]
public string SaveOperation = "save/control_dependency";
- ///
- /// Needed for command line to specify if retraining is requested.
- ///
- [Argument(ArgumentType.AtMostOnce, HelpText = "Retrain TensorFlow model.", SortOrder = 15)]
- public bool ReTrain = false;
-
///
/// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
///
@@ -1744,42 +1283,6 @@ internal sealed class Options : TransformInputBase
///
[Argument(ArgumentType.AtMostOnce, HelpText = "Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].", SortOrder = 16)]
public bool AddBatchDimensionInputs = false;
-
- ///
- /// Indicates if transfer learning is requested.
- ///
- [Argument(ArgumentType.AtMostOnce, HelpText = "Transfer learning on a model.", SortOrder = 15)]
- public bool TransferLearning = false;
-
- ///
- /// Specifies the model architecture to be used in the case of image classification training using transfer learning.
- ///
- [Argument(ArgumentType.AtMostOnce, HelpText = "Model architecture to be used in transfer learning for image classification.", SortOrder = 15)]
- public Architecture Arch = Architecture.ResnetV2101;
-
- ///
- /// Name of the tensor that will contain the output scores of the last layer when transfer learning is done.
- ///
- [Argument(ArgumentType.AtMostOnce, HelpText = "Softmax tensor of the last layer in transfer learning.", SortOrder = 15)]
- public string ScoreColumnName = "Scores";
-
- ///
- /// Name of the tensor that will contain the predicted label from output scores of the last layer when transfer learning is done.
- ///
- [Argument(ArgumentType.AtMostOnce, HelpText = "Argmax tensor of the last layer in transfer learning.", SortOrder = 15)]
- public string PredictedLabelColumnName = "PredictedLabel";
-
- ///
- /// Checkpoint folder to store graph files in the event of transfer learning.
- ///
- [Argument(ArgumentType.AtMostOnce, HelpText = "Checkpoint folder to store graph files in the event of transfer learning.", SortOrder = 15)]
- public string CheckpointName = "_retrain_checkpoint";
-
- ///
- /// Use train set to measure model accuracy between each epoch.
- ///
- [Argument(ArgumentType.AtMostOnce, HelpText = "Use train set to measure model accuracy between each epoch.", SortOrder = 15)]
- public bool MeasureTrainAccuracy = false;
}
private readonly IHost _host;
@@ -1787,25 +1290,16 @@ internal sealed class Options : TransformInputBase
private readonly DnnModel _tensorFlowModel;
private readonly TF_DataType[] _tfInputTypes;
private readonly DataViewType[] _outputTypes;
- private DnnTransformer _transformer;
+ private DnnRetrainTransformer _transformer;
- internal DnnEstimator(IHostEnvironment env, Options options, DnnModel tensorFlowModel)
+ internal DnnRetrainEstimator(IHostEnvironment env, Options options, DnnModel tensorFlowModel)
{
- _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(DnnEstimator));
+ _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(DnnRetrainEstimator));
_options = options;
_tensorFlowModel = tensorFlowModel;
-
- if (options.TransferLearning)
- _tfInputTypes = new[] { TF_DataType.TF_FLOAT };
- else
- {
- var inputTuple = DnnTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns);
- _tfInputTypes = inputTuple.tfInputTypes;
- }
- if (options.TransferLearning)
- _outputTypes = new[] { new VectorDataViewType(NumberDataViewType.Single), new VectorDataViewType(NumberDataViewType.Single, 1) };
- else
- _outputTypes = DnnTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns).outputTypes;
+ var inputTuple = DnnRetrainTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns);
+ _tfInputTypes = inputTuple.tfInputTypes;
+ _outputTypes = DnnRetrainTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns).outputTypes;
}
private static Options CreateArguments(DnnModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput)
@@ -1814,7 +1308,6 @@ private static Options CreateArguments(DnnModel tensorFlowModel, string[] output
options.ModelLocation = tensorFlowModel.ModelPath;
options.InputColumns = inputColumnName;
options.OutputColumns = outputColumnNames;
- options.ReTrain = false;
options.AddBatchDimensionInputs = addBatchDimensionInput;
return options;
}
@@ -1849,13 +1342,13 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
}
///
- /// Trains and returns a .
+ /// Trains and returns a .
///
- public DnnTransformer Fit(IDataView input)
+ public DnnRetrainTransformer Fit(IDataView input)
{
_host.CheckValue(input, nameof(input));
if (_transformer == null)
- _transformer = new DnnTransformer(_host, _options, _tensorFlowModel, input);
+ _transformer = new DnnRetrainTransformer(_host, _options, _tensorFlowModel, input);
// Validate input schema.
_transformer.GetOutputSchema(input.Schema);
diff --git a/src/Microsoft.ML.Dnn/DnnUtils.cs b/src/Microsoft.ML.Dnn/DnnUtils.cs
index 3aa727f41d..623e103997 100644
--- a/src/Microsoft.ML.Dnn/DnnUtils.cs
+++ b/src/Microsoft.ML.Dnn/DnnUtils.cs
@@ -381,6 +381,25 @@ public Runner AddInput(string input, Tensor value)
return this;
}
+ public Runner AddInput(string input)
+ {
+ _inputs.Add(ParseOutput(input));
+ return this;
+ }
+
+ public Runner AddInput(Tensor value, int index)
+ {
+ if (_inputValues.Count <= index)
+ _inputValues.Add(value);
+ else
+ {
+ _inputValues[index].Dispose();
+ _inputValues[index] = value;
+ }
+
+ return this;
+ }
+
public Runner AddOutputs(string output)
{
_outputs.Add(ParseOutput(output));
@@ -444,10 +463,6 @@ public Tensor[] Run()
return result;
}
- public Runner CloneRunner()
- {
- return new Runner(_session);
- }
}
}
diff --git a/src/Microsoft.ML.Dnn/ImageClassificationTransform.cs b/src/Microsoft.ML.Dnn/ImageClassificationTransform.cs
new file mode 100644
index 0000000000..4729731972
--- /dev/null
+++ b/src/Microsoft.ML.Dnn/ImageClassificationTransform.cs
@@ -0,0 +1,1222 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Runtime.InteropServices;
+using Google.Protobuf;
+using Microsoft.ML;
+using Microsoft.ML.CommandLine;
+using Microsoft.ML.Data;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Transforms;
+using Microsoft.ML.Transforms.Dnn;
+using Tensorflow;
+using Tensorflow.Summaries;
+using static Microsoft.ML.Data.TextLoader;
+using static Microsoft.ML.Transforms.Dnn.DnnUtils;
+using static Microsoft.ML.Transforms.ImageClassificationEstimator;
+using static Tensorflow.Python;
+using Architecture = Microsoft.ML.Transforms.ImageClassificationEstimator.Architecture;
+
+[assembly: LoadableClass(ImageClassificationTransformer.Summary, typeof(IDataTransform), typeof(ImageClassificationTransformer),
+ typeof(ImageClassificationEstimator.Options), typeof(SignatureDataTransform), ImageClassificationTransformer.UserName, ImageClassificationTransformer.ShortName)]
+
+[assembly: LoadableClass(ImageClassificationTransformer.Summary, typeof(IDataTransform), typeof(ImageClassificationTransformer), null, typeof(SignatureLoadDataTransform),
+ ImageClassificationTransformer.UserName, ImageClassificationTransformer.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(ImageClassificationTransformer), null, typeof(SignatureLoadModel),
+ ImageClassificationTransformer.UserName, ImageClassificationTransformer.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageClassificationTransformer), null, typeof(SignatureLoadRowMapper),
+ ImageClassificationTransformer.UserName, ImageClassificationTransformer.LoaderSignature)]
+
+namespace Microsoft.ML.Transforms
+{
+ ///
+ /// for the .
+ ///
+ public sealed class ImageClassificationTransformer : RowToRowTransformerBase
+ {
+ private readonly IHostEnvironment _env;
+ private readonly bool _addBatchDimensionInput;
+ private Session _session;
+ private Tensor _bottleneckTensor;
+ private Operation _trainStep;
+ private Tensor _softMaxTensor;
+ private Tensor _crossEntropy;
+ private Tensor _labelTensor;
+ private Tensor _evaluationStep;
+ private Tensor _prediction;
+ private Tensor _bottleneckInput;
+ private Tensor _jpegData;
+ private Tensor _resizedImage;
+ private string _jpegDataTensorName;
+ private string _resizedImageTensorName;
+ private string _inputTensorName;
+ private readonly int _classCount;
+ private readonly string _checkpointPath;
+ private readonly string _bottleneckOperationName;
+ private Graph Graph => _session.graph;
+ private readonly string[] _inputs;
+ private readonly string[] _outputs;
+ private readonly string _labelColumnName;
+ private readonly string _finalModelPrefix;
+ private readonly Architecture _arch;
+ private readonly string _scoreColumnName;
+ private readonly string _predictedLabelColumnName;
+ private readonly float _learningRate;
+ private readonly string _softmaxTensorName;
+ private readonly string _predictionTensorName;
+ internal const string Summary = "Trains Dnn models.";
+ internal const string UserName = "ImageClassificationTransform";
+ internal const string ShortName = "ImgClsTrans";
+ internal const string LoaderSignature = "ImageClassificationTrans";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "IMGTRANS",
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00000001,
+ verReadableCur: 0x00000001,
+ verWeCanReadBack: 0x00000001,
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(ImageClassificationTransformer).Assembly.FullName);
+ }
+
+ // Factory method for SignatureLoadModel.
+ private static ImageClassificationTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+
+ // *** Binary format ***
+ // byte: indicator for frozen models
+ // byte: indicator for adding batch dimension in input
+ // int: number of input columns
+ // for each input column
+ // int: id of int column name
+ // int: number of output columns
+ // for each output column
+ // int: id of output column name
+ // stream: tensorFlow model.
+
+ GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool addBatchDimensionInput,
+ out string labelColumn, out string checkpointName, out Architecture arch, out string scoreColumnName,
+ out string predictedColumnName, out float learningRate, out int classCount, out string predictionTensorName, out string softMaxTensorName,
+ out string jpegDataTensorName, out string resizeTensorName);
+
+ byte[] modelBytes = null;
+ if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
+ throw env.ExceptDecode();
+
+ return new ImageClassificationTransformer(env, DnnUtils.LoadTFSession(env, modelBytes), outputs, inputs,
+ null, addBatchDimensionInput, 1, labelColumn, checkpointName, arch,
+ scoreColumnName, predictedColumnName, learningRate, null, classCount, true, predictionTensorName,
+ softMaxTensorName, jpegDataTensorName, resizeTensorName);
+
+ }
+
+ // Factory method for SignatureDataTransform.
+ internal static IDataTransform Create(IHostEnvironment env, ImageClassificationEstimator.Options options, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(options, nameof(options));
+ env.CheckValue(input, nameof(input));
+ env.CheckValue(options.InputColumns, nameof(options.InputColumns));
+ env.CheckValue(options.OutputColumns, nameof(options.OutputColumns));
+
+ return new ImageClassificationTransformer(env, options, input).MakeDataTransform(input);
+ }
+
+ internal ImageClassificationTransformer(IHostEnvironment env, ImageClassificationEstimator.Options options, IDataView input)
+ : this(env, options, DnnUtils.LoadDnnModel(env, options.ModelLocation), input)
+ {
+ }
+
+ internal ImageClassificationTransformer(IHostEnvironment env, ImageClassificationEstimator.Options options, DnnModel tensorFlowModel, IDataView input)
+ : this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns,
+ options.ModelLocation, null, options.BatchSize,
+ options.LabelColumn, options.FinalModelPrefix, options.Arch, options.ScoreColumnName,
+ options.PredictedLabelColumnName, options.LearningRate, input.Schema)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(options, nameof(options));
+ env.CheckValue(input, nameof(input));
+ CheckTrainingParameters(options);
+ var imageProcessor = new ImageProcessor(this);
+ if (!options.ReuseTrainSetBottleneckCachedValues || !File.Exists(options.TrainSetBottleneckCachedValuesFilePath))
+ CacheFeaturizedImagesToDisk(input, options.LabelColumn, options.InputColumns[0], imageProcessor,
+ _inputTensorName, _bottleneckTensor.name, options.TrainSetBottleneckCachedValuesFilePath,
+ ImageClassificationMetrics.Dataset.Train, options.MetricsCallback);
+
+ if (options.ValidationSet != null &&
+ (!options.ReuseTrainSetBottleneckCachedValues || !File.Exists(options.ValidationSetBottleneckCachedValuesFilePath)))
+ CacheFeaturizedImagesToDisk(options.ValidationSet, options.LabelColumn, options.InputColumns[0],
+ imageProcessor, _inputTensorName, _bottleneckTensor.name, options.ValidationSetBottleneckCachedValuesFilePath,
+ ImageClassificationMetrics.Dataset.Validation, options.MetricsCallback);
+
+ TrainAndEvaluateClassificationLayer(options.TrainSetBottleneckCachedValuesFilePath, options, options.ValidationSetBottleneckCachedValuesFilePath);
+ }
+
+ private void CheckTrainingParameters(ImageClassificationEstimator.Options options)
+ {
+ Host.CheckNonWhiteSpace(options.LabelColumn, nameof(options.LabelColumn));
+ Host.CheckNonWhiteSpace(options.TensorFlowLabel, nameof(options.TensorFlowLabel));
+
+ if (_session.graph.OperationByName(options.TensorFlowLabel) == null)
+ throw Host.ExceptParam(nameof(options.TensorFlowLabel), $"'{options.TensorFlowLabel}' does not exist in the model");
+ }
+
+ private (Tensor, Tensor) AddJpegDecoding(int height, int width, int depth)
+ {
+ // height, width, depth
+ var inputDim = (height, width, depth);
+ var jpegData = tf.placeholder(tf.@string, name: "DecodeJPGInput");
+ var decodedImage = tf.image.decode_jpeg(jpegData, channels: inputDim.Item3);
+ // Convert from full range of uint8 to range [0,1] of float32.
+ var decodedImageAsFloat = tf.image.convert_image_dtype(decodedImage, tf.float32);
+ var decodedImage4d = tf.expand_dims(decodedImageAsFloat, 0);
+ var resizeShape = tf.stack(new int[] { inputDim.Item1, inputDim.Item2 });
+ var resizeShapeAsInt = tf.cast(resizeShape, dtype: tf.int32);
+ var resizedImage = tf.image.resize_bilinear(decodedImage4d, resizeShapeAsInt, false, "ResizeTensor");
+ return (jpegData, resizedImage);
+ }
+
+ private sealed class ImageProcessor
+ {
+ private Runner _imagePreprocessingRunner;
+
+ public ImageProcessor(ImageClassificationTransformer transformer)
+ {
+ _imagePreprocessingRunner = new Runner(transformer._session);
+ _imagePreprocessingRunner.AddInput(transformer._jpegDataTensorName);
+ _imagePreprocessingRunner.AddOutputs(transformer._resizedImageTensorName);
+ }
+
+ public Tensor ProcessImage(string path)
+ {
+ var imageTensor = new Tensor(File.ReadAllBytes(path), TF_DataType.TF_STRING);
+ var processedTensor = _imagePreprocessingRunner.AddInput(imageTensor, 0).Run()[0];
+ imageTensor.Dispose();
+ return processedTensor;
+ }
+ }
+
+ private void CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName, string imagepathColumnName,
+ ImageProcessor imageProcessor, string inputTensorName, string outputTensorName, string cacheFilePath,
+ ImageClassificationMetrics.Dataset dataset, ImageClassificationMetricsCallback metricsCallback)
+ {
+ var labelColumn = input.Schema[labelColumnName];
+
+ if (labelColumn.Type.RawType != typeof(UInt32))
+ throw Host.ExceptSchemaMismatch(nameof(labelColumn), "Label",
+ labelColumnName, typeof(uint).ToString(),
+ labelColumn.Type.RawType.ToString());
+
+ var imagePathColumn = input.Schema[imagepathColumnName];
+ Runner runner = new Runner(_session);
+ runner.AddOutputs(outputTensorName);
+
+ using (TextWriter writer = File.CreateText(cacheFilePath))
+ using (var cursor = input.GetRowCursor(input.Schema.Where(c => c.Index == labelColumn.Index || c.Index == imagePathColumn.Index)))
+ {
+ var labelGetter = cursor.GetGetter(labelColumn);
+ var imagePathGetter = cursor.GetGetter>(imagePathColumn);
+ UInt32 label = UInt32.MaxValue;
+ ReadOnlyMemory imagePath = default;
+ runner.AddInput(inputTensorName);
+ ImageClassificationMetrics metrics = new ImageClassificationMetrics();
+ metrics.Bottleneck = new BottleneckMetrics();
+ metrics.Bottleneck.DatasetUsed = dataset;
+ while (cursor.MoveNext())
+ {
+ labelGetter(ref label);
+ imagePathGetter(ref imagePath);
+ var imagePathStr = imagePath.ToString();
+ var imageTensor = imageProcessor.ProcessImage(imagePathStr);
+ runner.AddInput(imageTensor, 0);
+ var featurizedImage = runner.Run()[0]; // Reuse memory?
+ writer.WriteLine(label - 1 + "," + string.Join(",", featurizedImage.Data()));
+ featurizedImage.Dispose();
+ imageTensor.Dispose();
+ metrics.Bottleneck.Index++;
+ metrics.Bottleneck.Name = imagePathStr;
+ metricsCallback?.Invoke(metrics);
+ }
+ }
+ }
+
+ private IDataView GetShuffledData(string path)
+ {
+ return new RowShufflingTransformer(
+ _env,
+ new RowShufflingTransformer.Options
+ {
+ ForceShuffle = true,
+ ForceShuffleSource = true
+ },
+ new TextLoader(
+ _env,
+ new TextLoader.Options
+ {
+ Separators = new[] { ',' },
+ Columns = new[]
+ {
+ new Column("Label", DataKind.Int64, 0),
+ new Column("Features", DataKind.Single, new [] { new Range(1, null) }),
+ },
+ },
+ new MultiFileSource(path))
+ .Load(new MultiFileSource(path)));
+ }
+
+ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath, ImageClassificationEstimator.Options options,
+ string validationSetBottleneckFilePath)
+ {
+ int batchSize = options.BatchSize;
+ int epochs = options.Epoch;
+ bool evaluateOnly = !string.IsNullOrEmpty(validationSetBottleneckFilePath);
+ ImageClassificationMetricsCallback statisticsCallback = options.MetricsCallback;
+ var trainingSet = GetShuffledData(trainBottleneckFilePath);
+ IDataView validationSet = null;
+ if (options.ValidationSet != null && !string.IsNullOrEmpty(validationSetBottleneckFilePath))
+ validationSet = GetShuffledData(validationSetBottleneckFilePath);
+
+ long label = long.MaxValue;
+ VBuffer features = default;
+ ReadOnlySpan featureValues = default;
+ var featureColumn = trainingSet.Schema[1];
+ int featureLength = featureColumn.Type.GetVectorSize();
+ float[] featureBatch = new float[featureLength * batchSize];
+ var featureBatchHandle = GCHandle.Alloc(featureBatch, GCHandleType.Pinned);
+ IntPtr featureBatchPtr = featureBatchHandle.AddrOfPinnedObject();
+ int featureBatchSizeInBytes = sizeof(float) * featureBatch.Length;
+ long[] labelBatch = new long[batchSize];
+ var labelBatchHandle = GCHandle.Alloc(labelBatch, GCHandleType.Pinned);
+ IntPtr labelBatchPtr = labelBatchHandle.AddrOfPinnedObject();
+ int labelBatchSizeInBytes = sizeof(long) * labelBatch.Length;
+ var labelTensorShape = _labelTensor.TensorShape.dims.Select(x => (long)x).ToArray();
+ labelTensorShape[0] = batchSize;
+ int batchIndex = 0;
+ var runner = new Runner(_session);
+ var testEvalRunner = new Runner(_session);
+ testEvalRunner.AddOutputs(_evaluationStep.name);
+ testEvalRunner.AddOutputs(_crossEntropy.name);
+
+ Runner validationEvalRunner = null;
+ if (validationSet != null)
+ {
+ validationEvalRunner = new Runner(_session);
+ validationEvalRunner.AddOutputs(_evaluationStep.name);
+ validationEvalRunner.AddInput(_bottleneckInput.name).AddInput(_labelTensor.name);
+ }
+
+ runner.AddOperation(_trainStep);
+ var featureTensorShape = _bottleneckInput.TensorShape.dims.Select(x => (long)x).ToArray();
+ featureTensorShape[0] = batchSize;
+
+ Saver trainSaver = null;
+ FileWriter trainWriter = null;
+ Tensor merged = tf.summary.merge_all();
+ trainWriter = tf.summary.FileWriter(Path.Combine(Directory.GetCurrentDirectory(), "train"), _session.graph);
+ trainSaver = tf.train.Saver();
+ trainSaver.save(_session, _checkpointPath);
+
+ runner.AddInput(_bottleneckInput.name).AddInput(_labelTensor.name);
+ testEvalRunner.AddInput(_bottleneckInput.name).AddInput(_labelTensor.name);
+ Dictionary classStatsTrain = new Dictionary();
+ Dictionary classStatsValidate = new Dictionary();
+ for (int index = 0; index < _classCount; index += 1)
+ classStatsTrain[index] = classStatsValidate[index] = 0;
+
+ ImageClassificationMetrics metrics = new ImageClassificationMetrics();
+ metrics.Train = new TrainMetrics();
+ for (int epoch = 0; epoch < epochs; epoch += 1)
+ {
+ metrics.Train.Accuracy = 0;
+ metrics.Train.CrossEntropy = 0;
+ metrics.Train.BatchProcessedCount = 0;
+ using (var cursor = trainingSet.GetRowCursor(trainingSet.Schema.ToArray(), new Random()))
+ {
+ var labelGetter = cursor.GetGetter(trainingSet.Schema[0]);
+ var featuresGetter = cursor.GetGetter>(featureColumn);
+ while (cursor.MoveNext())
+ {
+ labelGetter(ref label);
+ featuresGetter(ref features);
+ classStatsTrain[label]++;
+
+ if (featureValues == default)
+ featureValues = features.GetValues();
+
+ // Buffer the values.
+ for (int index = 0; index < featureLength; index += 1)
+ featureBatch[batchIndex * featureLength + index] = featureValues[index];
+
+ labelBatch[batchIndex] = label;
+ batchIndex += 1;
+ // Train.
+ if (batchIndex == batchSize)
+ {
+ runner.AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0)
+ .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1)
+ .Run();
+
+ metrics.Train.BatchProcessedCount += 1;
+
+ if (options.TestOnTrainSet && statisticsCallback != null)
+ {
+ var outputTensors = testEvalRunner
+ .AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0)
+ .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1)
+ .Run();
+
+ metrics.Train.Accuracy += outputTensors[0].Data()[0];
+ metrics.Train.CrossEntropy += outputTensors[1].Data()[0];
+
+ outputTensors[0].Dispose();
+ outputTensors[1].Dispose();
+ }
+
+ batchIndex = 0;
+ }
+ }
+
+ if (options.TestOnTrainSet && statisticsCallback != null)
+ {
+ metrics.Train.Epoch = epoch;
+ metrics.Train.Accuracy /= metrics.Train.BatchProcessedCount;
+ metrics.Train.CrossEntropy /= metrics.Train.BatchProcessedCount;
+ metrics.Train.DatasetUsed = ImageClassificationMetrics.Dataset.Train;
+ statisticsCallback(metrics);
+ }
+ }
+
+ if (validationSet == null)
+ continue;
+
+ batchIndex = 0;
+ metrics.Train.BatchProcessedCount = 0;
+ metrics.Train.Accuracy = 0;
+ metrics.Train.CrossEntropy = 0;
+ using (var cursor = validationSet.GetRowCursor(validationSet.Schema.ToArray(), new Random()))
+ {
+ var labelGetter = cursor.GetGetter(validationSet.Schema[0]);
+ var featuresGetter = cursor.GetGetter>(featureColumn);
+ while (cursor.MoveNext())
+ {
+ labelGetter(ref label);
+ featuresGetter(ref features);
+ classStatsValidate[label]++;
+ // Buffer the values.
+ for (int index = 0; index < featureLength; index += 1)
+ featureBatch[batchIndex * featureLength + index] = featureValues[index];
+
+ labelBatch[batchIndex] = label;
+ batchIndex += 1;
+ // Evaluate.
+ if (batchIndex == batchSize)
+ {
+ var outputTensors = validationEvalRunner
+ .AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0)
+ .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1)
+ .Run();
+
+ metrics.Train.Accuracy += outputTensors[0].Data()[0];
+ metrics.Train.BatchProcessedCount += 1;
+ batchIndex = 0;
+
+ outputTensors[0].Dispose();
+ }
+ }
+
+ if (statisticsCallback != null)
+ {
+ metrics.Train.Epoch = epoch;
+ metrics.Train.Accuracy /= metrics.Train.BatchProcessedCount;
+ metrics.Train.DatasetUsed = ImageClassificationMetrics.Dataset.Validation;
+ statisticsCallback(metrics);
+ }
+ }
+ }
+
+ trainSaver.save(_session, _checkpointPath);
+ UpdateTransferLearningModelOnDisk(options, _classCount);
+ }
+
+ private (Session, Tensor, Tensor, Tensor) BuildEvaluationSession(ImageClassificationEstimator.Options options, int classCount)
+ {
+ var evalGraph = DnnUtils.LoadMetaGraph(options.ModelLocation);
+ var evalSess = tf.Session(graph: evalGraph);
+ Tensor evaluationStep = null;
+ Tensor prediction = null;
+ Tensor bottleneckTensor = evalGraph.OperationByName(_bottleneckOperationName);
+
+ tf_with(evalGraph.as_default(), graph =>
+ {
+ var (_, _, groundTruthInput, finalTensor) = AddFinalRetrainOps(classCount, options.LabelColumn,
+ options.ScoreColumnName, options.LearningRate, bottleneckTensor, false);
+
+ tf.train.Saver().restore(evalSess, _checkpointPath);
+ (evaluationStep, prediction) = AddEvaluationStep(finalTensor, groundTruthInput);
+ (_jpegData, _resizedImage) = AddJpegDecoding(299, 299, 3);
+ });
+
+ return (evalSess, _labelTensor, evaluationStep, prediction);
+ }
+
+ private (Tensor, Tensor) AddEvaluationStep(Tensor resultTensor, Tensor groundTruthTensor)
+ {
+ Tensor evaluationStep = null;
+ Tensor correctPrediction = null;
+
+ tf_with(tf.name_scope("accuracy"), scope =>
+ {
+ tf_with(tf.name_scope("correct_prediction"), delegate
+ {
+ _prediction = tf.argmax(resultTensor, 1);
+ correctPrediction = tf.equal(_prediction, groundTruthTensor);
+ });
+
+ tf_with(tf.name_scope("accuracy"), delegate
+ {
+ evaluationStep = tf.reduce_mean(tf.cast(correctPrediction, tf.float32));
+ });
+ });
+
+ tf.summary.scalar("accuracy", evaluationStep);
+ return (evaluationStep, _prediction);
+ }
+
+ private void UpdateTransferLearningModelOnDisk(ImageClassificationEstimator.Options options, int classCount)
+ {
+ var (sess, _, _, _) = BuildEvaluationSession(options, classCount);
+ var graph = sess.graph;
+ var outputGraphDef = tf.graph_util.convert_variables_to_constants(
+ sess, graph.as_graph_def(), new string[] { _softMaxTensor.name.Split(':')[0], _prediction.name.Split(':')[0], _jpegData.name.Split(':')[0], _resizedImage.name.Split(':')[0] });
+
+ string frozenModelPath = _checkpointPath + ".pb";
+ File.WriteAllBytes(_checkpointPath + ".pb", outputGraphDef.ToByteArray());
+ _session.graph.Dispose();
+ _session.Dispose();
+ _session = LoadTFSessionByModelFilePath(_env, frozenModelPath, false);
+ }
+
+ private void VariableSummaries(RefVariable var)
+ {
+ tf_with(tf.name_scope("summaries"), delegate
+ {
+ var mean = tf.reduce_mean(var);
+ tf.summary.scalar("mean", mean);
+ Tensor stddev = null;
+ tf_with(tf.name_scope("stddev"), delegate
+ {
+ stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)));
+ });
+ tf.summary.scalar("stddev", stddev);
+ tf.summary.scalar("max", tf.reduce_max(var));
+ tf.summary.scalar("min", tf.reduce_min(var));
+ tf.summary.histogram("histogram", var);
+ });
+ }
+
+ private (Operation, Tensor, Tensor, Tensor) AddFinalRetrainOps(int classCount, string labelColumn,
+ string scoreColumnName, float learningRate, Tensor bottleneckTensor, bool isTraining)
+ {
+ var (batch_size, bottleneck_tensor_size) = (bottleneckTensor.TensorShape.Dimensions[0], bottleneckTensor.TensorShape.Dimensions[1]);
+ tf_with(tf.name_scope("input"), scope =>
+ {
+ if (isTraining)
+ {
+ _bottleneckInput = tf.placeholder_with_default(
+ bottleneckTensor,
+ shape: bottleneckTensor.TensorShape.Dimensions,
+ name: "BottleneckInputPlaceholder");
+ }
+
+ _labelTensor = tf.placeholder(tf.int64, new TensorShape(batch_size), name: labelColumn);
+ });
+
+ string layerName = "final_retrain_ops";
+ Tensor logits = null;
+ tf_with(tf.name_scope(layerName), scope =>
+ {
+ RefVariable layerWeights = null;
+ tf_with(tf.name_scope("weights"), delegate
+ {
+ var initialValue = tf.truncated_normal(new int[] { bottleneck_tensor_size, classCount }, stddev: 0.001f);
+ layerWeights = tf.Variable(initialValue, name: "final_weights");
+ VariableSummaries(layerWeights);
+ });
+
+ RefVariable layerBiases = null;
+ tf_with(tf.name_scope("biases"), delegate
+ {
+ layerBiases = tf.Variable(tf.zeros(classCount), name: "final_biases");
+ VariableSummaries(layerBiases);
+ });
+
+ tf_with(tf.name_scope("Wx_plus_b"), delegate
+ {
+ var matmul = tf.matmul(isTraining ? _bottleneckInput : bottleneckTensor, layerWeights);
+ logits = matmul + layerBiases;
+ tf.summary.histogram("pre_activations", logits);
+ });
+ });
+
+ _softMaxTensor = tf.nn.softmax(logits, name: scoreColumnName);
+
+ tf.summary.histogram("activations", _softMaxTensor);
+ if (!isTraining)
+ return (null, null, _labelTensor, _softMaxTensor);
+
+ Tensor crossEntropyMean = null;
+ tf_with(tf.name_scope("cross_entropy"), delegate
+ {
+ crossEntropyMean = tf.losses.sparse_softmax_cross_entropy(
+ labels: _labelTensor, logits: logits);
+ });
+
+ tf.summary.scalar("cross_entropy", crossEntropyMean);
+
+ tf_with(tf.name_scope("train"), delegate
+ {
+ var optimizer = tf.train.GradientDescentOptimizer(learningRate);
+ _trainStep = optimizer.minimize(crossEntropyMean);
+ });
+
+ return (_trainStep, crossEntropyMean, _labelTensor, _softMaxTensor);
+ }
+
+ private void AddTransferLearningLayer(string labelColumn,
+ string scoreColumnName, float learningRate, int classCount)
+ {
+ _bottleneckTensor = Graph.OperationByName(_bottleneckOperationName);
+ tf_with(Graph.as_default(), delegate
+ {
+ (_trainStep, _crossEntropy, _labelTensor, _softMaxTensor) =
+ AddFinalRetrainOps(classCount, labelColumn, scoreColumnName, learningRate, _bottleneckTensor, true);
+ });
+ }
+
+ // Factory method for SignatureLoadDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ => Create(env, ctx).MakeDataTransform(input);
+
+ // Factory method for SignatureLoadRowMapper.
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+
+ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs,
+ out string[] outputs, out bool addBatchDimensionInput,
+ out string labelColumn, out string checkpointName, out Architecture arch,
+ out string scoreColumnName, out string predictedColumnName, out float learningRate, out int classCount, out string predictionTensorName, out string softMaxTensorName,
+ out string jpegDataTensorName, out string resizeTensorName)
+ {
+ addBatchDimensionInput = ctx.Reader.ReadBoolByte();
+
+ var numInputs = ctx.Reader.ReadInt32();
+ env.CheckDecode(numInputs > 0);
+ inputs = new string[numInputs];
+ for (int j = 0; j < inputs.Length; j++)
+ inputs[j] = ctx.LoadNonEmptyString();
+
+ var numOutputs = ctx.Reader.ReadInt32();
+ env.CheckDecode(numOutputs > 0);
+ outputs = new string[numOutputs];
+ for (int j = 0; j < outputs.Length; j++)
+ outputs[j] = ctx.LoadNonEmptyString();
+
+ labelColumn = ctx.Reader.ReadString();
+ checkpointName = ctx.Reader.ReadString();
+ arch = (Architecture)ctx.Reader.ReadInt32();
+ scoreColumnName = ctx.Reader.ReadString();
+ predictedColumnName = ctx.Reader.ReadString();
+ learningRate = ctx.Reader.ReadFloat();
+ classCount = ctx.Reader.ReadInt32();
+ predictionTensorName = ctx.Reader.ReadString();
+ softMaxTensorName = ctx.Reader.ReadString();
+ jpegDataTensorName = ctx.Reader.ReadString();
+ resizeTensorName = ctx.Reader.ReadString();
+ }
+
+ internal ImageClassificationTransformer(IHostEnvironment env, Session session, string[] outputColumnNames,
+ string[] inputColumnNames, string modelLocation,
+ bool? addBatchDimensionInput, int batchSize, string labelColumnName, string finalModelPrefix, Architecture arch,
+ string scoreColumnName, string predictedLabelColumnName, float learningRate, DataViewSchema inputSchema, int? classCount = null, bool loadModel = false,
+ string predictionTensorName = null, string softMaxTensorName = null, string jpegDataTensorName = null, string resizeTensorName = null)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageClassificationTransformer)))
+
+ {
+ Host.CheckValue(session, nameof(session));
+ Host.CheckNonEmpty(inputColumnNames, nameof(inputColumnNames));
+ Host.CheckNonEmpty(outputColumnNames, nameof(outputColumnNames));
+
+ _env = env;
+ _session = session;
+ _addBatchDimensionInput = addBatchDimensionInput ?? arch == Architecture.ResnetV2101;
+ _inputs = inputColumnNames;
+ _outputs = outputColumnNames;
+ _labelColumnName = labelColumnName;
+ _finalModelPrefix = finalModelPrefix;
+ _arch = arch;
+ _scoreColumnName = scoreColumnName;
+ _predictedLabelColumnName = predictedLabelColumnName;
+ _learningRate = learningRate;
+ _softmaxTensorName = softMaxTensorName;
+ _predictionTensorName = predictionTensorName;
+ _jpegDataTensorName = jpegDataTensorName;
+ _resizedImageTensorName = resizeTensorName;
+
+ if (classCount == null)
+ {
+ var labelColumn = inputSchema.GetColumnOrNull(labelColumnName).Value;
+ var labelType = labelColumn.Type;
+ var labelCount = labelType.GetKeyCount();
+ if (labelCount <= 0)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", (string)labelColumn.Name, "Key", (string)labelType.ToString());
+
+ _classCount = labelCount == 1 ? 2 : (int)labelCount;
+ }
+ else
+ _classCount = classCount.Value;
+
+ _checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), finalModelPrefix + modelLocation);
+
+ // Configure bottleneck tensor based on the model.
+ if (arch == ImageClassificationEstimator.Architecture.ResnetV2101)
+ {
+ _bottleneckOperationName = "resnet_v2_101/SpatialSqueeze";
+ _inputTensorName = "input";
+ }
+ else if (arch == ImageClassificationEstimator.Architecture.InceptionV3)
+ {
+ _bottleneckOperationName = "module_apply_default/hub_output/feature_vector/SpatialSqueeze";
+ _inputTensorName = "Placeholder";
+ }
+
+ _outputs = new[] { scoreColumnName, predictedLabelColumnName };
+
+ if (loadModel == false)
+ {
+ (_jpegData, _resizedImage) = AddJpegDecoding(299, 299, 3);
+ _jpegDataTensorName = _jpegData.name;
+ _resizedImageTensorName = _resizedImage.name;
+
+ // Add transfer learning layer.
+ AddTransferLearningLayer(labelColumnName, scoreColumnName, learningRate, _classCount);
+
+ // Initialize the variables.
+ new Runner(_session).AddOperation(tf.global_variables_initializer()).Run();
+
+ // Add evaluation layer.
+ (_evaluationStep, _) = AddEvaluationStep(_softMaxTensor, _labelTensor);
+ _softmaxTensorName = _softMaxTensor.name;
+ _predictionTensorName = _prediction.name;
+ }
+ }
+
+ private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema);
+
+ private protected override void SaveModel(ModelSaveContext ctx)
+ {
+ Host.AssertValue(ctx);
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+ ctx.Writer.WriteBoolByte(_addBatchDimensionInput);
+
+ Host.AssertNonEmpty(_inputs);
+ ctx.Writer.Write(_inputs.Length);
+ foreach (var colName in _inputs)
+ ctx.SaveNonEmptyString(colName);
+
+ Host.AssertNonEmpty(_outputs);
+ ctx.Writer.Write(_outputs.Length);
+ foreach (var colName in _outputs)
+ ctx.SaveNonEmptyString(colName);
+
+ ctx.Writer.Write(_labelColumnName);
+ ctx.Writer.Write(_finalModelPrefix);
+ ctx.Writer.Write((int)_arch);
+ ctx.Writer.Write(_scoreColumnName);
+ ctx.Writer.Write(_predictedLabelColumnName);
+ ctx.Writer.Write(_learningRate);
+ ctx.Writer.Write(_classCount);
+ ctx.Writer.Write(_predictionTensorName);
+ ctx.Writer.Write(_softmaxTensorName);
+ ctx.Writer.Write(_jpegDataTensorName);
+ ctx.Writer.Write(_resizedImageTensorName);
+ Status status = new Status();
+ var buffer = _session.graph.ToGraphDef(status);
+ ctx.SaveBinaryStream("TFModel", w =>
+ {
+ w.WriteByteArray(buffer.Data);
+ });
+ status.Check(true);
+ }
+
+ ~ImageClassificationTransformer()
+ {
+ Dispose(false);
+ }
+
+ private void Dispose(bool disposing)
+ {
+ // Ensure that the Session is not null and it's handle is not Zero, as it may have already been disposed/finalized.
+ // Technically we shouldn't be calling this if disposing == false, since we're running in finalizer
+ // and the GC doesn't guarantee ordering of finalization of managed objects, but we have to make sure
+ // that the Session is closed before deleting our temporary directory.
+ if (_session != null && _session != IntPtr.Zero)
+ {
+ _session.close();
+ }
+ }
+
+ private sealed class Mapper : MapperBase
+ {
+ private readonly ImageClassificationTransformer _parent;
+ private readonly int[] _inputColIndices;
+
+ public Mapper(ImageClassificationTransformer parent, DataViewSchema inputSchema) :
+ base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
+ {
+ Host.CheckValue(parent, nameof(parent));
+ _parent = parent;
+ _inputColIndices = new int[1];
+ if (!inputSchema.TryGetColumnIndex(_parent._inputs[0], out _inputColIndices[0]))
+ throw Host.ExceptSchemaMismatch(nameof(InputSchema), "source", _parent._inputs[0]);
+ }
+
+ private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
+
+ private class OutputCache
+ {
+ public long Position;
+ private ValueGetter> _imagePathGetter;
+ private ReadOnlyMemory _imagePath;
+ private Runner _runner;
+ private ImageProcessor _imageProcessor;
+ public UInt32 PredictedLabel { get; set; }
+ public float[] ClassProbabilities { get; set; }
+ private DataViewRow _inputRow;
+
+ public OutputCache(DataViewRow input, ImageClassificationTransformer transformer)
+ {
+ _imagePath = default;
+ _imagePathGetter = input.GetGetter>(input.Schema[transformer._inputs[0]]);
+ _runner = new Runner(transformer._session);
+ _runner.AddInput(transformer._inputTensorName);
+ _runner.AddOutputs(transformer._softmaxTensorName);
+ _runner.AddOutputs(transformer._predictionTensorName);
+ _imageProcessor = new ImageProcessor(transformer);
+ _inputRow = input;
+ Position = -1;
+ }
+
+ public void UpdateCacheIfNeeded()
+ {
+ lock (this)
+ {
+ if (_inputRow.Position != Position)
+ {
+ Position = _inputRow.Position;
+ _imagePathGetter(ref _imagePath);
+ var processedTensor = _imageProcessor.ProcessImage(_imagePath.ToString());
+ var outputTensor = _runner.AddInput(processedTensor, 0).Run();
+ ClassProbabilities = outputTensor[0].Data();
+ PredictedLabel = (UInt32)outputTensor[1].Data()[0];
+ outputTensor[0].Dispose();
+ outputTensor[1].Dispose();
+ processedTensor.Dispose();
+ }
+ }
+ }
+ }
+
+ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer)
+ {
+ disposer = null;
+ Host.AssertValue(input);
+ var cache = new OutputCache(input, _parent);
+
+ if (iinfo == 0)
+ {
+ ValueGetter> valuegetter = (ref VBuffer dst) =>
+ {
+ cache.UpdateCacheIfNeeded();
+ var editor = VBufferEditor.Create(ref dst, cache.ClassProbabilities.Length);
+ new Span(cache.ClassProbabilities, 0, cache.ClassProbabilities.Length).CopyTo(editor.Values);
+ dst = editor.Commit();
+ };
+ return valuegetter;
+ }
+ else
+ {
+ ValueGetter valuegetter = (ref UInt32 dst) =>
+ {
+ cache.UpdateCacheIfNeeded();
+ dst = cache.PredictedLabel;
+ };
+
+ return valuegetter;
+ }
+ }
+
+ private protected override Func GetDependenciesCore(Func activeOutput)
+ {
+ return col => Enumerable.Range(0, _parent._outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col);
+ }
+
+ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
+ {
+ var info = new DataViewSchema.DetachedColumn[_parent._outputs.Length];
+ info[0] = new DataViewSchema.DetachedColumn(_parent._outputs[0], new VectorDataViewType(NumberDataViewType.Single, _parent._classCount), null);
+ info[1] = new DataViewSchema.DetachedColumn(_parent._outputs[1], NumberDataViewType.UInt32, null);
+ return info;
+ }
+ }
+ }
+
+ ///
+ public sealed class ImageClassificationEstimator : IEstimator
+ {
+ ///
+ /// Image classification model.
+ ///
+ public enum Architecture
+ {
+ ResnetV2101,
+ InceptionV3
+ };
+
+ ///
+ /// Backend DNN training framework.
+ ///
+ public enum DnnFramework
+ {
+ Tensorflow
+ };
+
+ ///
+ /// Callback that returns DNN statistics during training phase.
+ ///
+ public delegate void ImageClassificationMetricsCallback(ImageClassificationMetrics metrics);
+
+ ///
+ /// DNN training metrics.
+ ///
+ public sealed class TrainMetrics
+ {
+ ///
+ /// Indicates the dataset on which metrics are being reported.
+ ///
+ ///
+ public ImageClassificationMetrics.Dataset DatasetUsed { get; set; }
+
+ ///
+ /// The number of batches processed in an epoch.
+ ///
+ public int BatchProcessedCount { get; set; }
+
+ ///
+ /// The training epoch index for which this metric is reported.
+ ///
+ public int Epoch { get; set; }
+
+ ///
+ /// Accuracy of the batch on this . Higher the better.
+ ///
+ public float Accuracy { get; set; }
+
+ ///
+ /// Cross-Entropy (loss) of the batch on this . Lower
+ /// the better.
+ ///
+ public float CrossEntropy { get; set; }
+
+ ///
+ /// String representation of the metrics.
+ ///
+ public override string ToString()
+ {
+ if (DatasetUsed == ImageClassificationMetrics.Dataset.Train)
+ return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, " +
+ $"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}, Cross-Entropy: {CrossEntropy,10}";
+ else
+ return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, " +
+ $"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}";
+ }
+ }
+
+ ///
+ /// Metrics for image featurization values. The input image is passed through
+ /// the network and features are extracted from second or last layer to
+ /// train a custom full connected layer that serves as classifier.
+ ///
+ public sealed class BottleneckMetrics
+ {
+ ///
+ /// Indicates the dataset on which metrics are being reported.
+ ///
+ ///
+ public ImageClassificationMetrics.Dataset DatasetUsed { get; set; }
+
+ ///
+ /// Name of the input image.
+ ///
+ public string Name { get; set; }
+
+ ///
+ /// Index of the input image.
+ ///
+ public int Index { get; set; }
+
+ ///
+ /// String representation of the metrics.
+ ///
+ public override string ToString() => $"Phase: Bottleneck Computation, Dataset used: {DatasetUsed.ToString(),10}, Image Index: {Index,3}, Image Name: {Name}";
+ }
+
+ ///
+ /// Metrics for image classification training.
+ ///
+ public sealed class ImageClassificationMetrics
+ {
+ ///
+ /// Indicates the kind of the dataset of which metric is reported.
+ ///
+ public enum Dataset
+ {
+ Train,
+ Validation
+ };
+
+ ///
+ /// Contains train time metrics.
+ ///
+ public TrainMetrics Train { get; set; }
+
+ ///
+ /// Contains pre-train time metrics. These contains metrics on image
+ /// featurization.
+ ///
+ public BottleneckMetrics Bottleneck { get; set; }
+
+ ///
+ /// String representation of the metrics.
+ ///
+ public override string ToString() => Train != null ? Train.ToString() : Bottleneck.ToString();
+ }
+
+ ///
+ /// The options for the .
+ ///
+ internal sealed class Options : TransformInputBase
+ {
+ ///
+ /// Location of the TensorFlow model.
+ ///
+ [Argument(ArgumentType.Required, HelpText = "TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", SortOrder = 0)]
+ public string ModelLocation;
+
+ ///
+ /// The names of the model inputs.
+ ///
+ [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 1)]
+ public string[] InputColumns;
+
+ ///
+ /// The names of the requested model outputs.
+ ///
+ [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The name of the outputs", ShortName = "outputs", SortOrder = 2)]
+ public string[] OutputColumns;
+
+ ///
+ /// The name of the label column in that will be mapped to label node in TensorFlow model.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Training labels.", ShortName = "label", SortOrder = 4)]
+ public string LabelColumn;
+
+ ///
+ /// The name of the label in TensorFlow model.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "TensorFlow label node.", ShortName = "TFLabel", SortOrder = 5)]
+ public string TensorFlowLabel;
+
+ ///
+ /// Number of samples to use for mini-batch training.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Number of samples to use for mini-batch training.", SortOrder = 9)]
+ public int BatchSize = 64;
+
+ ///
+ /// Number of training iterations.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Number of training iterations.", SortOrder = 10)]
+ public int Epoch = 5;
+
+ ///
+ /// Learning rate to use during optimization.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate to use during optimization.", SortOrder = 12)]
+ public float LearningRate = 0.01f;
+
+ ///
+ /// Specifies the model architecture to be used in the case of image classification training using transfer learning.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Model architecture to be used in transfer learning for image classification.", SortOrder = 15)]
+ public Architecture Arch = Architecture.InceptionV3;
+
+ ///
+ /// Name of the tensor that will contain the output scores of the last layer when transfer learning is done.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Softmax tensor of the last layer in transfer learning.", SortOrder = 15)]
+ public string ScoreColumnName = "Scores";
+
+ ///
+ /// Name of the tensor that will contain the predicted label from output scores of the last layer when transfer learning is done.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Argmax tensor of the last layer in transfer learning.", SortOrder = 15)]
+ public string PredictedLabelColumnName = "PredictedLabel";
+
+ ///
+ /// Final model and checkpoint files/folder prefix for storing graph files.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Final model and checkpoint files/folder prefix for storing graph files.", SortOrder = 15)]
+ public string FinalModelPrefix = "custom_retrained_model_based_on_";
+
+ ///
+ /// Callback to report statistics on accuracy/cross entropy during training phase.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Callback to report metrics during training and validation phase.", SortOrder = 15)]
+ public ImageClassificationMetricsCallback MetricsCallback = null;
+
+ ///
+ /// Frequency of epochs at which statistics on training phase should be reported.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Frequency of epochs at which statistics on training/validation phase should be reported.", SortOrder = 15)]
+ public int StatisticsFrequency = 1;
+
+ ///
+ /// Indicates the choice DNN training framework. Currently only TensorFlow is supported.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the choice DNN training framework. Currently only TensorFlow is supported.", SortOrder = 15)]
+ public DnnFramework Framework = DnnFramework.Tensorflow;
+
+ ///
+ /// Indicates the path where the newly retrained model should be saved.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the path where the newly retrained model should be saved.", SortOrder = 15)]
+ public string ModelSavePath = null;
+
+ ///
+ /// Indicates to evaluate the model on train set after every epoch.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to evaluate the model on train set after every epoch.", SortOrder = 15)]
+ public bool TestOnTrainSet;
+
+ ///
+ /// Indicates to not re-compute cached bottleneck trainset values if already available in the bin folder.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to not re-compute trained cached bottleneck values if already available in the bin folder.", SortOrder = 15)]
+ public bool ReuseTrainSetBottleneckCachedValues;
+
+ ///
+ /// Indicates to not re-compute cached bottleneck validationset values if already available in the bin folder.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to not re-compute validataionset cached bottleneck validationset values if already available in the bin folder.", SortOrder = 15)]
+ public bool ReuseValidationSetBottleneckCachedValues;
+
+ ///
+ /// Validation set.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Validation set.", SortOrder = 15)]
+ public IDataView ValidationSet;
+
+ ///
+ /// Indicates the file path to store trainset bottleneck values for caching.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the file path to store trainset bottleneck values for caching.", SortOrder = 15)]
+ public string TrainSetBottleneckCachedValuesFilePath;
+
+ ///
+ /// Indicates the file path to store validationset bottleneck values for caching.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the file path to store validationset bottleneck values for caching.", SortOrder = 15)]
+ public string ValidationSetBottleneckCachedValuesFilePath;
+ }
+
+ private readonly IHost _host;
+ private readonly Options _options;
+ private readonly DnnModel _dnnModel;
+ private readonly TF_DataType[] _tfInputTypes;
+ private readonly DataViewType[] _outputTypes;
+ private ImageClassificationTransformer _transformer;
+
+ internal ImageClassificationEstimator(IHostEnvironment env, Options options, DnnModel dnnModel)
+ {
+ _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageClassificationEstimator));
+ _options = options;
+ _dnnModel = dnnModel;
+ _tfInputTypes = new[] { TF_DataType.TF_STRING };
+ _outputTypes = new[] { new VectorDataViewType(NumberDataViewType.Single), NumberDataViewType.UInt32.GetItemType() };
+ }
+
+ private static Options CreateArguments(DnnModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput)
+ {
+ var options = new Options();
+ options.ModelLocation = tensorFlowModel.ModelPath;
+ options.InputColumns = inputColumnName;
+ options.OutputColumns = outputColumnNames;
+ return options;
+ }
+
+ ///
+ /// Returns the of the schema which will be produced by the transformer.
+ /// Used for schema propagation and verification in a pipeline.
+ ///
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ _host.CheckValue(inputSchema, nameof(inputSchema));
+ var result = inputSchema.ToDictionary(x => x.Name);
+ var resultDic = inputSchema.ToDictionary(x => x.Name);
+ for (var i = 0; i < _options.InputColumns.Length; i++)
+ {
+ var input = _options.InputColumns[i];
+ if (!inputSchema.TryFindColumn(input, out var col))
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
+ var expectedType = DnnUtils.Tf2MlNetType(_tfInputTypes[i]);
+ if (col.ItemType != expectedType)
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
+ }
+ for (var i = 0; i < _options.OutputColumns.Length; i++)
+ {
+ resultDic[_options.OutputColumns[i]] = new SchemaShape.Column(_options.OutputColumns[i],
+ _outputTypes[i].IsKnownSizeVector() ? SchemaShape.Column.VectorKind.Vector
+ : SchemaShape.Column.VectorKind.VariableVector, _outputTypes[i].GetItemType(), false);
+ }
+ return new SchemaShape(resultDic.Values);
+ }
+
+ ///
+ /// Trains and returns a .
+ ///
+ public ImageClassificationTransformer Fit(IDataView input)
+ {
+ _host.CheckValue(input, nameof(input));
+ if (_transformer == null)
+ _transformer = new ImageClassificationTransformer(_host, _options, _dnnModel, input);
+
+ // Validate input schema.
+ _transformer.GetOutputSchema(input.Schema);
+ return _transformer;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index 7a8d7063dd..e1b8ef452e 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -639,7 +639,10 @@ private void UpdateCacheIfNeeded(long position, ITensorValueGetter[] srcTensorGe
// Feed inputs to the graph.
for (int i = 0; i < _parent.Inputs.Length; i++)
- runner.AddInput(_parent.Inputs[i], srcTensorGetters[i].GetTensor());
+ {
+ var tensor = srcTensorGetters[i].GetTensor();
+ runner.AddInput(_parent.Inputs[i], tensor);
+ }
// Add outputs.
for (int i = 0; i < _parent.Outputs.Length; i++)
@@ -651,8 +654,12 @@ private void UpdateCacheIfNeeded(long position, ITensorValueGetter[] srcTensorGe
Contracts.Assert(tensors.Length > 0);
for (int j = 0; j < activeOutputColNames.Length; j++)
- outputCache.Outputs[activeOutputColNames[j]] = tensors[j];
+ {
+ if (outputCache.Outputs.TryGetValue(activeOutputColNames[j], out Tensor outTensor))
+ outTensor.Dispose();
+ outputCache.Outputs[activeOutputColNames[j]] = tensors[j];
+ }
outputCache.Position = position;
}
}
@@ -704,7 +711,6 @@ private class TensorValueGetter : ITensorValueGetter
private readonly T[] _bufferedData;
private readonly TensorShape _tfShape;
private int _position;
- private readonly List _tensors;
public TensorValueGetter(DataViewRow input, int colIndex, TensorShape tfShape)
{
@@ -719,7 +725,6 @@ public TensorValueGetter(DataViewRow input, int colIndex, TensorShape tfShape)
size *= dim;
}
_bufferedData = new T[size];
- _tensors = new List();
}
public Tensor GetTensor()
@@ -728,7 +733,6 @@ public Tensor GetTensor()
_srcgetter(ref scalar);
var tensor = new Tensor(new[] { scalar });
tensor.SetShape(_tfShape);
- _tensors.Add(tensor);
return tensor;
}
@@ -743,7 +747,6 @@ public Tensor GetBufferedBatchTensor()
{
var tensor = new Tensor(new NDArray(_bufferedData, _tfShape));
_position = 0;
- _tensors.Add(tensor);
return tensor;
}
}
@@ -757,7 +760,6 @@ private class TensorValueGetterVec : ITensorValueGetter
private T[] _bufferedData;
private int _position;
private long[] _dims;
- private readonly List _tensors;
private readonly long _bufferedDataSize;
public TensorValueGetterVec(DataViewRow input, int colIndex, TensorShape tfShape)
@@ -778,7 +780,6 @@ public TensorValueGetterVec(DataViewRow input, int colIndex, TensorShape tfShape
_bufferedData = new T[size];
if (_tfShape.Dimensions != null)
_dims = _tfShape.Dimensions.Select(x => (long)x).ToArray();
- _tensors = new List();
_bufferedDataSize = size;
}
@@ -792,7 +793,6 @@ public Tensor GetTensor()
_denseData = new T[_vBuffer.Length];
_vBuffer.CopyTo(_denseData);
var tensor = CastDataAndReturnAsTensor(_denseData);
- _tensors.Add(tensor);
return tensor;
}
@@ -845,7 +845,6 @@ public Tensor GetBufferedBatchTensor()
{
_position = 0;
var tensor = CastDataAndReturnAsTensor(_bufferedData);
- _tensors.Add(tensor);
_bufferedData = new T[_bufferedDataSize];
return tensor;
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
index d5bf30ab96..e9727d13a7 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
@@ -3760,21 +3760,6 @@ public void EntryPointWordEmbeddings()
}
}
- [TensorFlowFact]
- public void EntryPointTensorFlowTransform()
- {
- Env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransformer).Assembly);
-
- TestEntryPointPipelineRoutine(GetDataPath("Train-Tiny-28x28.txt"), "col=Label:R4:0 col=Placeholder:R4:1-784",
- new[] { "Transforms.TensorFlowScorer" },
- new[]
- {
- @"'InputColumns': [ 'Placeholder' ],
- 'ModelLocation': 'mnist_model/frozen_saved_model.pb',
- 'OutputColumns': [ 'Softmax' ]"
- });
- }
-
[Fact(Skip = "Needs real time series dataset. https://github.com/dotnet/machinelearning/issues/1120")]
public void EntryPointSsaChangePoint()
{
@@ -5637,6 +5622,21 @@ public void TestOvaMacroWithUncalibratedLearner()
}
}
+ [TensorFlowFact]
+ public void EntryPointTensorFlowTransform()
+ {
+ Env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransformer).Assembly);
+
+ TestEntryPointPipelineRoutine(GetDataPath("Train-Tiny-28x28.txt"), "col=Label:R4:0 col=Placeholder:R4:1-784",
+ new[] { "Transforms.TensorFlowScorer" },
+ new[]
+ {
+ @"'InputColumns': [ 'Placeholder' ],
+ 'ModelLocation': 'mnist_model/frozen_saved_model.pb',
+ 'OutputColumns': [ 'Softmax' ]"
+ });
+ }
+
[TensorFlowFact]
public void TestTensorFlowEntryPoint()
{
diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs
deleted file mode 100644
index ef62dcae77..0000000000
--- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs
+++ /dev/null
@@ -1,99 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System.IO;
-using Microsoft.ML.Data;
-using Microsoft.ML.Transforms.Image;
-using Microsoft.ML.TestFramework.Attributes;
-using Microsoft.ML.Transforms;
-using Xunit;
-
-namespace Microsoft.ML.Scenarios
-{
- public partial class ScenariosTests
- {
- [TensorFlowFact]
- public void TensorFlowTransforCifarEndToEndTest()
- {
- var imageHeight = 32;
- var imageWidth = 32;
- var model_location = "cifar_model/frozen_model.pb";
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
-
- var mlContext = new MLContext(seed: 1);
- var data = TextLoader.Create(mlContext, new TextLoader.Options()
- {
- Columns = new[]
- {
- new TextLoader.Column("ImagePath", DataKind.String, 0),
- new TextLoader.Column("Label", DataKind.String, 1),
- }
- }, new MultiFileSource(dataFile));
-
- var pipeEstimator = new ImageLoadingEstimator(mlContext, imageFolder, ("ImageReal", "ImagePath"))
- .Append(new ImageResizingEstimator(mlContext, "ImageCropped", imageHeight, imageWidth, "ImageReal"))
- .Append(new ImagePixelExtractingEstimator(mlContext, "Input", "ImageCropped", interleavePixelColors: true))
- .Append(mlContext.Model.LoadTensorFlowModel(model_location).ScoreTensorFlowModel("Output", "Input"))
- .Append(new ColumnConcatenatingEstimator(mlContext, "Features", "Output"))
- .Append(new ValueToKeyMappingEstimator(mlContext, "Label"))
- .AppendCacheCheckpoint(mlContext)
- .Append(mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy());
-
-
- var transformer = pipeEstimator.Fit(data);
- var predictions = transformer.Transform(data);
-
- var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
- Assert.Equal(1, metrics.MicroAccuracy, 2);
-
- var predictFunction = mlContext.Model.CreatePredictionEngine(transformer);
- var prediction = predictFunction.Predict(new CifarData()
- {
- ImagePath = GetDataPath("images/banana.jpg")
- });
- Assert.Equal(0, prediction.PredictedScores[0], 2);
- Assert.Equal(1, prediction.PredictedScores[1], 2);
- Assert.Equal(0, prediction.PredictedScores[2], 2);
-
- prediction = predictFunction.Predict(new CifarData()
- {
- ImagePath = GetDataPath("images/hotdog.jpg")
- });
- Assert.Equal(0, prediction.PredictedScores[0], 2);
- Assert.Equal(0, prediction.PredictedScores[1], 2);
- Assert.Equal(1, prediction.PredictedScores[2], 2);
- }
- }
-
- public class CifarData
- {
- [LoadColumn(0)]
- public string ImagePath;
-
- [LoadColumn(1)]
- public string Label;
- }
-
- public class CifarPrediction
- {
- [ColumnName("Score")]
- public float[] PredictedScores;
- }
-
- public class ImageNetData
- {
- [LoadColumn(0)]
- public string ImagePath;
-
- [LoadColumn(1)]
- public string Label;
- }
-
- public class ImageNetPrediction
- {
- [ColumnName("Score")]
- public float[] PredictedLabels;
- }
-}
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index 810a308d67..a48ecc6550 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -5,21 +5,30 @@
using System;
using System.Collections.Generic;
using System.IO;
+using System.IO.Compression;
using System.Linq;
+using System.Net;
using System.Runtime.InteropServices;
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
+using Microsoft.ML.TestFramework;
using Microsoft.ML.TestFramework.Attributes;
using Microsoft.ML.Transforms;
using Microsoft.ML.Transforms.Image;
using Microsoft.ML.Transforms.TensorFlow;
-using Tensorflow;
using Xunit;
+using Xunit.Abstractions;
+using static Microsoft.ML.DataOperationsCatalog;
namespace Microsoft.ML.Scenarios
{
- public partial class ScenariosTests
+ [Collection("NoParallelization")]
+ public sealed class TensorFlowScenariosTests : BaseTestClass
{
+ public TensorFlowScenariosTests(ITestOutputHelper output) : base(output)
+ {
+ }
+
private class TestData
{
[VectorType(4)]
@@ -28,6 +37,89 @@ private class TestData
public float[] b;
}
+ public class CifarData
+ {
+ [LoadColumn(0)]
+ public string ImagePath;
+
+ [LoadColumn(1)]
+ public string Label;
+ }
+
+ public class CifarPrediction
+ {
+ [ColumnName("Score")]
+ public float[] PredictedScores;
+ }
+
+ public class ImageNetData
+ {
+ [LoadColumn(0)]
+ public string ImagePath;
+
+ [LoadColumn(1)]
+ public string Label;
+ }
+
+ public class ImageNetPrediction
+ {
+ [ColumnName("Score")]
+ public float[] PredictedLabels;
+ }
+
+ [TensorFlowFact]
+ public void TensorFlowTransforCifarEndToEndTest2()
+ {
+ var imageHeight = 32;
+ var imageWidth = 32;
+ var model_location = "cifar_model/frozen_model.pb";
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+
+ var mlContext = new MLContext(seed: 1);
+ var data = TextLoader.Create(mlContext, new TextLoader.Options()
+ {
+ Columns = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.String, 0),
+ new TextLoader.Column("Label", DataKind.String, 1),
+ }
+ }, new MultiFileSource(dataFile));
+
+ var pipeEstimator = new ImageLoadingEstimator(mlContext, imageFolder, ("ImageReal", "ImagePath"))
+ .Append(new ImageResizingEstimator(mlContext, "ImageCropped", imageHeight, imageWidth, "ImageReal"))
+ .Append(new ImagePixelExtractingEstimator(mlContext, "Input", "ImageCropped", interleavePixelColors: true))
+ .Append(mlContext.Model.LoadTensorFlowModel(model_location).ScoreTensorFlowModel("Output", "Input"))
+ .Append(new ColumnConcatenatingEstimator(mlContext, "Features", "Output"))
+ .Append(new ValueToKeyMappingEstimator(mlContext, "Label"))
+ .AppendCacheCheckpoint(mlContext)
+ .Append(mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy());
+
+
+ var transformer = pipeEstimator.Fit(data);
+ var predictions = transformer.Transform(data);
+
+ var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
+ Assert.Equal(1, metrics.MicroAccuracy, 2);
+
+ var predictFunction = mlContext.Model.CreatePredictionEngine(transformer);
+ var prediction = predictFunction.Predict(new CifarData()
+ {
+ ImagePath = GetDataPath("images/banana.jpg")
+ });
+ Assert.Equal(0, prediction.PredictedScores[0], 2);
+ Assert.Equal(1, prediction.PredictedScores[1], 2);
+ Assert.Equal(0, prediction.PredictedScores[2], 2);
+
+ prediction = predictFunction.Predict(new CifarData()
+ {
+ ImagePath = GetDataPath("images/hotdog.jpg")
+ });
+ Assert.Equal(0, prediction.PredictedScores[0], 2);
+ Assert.Equal(0, prediction.PredictedScores[1], 2);
+ Assert.Equal(1, prediction.PredictedScores[2], 2);
+ }
+
[TensorFlowFact]
public void TensorFlowTransformMatrixMultiplicationTest()
{
@@ -1113,5 +1205,183 @@ public void TensorFlowStringTest()
Assert.Equal(input.A[i], textOutput.AOut[i]);
Assert.Equal(string.Join(" ", input.B).Replace("/", " "), textOutput.BOut[0]);
}
+
+ [TensorFlowFact]
+ public void TensorFlowImageClassification()
+ {
+ string assetsRelativePath = @"assets";
+ string assetsPath = GetAbsolutePath(assetsRelativePath);
+ string imagesDownloadFolderPath = Path.Combine(assetsPath, "inputs",
+ "images");
+
+ //Download the image set and unzip
+ string finalImagesFolderName = DownloadImageSet(
+ imagesDownloadFolderPath);
+
+ string fullImagesetFolderPath = Path.Combine(
+ imagesDownloadFolderPath, finalImagesFolderName);
+
+ MLContext mlContext = new MLContext(seed: 1);
+
+ //Load all the original images info
+ IEnumerable images = LoadImagesFromDirectory(
+ folder: fullImagesetFolderPath, useFolderNameAsLabel: true);
+
+ IDataView shuffledFullImagesDataset = mlContext.Data.ShuffleRows(
+ mlContext.Data.LoadFromEnumerable(images), seed: 1);
+
+ shuffledFullImagesDataset = mlContext.Transforms.Conversion
+ .MapValueToKey("Label")
+ .Fit(shuffledFullImagesDataset)
+ .Transform(shuffledFullImagesDataset);
+
+ // Split the data 80:10 into train and test sets, train and evaluate.
+ TrainTestData trainTestData = mlContext.Data.TrainTestSplit(
+ shuffledFullImagesDataset, testFraction: 0.2, seed: 1);
+
+ IDataView trainDataset = trainTestData.TrainSet;
+ IDataView testDataset = trainTestData.TestSet;
+
+ var pipeline = mlContext.Model.ImageClassification(
+ "ImagePath", "Label",
+ arch: ImageClassificationEstimator.Architecture.ResnetV2101,
+ epoch: 5,
+ batchSize: 5,
+ learningRate: 0.01f,
+ testOnTrainSet: false);
+
+ var trainedModel = pipeline.Fit(trainDataset);
+
+ mlContext.Model.Save(trainedModel, shuffledFullImagesDataset.Schema,
+ "model.zip");
+
+ ITransformer loadedModel;
+ DataViewSchema schema;
+ using (var file = File.OpenRead("model.zip"))
+ loadedModel = mlContext.Model.Load(file, out schema);
+
+ IDataView predictions = trainedModel.Transform(testDataset);
+ var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
+
+ // On Ubuntu the results seem to vary quite a bit but they can probably be
+ // controlled by training more epochs, however that will slow the
+ // build down. Accuracy values seen were 0.33, 0.66, 0.70+. The model
+ // seems to be unstable, there could be many reasons, will need to
+ // investigate this further.
+ if (!(RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ||
+ (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))))
+ {
+ Assert.InRange(metrics.MicroAccuracy, 0.3, 1);
+ Assert.InRange(metrics.MacroAccuracy, 0.3, 1);
+ }
+ else
+ {
+ Assert.Equal(1, metrics.MicroAccuracy);
+ Assert.Equal(1, metrics.MacroAccuracy);
+ }
+ }
+
+ public static IEnumerable LoadImagesFromDirectory(string folder,
+ bool useFolderNameAsLabel = true)
+ {
+ var files = Directory.GetFiles(folder, "*",
+ searchOption: SearchOption.AllDirectories);
+
+ foreach (var file in files)
+ {
+ if (Path.GetExtension(file) != ".jpg")
+ continue;
+
+ var label = Path.GetFileName(file);
+ if (useFolderNameAsLabel)
+ label = Directory.GetParent(file).Name;
+ else
+ {
+ for (int index = 0; index < label.Length; index++)
+ {
+ if (!char.IsLetter(label[index]))
+ {
+ label = label.Substring(0, index);
+ break;
+ }
+ }
+ }
+
+ yield return new ImageData()
+ {
+ ImagePath = file,
+ Label = label
+ };
+
+ }
+ }
+
+ public static string DownloadImageSet(string imagesDownloadFolder)
+ {
+ string fileName = "flower_photos_tiny_set_for_unit_tests.zip";
+ string url = $"https://mlnetfilestorage.file.core.windows.net/imagesets" +
+ $"/flower_images/flower_photos_tiny_set_for_unit_tests.zip?st=2019" +
+ $"-08-29T00%3A07%3A21Z&se=2030-08-30T00%3A07%3A00Z&sp=rl&sv=2018" +
+ $"-03-28&sr=f&sig=N8HbLziTcT61kstprNLmn%2BDC0JoMrNwo6yRWb3hLLag%3D";
+
+ Download(url, imagesDownloadFolder, fileName);
+ UnZip(Path.Combine(imagesDownloadFolder, fileName), imagesDownloadFolder);
+
+ return Path.GetFileNameWithoutExtension(fileName);
+ }
+
+ private static bool Download(string url, string destDir, string destFileName)
+ {
+ if (destFileName == null)
+ destFileName = url.Split(Path.DirectorySeparatorChar).Last();
+
+ Directory.CreateDirectory(destDir);
+
+ string relativeFilePath = Path.Combine(destDir, destFileName);
+
+ if (File.Exists(relativeFilePath))
+ return false;
+
+ new WebClient().DownloadFile(url, relativeFilePath);
+ return true;
+ }
+
+ private static void UnZip(String gzArchiveName, String destFolder)
+ {
+ var flag = gzArchiveName.Split(Path.DirectorySeparatorChar)
+ .Last()
+ .Split('.')
+ .First() + ".bin";
+
+ if (File.Exists(Path.Combine(destFolder, flag)))
+ return;
+
+ ZipFile.ExtractToDirectory(gzArchiveName, destFolder);
+ File.Create(Path.Combine(destFolder, flag));
+ }
+
+ public static string GetAbsolutePath(string relativePath) =>
+ Path.Combine(new FileInfo(typeof(
+ TensorFlowScenariosTests).Assembly.Location).Directory.FullName, relativePath);
+
+
+ public class ImageData
+ {
+ [LoadColumn(0)]
+ public string ImagePath;
+
+ [LoadColumn(1)]
+ public string Label;
+ }
+
+ public class ImagePrediction
+ {
+ [ColumnName("Score")]
+ public float[] Score;
+
+ [ColumnName("PredictedLabel")]
+ public UInt32 PredictedLabel;
+ }
+
}
}
diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs
index b3838a791e..2a78910617 100644
--- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs
+++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs
@@ -17,6 +17,10 @@
namespace Microsoft.ML.Tests
{
+ [CollectionDefinition("NoParallelization", DisableParallelization = true)]
+ public class NoParallelizationCollection { }
+
+ [Collection("NoParallelization")]
public class TensorFlowEstimatorTests : TestDataPipeBase
{
private class TestData