diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/InceptionV3TransferLearning.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/InceptionV3TransferLearning.cs deleted file mode 100644 index 0e1ec80973..0000000000 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/InceptionV3TransferLearning.cs +++ /dev/null @@ -1,109 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using Microsoft.ML; -using Microsoft.ML.Data; -using Microsoft.ML.Transforms; - -namespace Samples.Dynamic -{ - public static class InceptionV3TransferLearning - { - /// - /// Example use of Image classification API in a ML.NET pipeline. - /// - public static void Example() - { - var mlContext = new MLContext(seed: 1); - - var imagesDataFile = Path.GetDirectoryName( - Microsoft.ML.SamplesUtils.DatasetUtils.DownloadImages()); - - var data = mlContext.Data.LoadFromEnumerable( - ImageNetData.LoadImagesFromDirectory(imagesDataFile, 4)); - - data = mlContext.Data.ShuffleRows(data, 5); - var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label") - .Append(mlContext.Transforms.LoadImages("ImageObject", null, - "ImagePath")) - .Append(mlContext.Transforms.ResizeImages("Image", - inputColumnName: "ImageObject", imageWidth: 299, - imageHeight: 299)) - .Append(mlContext.Transforms.ExtractPixels("Image", - interleavePixelColors: true)) - .Append(mlContext.Model.ImageClassification("Image", - "Label", arch: DnnEstimator.Architecture.InceptionV3, epoch: 4, - batchSize: 4)); - - var trainedModel = pipeline.Fit(data); - var predicted = trainedModel.Transform(data); - var metrics = mlContext.MulticlassClassification.Evaluate(predicted); - - Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}," + - $"macro-accuracy = {metrics.MacroAccuracy}"); - - // Create prediction function and test prediction - var predictFunction = mlContext.Model - .CreatePredictionEngine(trainedModel); - - var prediction = predictFunction - .Predict(ImageNetData.LoadImagesFromDirectory(imagesDataFile, 4) - .First()); - - Console.WriteLine($"Scores : [{string.Join(",", prediction.Score)}], " + - $"Predicted Label : {prediction.PredictedLabel}"); - - } - } - - public class ImageNetData - { - [LoadColumn(0)] - public string ImagePath; - - [LoadColumn(1)] - public string Label; - - public static IEnumerable LoadImagesFromDirectory( - string folder, int repeat = 1, bool useFolderNameasLabel = false) - { - var files = Directory.GetFiles(folder, "*", - searchOption: SearchOption.AllDirectories); - - foreach (var file in files) - { - if (Path.GetExtension(file) != ".jpg") - continue; - - var label = Path.GetFileName(file); - if (useFolderNameasLabel) - label = Directory.GetParent(file).Name; - else - { - for (int index = 0; index < label.Length; index++) - { - if (!char.IsLetter(label[index])) - { - label = label.Substring(0, index); - break; - } - } - } - - for (int index = 0; index < repeat; index++) - yield return new ImageNetData() { - ImagePath = file,Label = label }; - } - } - } - - public class ImagePrediction - { - [ColumnName("Score")] - public float[] Score; - - [ColumnName("PredictedLabel")] - public Int64 PredictedLabel; - } -} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearning.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearning.cs deleted file mode 100644 index 9d3136b01b..0000000000 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearning.cs +++ /dev/null @@ -1,123 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using Microsoft.ML; -using Microsoft.ML.Data; -using Microsoft.ML.Transforms; - -namespace Samples.Dynamic -{ - public static class ResnetV2101TransferLearning - { - /// - /// Example use of Image classification API in a ML.NET pipeline. - /// - public static void Example() - { - var mlContext = new MLContext(seed: 1); - - var imagesDataFile = Path.GetDirectoryName( - Microsoft.ML.SamplesUtils.DatasetUtils.DownloadImages()); - - var data = mlContext.Data.LoadFromEnumerable( - ImageNetData.LoadImagesFromDirectory(imagesDataFile, 4)); - - data = mlContext.Data.ShuffleRows(data, 5); - var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label") - .Append(mlContext.Transforms.LoadImages("ImageObject", null, - "ImagePath")) - .Append(mlContext.Transforms.ResizeImages("Image", - inputColumnName: "ImageObject", imageWidth: 299, - imageHeight: 299)) - .Append(mlContext.Transforms.ExtractPixels("Image", - interleavePixelColors: true)) - .Append(mlContext.Model.ImageClassification("Image", - "Label", arch: DnnEstimator.Architecture.ResnetV2101, epoch: 4, - batchSize: 4)); - - var trainedModel = pipeline.Fit(data); - var predicted = trainedModel.Transform(data); - var metrics = mlContext.MulticlassClassification.Evaluate(predicted); - - Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}," + - $"macro-accuracy = {metrics.MacroAccuracy}"); - - mlContext.Model.Save(trainedModel, data.Schema, "model.zip"); - - ITransformer loadedModel; - using (var file = File.OpenRead("model.zip")) - loadedModel = mlContext.Model.Load(file, out DataViewSchema schema); - - // Create prediction function and test prediction - var predictFunction = mlContext.Model - .CreatePredictionEngine(loadedModel); - - var prediction = predictFunction - .Predict(ImageNetData.LoadImagesFromDirectory(imagesDataFile, 4) - .First()); - - Console.WriteLine($"Scores : [{string.Join(",", prediction.Score)}], " + - $"Predicted Label : {prediction.PredictedLabel}"); - } - - private const int imageHeight = 224; - private const int imageWidth = 224; - private const int numChannels = 3; - private const int inputSize = imageHeight * imageWidth * numChannels; - - public class ImageNetData - { - [LoadColumn(0)] - public string ImagePath; - - [LoadColumn(1)] - public string Label; - - public static IEnumerable LoadImagesFromDirectory( - string folder, int repeat = 1, bool useFolderNameasLabel = false) - { - var files = Directory.GetFiles(folder, "*", - searchOption: SearchOption.AllDirectories); - - foreach (var file in files) - { - if (Path.GetExtension(file) != ".jpg") - continue; - - var label = Path.GetFileName(file); - if (useFolderNameasLabel) - label = Directory.GetParent(file).Name; - else - { - for (int index = 0; index < label.Length; index++) - { - if (!char.IsLetter(label[index])) - { - label = label.Substring(0, index); - break; - } - } - } - - for (int index = 0; index < repeat; index++) - yield return new ImageNetData() - { - ImagePath = file, - Label = label - }; - } - } - } - - public class ImagePrediction - { - [ColumnName("Score")] - public float[] Score; - - [ColumnName("PredictedLabel")] - public Int64 PredictedLabel; - } - } -} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs new file mode 100644 index 0000000000..ed3f4da5a4 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs @@ -0,0 +1,307 @@ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading.Tasks; +using Microsoft.ML; +using Microsoft.ML.Transforms; +using static Microsoft.ML.DataOperationsCatalog; +using System.Linq; +using Microsoft.ML.Data; +using System.IO.Compression; +using System.Threading; +using System.Net; + +namespace Samples.Dynamic +{ + public class ResnetV2101TransferLearningTrainTestSplit + { + public static void Example() + { + string assetsRelativePath = @"../../../assets"; + string assetsPath = GetAbsolutePath(assetsRelativePath); + + var outputMlNetModelFilePath = Path.Combine(assetsPath, "outputs", + "imageClassifier.zip"); + + string imagesDownloadFolderPath = Path.Combine(assetsPath, "inputs", + "images"); + + //Download the image set and unzip + string finalImagesFolderName = DownloadImageSet( + imagesDownloadFolderPath); + + string fullImagesetFolderPath = Path.Combine( + imagesDownloadFolderPath, finalImagesFolderName); + + try + { + + MLContext mlContext = new MLContext(seed: 1); + + //Load all the original images info + IEnumerable images = LoadImagesFromDirectory( + folder: fullImagesetFolderPath, useFolderNameasLabel: true); + + IDataView shuffledFullImagesDataset = mlContext.Data.ShuffleRows( + mlContext.Data.LoadFromEnumerable(images)); + + shuffledFullImagesDataset = mlContext.Transforms.Conversion + .MapValueToKey("Label") + .Fit(shuffledFullImagesDataset) + .Transform(shuffledFullImagesDataset); + + // Split the data 90:10 into train and test sets, train and evaluate. + TrainTestData trainTestData = mlContext.Data.TrainTestSplit( + shuffledFullImagesDataset, testFraction: 0.1, seed: 1); + + IDataView trainDataset = trainTestData.TrainSet; + IDataView testDataset = trainTestData.TestSet; + + var pipeline = mlContext.Model.ImageClassification( + "ImagePath", "Label", + // Just by changing/selecting InceptionV3 here instead of + // ResnetV2101 you can try a different architecture/pre-trained + // model. + arch: ImageClassificationEstimator.Architecture.ResnetV2101, + epoch: 50, + batchSize: 10, + learningRate: 0.01f, + metricsCallback: (metrics) => Console.WriteLine(metrics), + validationSet: testDataset); + + + Console.WriteLine("*** Training the image classification model with " + + "DNN Transfer Learning on top of the selected pre-trained " + + "model/architecture ***"); + + // Measuring training time + var watch = System.Diagnostics.Stopwatch.StartNew(); + + var trainedModel = pipeline.Fit(trainDataset); + + watch.Stop(); + long elapsedMs = watch.ElapsedMilliseconds; + + Console.WriteLine("Training with transfer learning took: " + + (elapsedMs / 1000).ToString() + " seconds"); + + mlContext.Model.Save(trainedModel, shuffledFullImagesDataset.Schema, + "model.zip"); + + ITransformer loadedModel; + DataViewSchema schema; + using (var file = File.OpenRead("model.zip")) + loadedModel = mlContext.Model.Load(file, out schema); + + EvaluateModel(mlContext, testDataset, loadedModel); + + VBuffer> keys = default; + loadedModel.GetOutputSchema(schema)["Label"].GetKeyValues(ref keys); + + watch = System.Diagnostics.Stopwatch.StartNew(); + TrySinglePrediction(fullImagesetFolderPath, mlContext, loadedModel, + keys.DenseValues().ToArray()); + + watch.Stop(); + elapsedMs = watch.ElapsedMilliseconds; + + Console.WriteLine("Prediction engine took: " + + (elapsedMs / 1000).ToString() + " seconds"); + } + catch (Exception ex) + { + Console.WriteLine(ex.ToString()); + } + + Console.WriteLine("Press any key to finish"); + Console.ReadKey(); + } + + private static void TrySinglePrediction(string imagesForPredictions, + MLContext mlContext, ITransformer trainedModel, + ReadOnlyMemory[] originalLabels) + { + // Create prediction function to try one prediction + var predictionEngine = mlContext.Model + .CreatePredictionEngine(trainedModel); + + IEnumerable testImages = LoadImagesFromDirectory( + imagesForPredictions, false); + + ImageData imageToPredict = new ImageData + { + ImagePath = testImages.First().ImagePath + }; + + var prediction = predictionEngine.Predict(imageToPredict); + var index = prediction.PredictedLabel; + + Console.WriteLine($"ImageFile : " + + $"[{Path.GetFileName(imageToPredict.ImagePath)}], " + + $"Scores : [{string.Join(",", prediction.Score)}], " + + $"Predicted Label : {originalLabels[index]}"); + } + + + private static void EvaluateModel(MLContext mlContext, + IDataView testDataset, ITransformer trainedModel) + { + Console.WriteLine("Making bulk predictions and evaluating model's " + + "quality..."); + + // Measuring time + var watch2 = System.Diagnostics.Stopwatch.StartNew(); + + IDataView predictions = trainedModel.Transform(testDataset); + var metrics = mlContext.MulticlassClassification.Evaluate(predictions); + + Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}," + + $"macro-accuracy = {metrics.MacroAccuracy}"); + + watch2.Stop(); + long elapsed2Ms = watch2.ElapsedMilliseconds; + + Console.WriteLine("Predicting and Evaluation took: " + + (elapsed2Ms / 1000).ToString() + " seconds"); + } + + public static IEnumerable LoadImagesFromDirectory(string folder, + bool useFolderNameasLabel = true) + { + var files = Directory.GetFiles(folder, "*", + searchOption: SearchOption.AllDirectories); + + foreach (var file in files) + { + if (Path.GetExtension(file) != ".jpg") + continue; + + var label = Path.GetFileName(file); + if (useFolderNameasLabel) + label = Directory.GetParent(file).Name; + else + { + for (int index = 0; index < label.Length; index++) + { + if (!char.IsLetter(label[index])) + { + label = label.Substring(0, index); + break; + } + } + } + + yield return new ImageData() + { + ImagePath = file, + Label = label + }; + + } + } + + public static string DownloadImageSet(string imagesDownloadFolder) + { + // get a set of images to teach the network about the new classes + + //SINGLE SMALL FLOWERS IMAGESET (200 files) + string fileName = "flower_photos_small_set.zip"; + string url = $"https://mlnetfilestorage.file.core.windows.net/" + + $"imagesets/flower_images/flower_photos_small_set.zip?st=2019-08-" + + $"07T21%3A27%3A44Z&se=2030-08-08T21%3A27%3A00Z&sp=rl&sv=2018-03-" + + $"28&sr=f&sig=SZ0UBX47pXD0F1rmrOM%2BfcwbPVob8hlgFtIlN89micM%3D"; + + Download(url, imagesDownloadFolder, fileName); + UnZip(Path.Combine(imagesDownloadFolder, fileName), imagesDownloadFolder); + + return Path.GetFileNameWithoutExtension(fileName); + } + + public static bool Download(string url, string destDir, string destFileName) + { + if (destFileName == null) + destFileName = url.Split(Path.DirectorySeparatorChar).Last(); + + Directory.CreateDirectory(destDir); + + string relativeFilePath = Path.Combine(destDir, destFileName); + + if (File.Exists(relativeFilePath)) + { + Console.WriteLine($"{relativeFilePath} already exists."); + return false; + } + + var wc = new WebClient(); + Console.WriteLine($"Downloading {relativeFilePath}"); + var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath)); + while (!download.IsCompleted) + { + Thread.Sleep(1000); + Console.Write("."); + } + Console.WriteLine(""); + Console.WriteLine($"Downloaded {relativeFilePath}"); + + return true; + } + + public static void UnZip(String gzArchiveName, String destFolder) + { + var flag = gzArchiveName.Split(Path.DirectorySeparatorChar) + .Last() + .Split('.') + .First() + ".bin"; + + if (File.Exists(Path.Combine(destFolder, flag))) return; + + Console.WriteLine($"Extracting."); + var task = Task.Run(() => + { + ZipFile.ExtractToDirectory(gzArchiveName, destFolder); + }); + + while (!task.IsCompleted) + { + Thread.Sleep(200); + Console.Write("."); + } + + File.Create(Path.Combine(destFolder, flag)); + Console.WriteLine(""); + Console.WriteLine("Extracting is completed."); + } + + public static string GetAbsolutePath(string relativePath) + { + FileInfo _dataRoot = new FileInfo(typeof( + ResnetV2101TransferLearningTrainTestSplit).Assembly.Location); + + string assemblyFolderPath = _dataRoot.Directory.FullName; + + string fullPath = Path.Combine(assemblyFolderPath, relativePath); + + return fullPath; + } + + public class ImageData + { + [LoadColumn(0)] + public string ImagePath; + + [LoadColumn(1)] + public string Label; + } + + public class ImagePrediction + { + [ColumnName("Score")] + public float[] Score; + + [ColumnName("PredictedLabel")] + public UInt32 PredictedLabel; + } + } +} + diff --git a/src/Microsoft.ML.Dnn/DnnCatalog.cs b/src/Microsoft.ML.Dnn/DnnCatalog.cs index 3f5ee076ff..fea3735796 100644 --- a/src/Microsoft.ML.Dnn/DnnCatalog.cs +++ b/src/Microsoft.ML.Dnn/DnnCatalog.cs @@ -9,11 +9,12 @@ using Microsoft.ML.Data; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Dnn; -using static Microsoft.ML.Transforms.DnnEstimator; +using static Microsoft.ML.Transforms.ImageClassificationEstimator; +using Options = Microsoft.ML.Transforms.DnnRetrainEstimator.Options; namespace Microsoft.ML { - /// + /// public static class DnnCatalog { @@ -36,11 +37,10 @@ public static class DnnCatalog /// Learning rate to use during optimization (Optional). /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well. - /// /// /// The support for retraining is under preview. /// - public static DnnEstimator RetrainDnnModel( + public static DnnRetrainEstimator RetrainDnnModel( this ModelOperationsCatalog catalog, string[] outputColumnNames, string[] inputColumnNames, @@ -54,8 +54,7 @@ public static DnnEstimator RetrainDnnModel( string metricOperation = null, string learningRateOperation = null, float learningRate = 0.01f, - bool addBatchDimensionInput = false, - DnnFramework dnnFramework = DnnFramework.Tensorflow) + bool addBatchDimensionInput = false) { var options = new Options() { @@ -71,12 +70,11 @@ public static DnnEstimator RetrainDnnModel( LearningRateOperation = learningRateOperation, LearningRate = learningRate, BatchSize = batchSize, - AddBatchDimensionInputs = addBatchDimensionInput, - ReTrain = true + AddBatchDimensionInputs = addBatchDimensionInput }; var env = CatalogUtils.GetEnvironment(catalog); - return new DnnEstimator(env, options, DnnUtils.LoadDnnModel(env, modelPath, true)); + return new DnnRetrainEstimator(env, options, DnnUtils.LoadDnnModel(env, modelPath, true)); } /// @@ -85,33 +83,50 @@ public static DnnEstimator RetrainDnnModel( /// /// The name of the input features column. /// The name of the labels column. - /// Optional name of the path where a copy new graph should be saved. The graph will be saved as part of model. /// The name of the output score column. /// The name of the output predicted label columns. - /// The name of the prefix for checkpoint files. /// The architecture of the image recognition DNN model. - /// The backend DNN framework to use, currently only Tensorflow is supported. - /// Number of training epochs. + /// Number of training iterations. Each iteration/epoch refers to one pass over the dataset. /// The batch size for training. /// The learning rate for training. + /// Callback for reporting model statistics during training phase. + /// Indicates the frequency of epochs at which to report model statistics during training phase. + /// Indicates the choice of DNN training framework. Currently only tensorflow is supported. + /// Optional name of the path where a copy new graph should be saved. The graph will be saved as part of model. + /// The name of the prefix for the final mode and checkpoint files. + /// Validation set. + /// Indicates to evaluate the model on train set after every epoch. + /// Indicates to not re-compute cached trainset bottleneck values if already available in the bin folder. + /// Indicates to not re-compute validataionset cached bottleneck validationset values if already available in the bin folder. + /// Indicates the file path to store trainset bottleneck values for caching. + /// Indicates the file path to store validationset bottleneck values for caching. /// /// The support for image classification is under preview. /// - public static DnnEstimator ImageClassification( + public static ImageClassificationEstimator ImageClassification( this ModelOperationsCatalog catalog, string featuresColumnName, string labelColumnName, - string outputGraphPath = null, string scoreColumnName = "Score", string predictedLabelColumnName = "PredictedLabel", - string checkpointName = "_retrain_checkpoint", - Architecture arch = Architecture.ResnetV2101, - DnnFramework dnnFramework = DnnFramework.Tensorflow, - int epoch = 10, - int batchSize = 20, - float learningRate = 0.01f) + Architecture arch = Architecture.InceptionV3, + int epoch = 100, + int batchSize = 10, + float learningRate = 0.01f, + ImageClassificationMetricsCallback metricsCallback = null, + int statisticFrequency = 1, + DnnFramework framework = DnnFramework.Tensorflow, + string modelSavePath = null, + string finalModelPrefix = "custom_retrained_model_based_on_", + IDataView validationSet = null, + bool testOnTrainSet = true, + bool reuseTrainSetBottleneckCachedValues = false, + bool reuseValidationSetBottleneckCachedValues = false, + string trainSetBottleneckCachedValuesFilePath = "trainSetBottleneckFile.csv", + string validationSetBottleneckCachedValuesFilePath = "validationSetBottleneckFile.csv" + ) { - var options = new Options() + var options = new ImageClassificationEstimator.Options() { ModelLocation = arch == Architecture.ResnetV2101 ? @"resnet_v2_101_299.meta" : @"InceptionV3.meta", InputColumns = new[] { featuresColumnName }, @@ -121,13 +136,20 @@ public static DnnEstimator ImageClassification( Epoch = epoch, LearningRate = learningRate, BatchSize = batchSize, - AddBatchDimensionInputs = arch == Architecture.InceptionV3 ? false : true, - TransferLearning = true, ScoreColumnName = scoreColumnName, PredictedLabelColumnName = predictedLabelColumnName, - CheckpointName = checkpointName, + FinalModelPrefix = finalModelPrefix, Arch = arch, - MeasureTrainAccuracy = false + MetricsCallback = metricsCallback, + StatisticsFrequency = statisticFrequency, + Framework = framework, + ModelSavePath = modelSavePath, + ValidationSet = validationSet, + TestOnTrainSet = testOnTrainSet, + TrainSetBottleneckCachedValuesFilePath = trainSetBottleneckCachedValuesFilePath, + ValidationSetBottleneckCachedValuesFilePath = validationSetBottleneckCachedValuesFilePath, + ReuseTrainSetBottleneckCachedValues = reuseTrainSetBottleneckCachedValues, + ReuseValidationSetBottleneckCachedValues = reuseValidationSetBottleneckCachedValues }; if (!File.Exists(options.ModelLocation)) @@ -158,7 +180,7 @@ public static DnnEstimator ImageClassification( } var env = CatalogUtils.GetEnvironment(catalog); - return new DnnEstimator(env, options, DnnUtils.LoadDnnModel(env, options.ModelLocation, true)); + return new ImageClassificationEstimator(env, options, DnnUtils.LoadDnnModel(env, options.ModelLocation, true)); } } } diff --git a/src/Microsoft.ML.Dnn/DnnModel.cs b/src/Microsoft.ML.Dnn/DnnModel.cs index 6f8c54edb7..a5324e9e39 100644 --- a/src/Microsoft.ML.Dnn/DnnModel.cs +++ b/src/Microsoft.ML.Dnn/DnnModel.cs @@ -4,14 +4,14 @@ using Microsoft.ML.Runtime; using Tensorflow; -using static Microsoft.ML.Transforms.DnnEstimator; +using static Microsoft.ML.Transforms.DnnRetrainEstimator; namespace Microsoft.ML.Transforms { /// /// This class holds the information related to TensorFlow model and session. /// It provides some convenient methods to query model schema as well as - /// creation of object. + /// creation of object. /// public sealed class DnnModel { diff --git a/src/Microsoft.ML.Dnn/DnnTransform.cs b/src/Microsoft.ML.Dnn/DnnRetrainTransform.cs similarity index 64% rename from src/Microsoft.ML.Dnn/DnnTransform.cs rename to src/Microsoft.ML.Dnn/DnnRetrainTransform.cs index 56b3e13fd1..99cbd1afa5 100644 --- a/src/Microsoft.ML.Dnn/DnnTransform.cs +++ b/src/Microsoft.ML.Dnn/DnnRetrainTransform.cs @@ -8,7 +8,6 @@ using System.IO; using System.Linq; using System.Text; -using Google.Protobuf; using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; @@ -18,33 +17,29 @@ using Microsoft.ML.Transforms.Dnn; using NumSharp; using Tensorflow; -using Tensorflow.Summaries; using static Microsoft.ML.Transforms.Dnn.DnnUtils; -using static Microsoft.ML.Transforms.DnnEstimator; -using static Tensorflow.Python; -[assembly: LoadableClass(DnnTransformer.Summary, typeof(IDataTransform), typeof(DnnTransformer), - typeof(DnnEstimator.Options), typeof(SignatureDataTransform), DnnTransformer.UserName, DnnTransformer.ShortName)] +[assembly: LoadableClass(DnnRetrainTransformer.Summary, typeof(IDataTransform), typeof(DnnRetrainTransformer), + typeof(DnnRetrainEstimator.Options), typeof(SignatureDataTransform), DnnRetrainTransformer.UserName, DnnRetrainTransformer.ShortName)] -[assembly: LoadableClass(DnnTransformer.Summary, typeof(IDataTransform), typeof(DnnTransformer), null, typeof(SignatureLoadDataTransform), - DnnTransformer.UserName, DnnTransformer.LoaderSignature)] +[assembly: LoadableClass(DnnRetrainTransformer.Summary, typeof(IDataTransform), typeof(DnnRetrainTransformer), null, typeof(SignatureLoadDataTransform), + DnnRetrainTransformer.UserName, DnnRetrainTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(DnnTransformer), null, typeof(SignatureLoadModel), - DnnTransformer.UserName, DnnTransformer.LoaderSignature)] +[assembly: LoadableClass(typeof(DnnRetrainTransformer), null, typeof(SignatureLoadModel), + DnnRetrainTransformer.UserName, DnnRetrainTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(IRowMapper), typeof(DnnTransformer), null, typeof(SignatureLoadRowMapper), - DnnTransformer.UserName, DnnTransformer.LoaderSignature)] +[assembly: LoadableClass(typeof(IRowMapper), typeof(DnnRetrainTransformer), null, typeof(SignatureLoadRowMapper), + DnnRetrainTransformer.UserName, DnnRetrainTransformer.LoaderSignature)] namespace Microsoft.ML.Transforms { /// - /// for the . + /// for the . /// - public sealed class DnnTransformer : RowToRowTransformerBase + public sealed class DnnRetrainTransformer : RowToRowTransformerBase { private readonly IHostEnvironment _env; private readonly string _modelLocation; - private readonly bool _transferLearning; private readonly bool _isTemporarySavedModel; private readonly bool _addBatchDimensionInput; private Session _session; @@ -56,33 +51,15 @@ public sealed class DnnTransformer : RowToRowTransformerBase private readonly (Operation, int)[] _tfOutputOperations; private TF_Output[] _tfInputNodes; private readonly TF_Output[] _tfOutputNodes; - private Tensor _bottleneckTensor; - private Operation _trainStep; - private Tensor _softMaxTensor; - private Tensor _crossEntropy; - private Tensor _labelTensor; - private Tensor _evaluationStep; - private Tensor _prediction; - private readonly int _classCount; - private readonly string _checkpointPath; - private readonly string _bottleneckOperationName; private Graph Graph => _session.graph; private readonly Dictionary _idvToTfMapping; private readonly string[] _inputs; private readonly string[] _outputs; - private readonly string _labelColumnName; - private readonly string _checkpointName; - private readonly Architecture _arch; - private readonly string _scoreColumnName; - private readonly string _predictedLabelColumnName; - private readonly float _learningRate; - private readonly string _softmaxTensorName; - private readonly string _predictionTensorName; - - internal const string Summary = "Trains Dnn models."; - internal const string UserName = "DnnTransform"; - internal const string ShortName = "DnnTransform"; - internal const string LoaderSignature = "DnnTransform"; + + internal const string Summary = "Re-Trains Dnn models."; + internal const string UserName = "DnnRtTransform"; + internal const string ShortName = "DnnRtTransform"; + internal const string LoaderSignature = "DnnRtTransform"; internal static class DefaultModelFileNames { @@ -102,11 +79,11 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00000001, verWeCanReadBack: 0x00000001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(DnnTransformer).Assembly.FullName); + loaderAssemblyName: typeof(DnnRetrainTransformer).Assembly.FullName); } // Factory method for SignatureLoadModel. - private static DnnTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + private static DnnRetrainTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -123,9 +100,7 @@ private static DnnTransformer Create(IHostEnvironment env, ModelLoadContext ctx) // int: id of output column name // stream: tensorFlow model. - GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput, - out bool transferLearning, out string labelColumn, out string checkpointName, out Architecture arch, out string scoreColumnName, - out string predictedColumnName, out float learningRate, out int classCount, out string predictionTensorName, out string softMaxTensorName); + GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput); if (isFrozen) { @@ -133,12 +108,11 @@ private static DnnTransformer Create(IHostEnvironment env, ModelLoadContext ctx) if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) throw env.ExceptDecode(); - return new DnnTransformer(env, DnnUtils.LoadTFSession(env, modelBytes), outputs, inputs, - null, false, addBatchDimensionInput, 1, transferLearning, labelColumn, checkpointName, arch, - scoreColumnName, predictedColumnName, learningRate, null, classCount, true, predictionTensorName, softMaxTensorName); + return new DnnRetrainTransformer(env, DnnUtils.LoadTFSession(env, modelBytes), outputs, inputs, + null, false, addBatchDimensionInput, 1); } - var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(DnnTransformer) + "_" + Guid.NewGuid())); + var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(DnnRetrainTransformer) + "_" + Guid.NewGuid())); DnnUtils.CreateFolderWithAclIfNotExists(env, tempDirPath); try { @@ -164,9 +138,8 @@ private static DnnTransformer Create(IHostEnvironment env, ModelLoadContext ctx) } }); - return new DnnTransformer(env, DnnUtils.GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true, - addBatchDimensionInput, 1, transferLearning, labelColumn, checkpointName, arch, - scoreColumnName, predictedColumnName, learningRate, null, classCount, true, predictionTensorName, softMaxTensorName); + return new DnnRetrainTransformer(env, DnnUtils.GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true, + addBatchDimensionInput, 1); } catch (Exception) { @@ -176,7 +149,7 @@ private static DnnTransformer Create(IHostEnvironment env, ModelLoadContext ctx) } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, DnnEstimator.Options options, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, DnnRetrainEstimator.Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(options, nameof(options)); @@ -184,33 +157,30 @@ internal static IDataTransform Create(IHostEnvironment env, DnnEstimator.Options env.CheckValue(options.InputColumns, nameof(options.InputColumns)); env.CheckValue(options.OutputColumns, nameof(options.OutputColumns)); - return new DnnTransformer(env, options, input).MakeDataTransform(input); + return new DnnRetrainTransformer(env, options, input).MakeDataTransform(input); } - internal DnnTransformer(IHostEnvironment env, DnnEstimator.Options options, IDataView input) + internal DnnRetrainTransformer(IHostEnvironment env, DnnRetrainEstimator.Options options, IDataView input) : this(env, options, DnnUtils.LoadDnnModel(env, options.ModelLocation), input) { } - internal DnnTransformer(IHostEnvironment env, DnnEstimator.Options options, DnnModel tensorFlowModel, IDataView input, IDataView validationSet = null) + internal DnnRetrainTransformer(IHostEnvironment env, DnnRetrainEstimator.Options options, DnnModel tensorFlowModel, IDataView input, IDataView validationSet = null) : this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns, - options.ModelLocation, false, options.AddBatchDimensionInputs, options.BatchSize, options.TransferLearning, - options.LabelColumn, options.CheckpointName, options.Arch, options.ScoreColumnName, - options.PredictedLabelColumnName, options.LearningRate, input.Schema) + options.ModelLocation, false, options.AddBatchDimensionInputs, options.BatchSize) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - if (options.ReTrain) - CheckTrainingParameters(options); + CheckTrainingParameters(options); - if (options.ReTrain && !DnnUtils.IsSavedModel(env, options.ModelLocation)) + if (!DnnUtils.IsSavedModel(env, options.ModelLocation)) throw env.ExceptNotSupp("TensorFlowTransform: Re-Training of TensorFlow model is only supported for un-frozen model."); TrainCore(options, input, validationSet); } - private void CheckTrainingParameters(DnnEstimator.Options options) + private void CheckTrainingParameters(DnnRetrainEstimator.Options options) { Host.CheckNonWhiteSpace(options.LabelColumn, nameof(options.LabelColumn)); Host.CheckNonWhiteSpace(options.OptimizationOperation, nameof(options.OptimizationOperation)); @@ -296,7 +266,7 @@ private void CheckTrainingParameters(DnnEstimator.Options options) return (inputColIndex, isInputVector, tfInputType, tfInputShape); } - private void TrainCore(DnnEstimator.Options options, IDataView input, IDataView validationSet) + private void TrainCore(DnnRetrainEstimator.Options options, IDataView input, IDataView validationSet) { var inputsForTraining = new string[_inputs.Length + 1]; var inputColIndices = new int[inputsForTraining.Length]; @@ -313,10 +283,7 @@ private void TrainCore(DnnEstimator.Options options, IDataView input, IDataView GetTrainingInputInfo(inputSchema, _inputs[i], inputsForTraining[i], options.BatchSize); var index = inputsForTraining.Length - 1; - if (options.TransferLearning) - inputsForTraining[index] = _labelTensor.name.Split(':').First(); - else - inputsForTraining[index] = options.TensorFlowLabel; + inputsForTraining[index] = options.TensorFlowLabel; (inputColIndices[index], isInputVector[index], tfInputTypes[index], tfInputShapes[index]) = GetTrainingInputInfo(inputSchema, options.LabelColumn, inputsForTraining[index], options.BatchSize); @@ -324,14 +291,9 @@ private void TrainCore(DnnEstimator.Options options, IDataView input, IDataView // Create graph inputs. Operation labelOp; int labelOpIdx; - if (options.ReTrain) - (labelOp, labelOpIdx) = GetOperationFromName(options.TensorFlowLabel, _session); - else - (labelOp, labelOpIdx) = GetOperationFromName(_labelTensor.name, _session); - + (labelOp, labelOpIdx) = GetOperationFromName(options.TensorFlowLabel, _session); TF_Output[] tfInputs; - - if (options.ReTrain && !string.IsNullOrEmpty(options.LearningRateOperation)) + if (!string.IsNullOrEmpty(options.LearningRateOperation)) tfInputs = new TF_Output[_tfInputNodes.Length + 2]; //Inputs + Label + Learning Rate. else tfInputs = new TF_Output[_tfInputNodes.Length + 1]; //Inputs + Label. @@ -339,32 +301,13 @@ private void TrainCore(DnnEstimator.Options options, IDataView input, IDataView Array.Copy(_tfInputNodes, tfInputs, _tfInputNodes.Length); tfInputs[_tfInputNodes.Length] = new TF_Output(labelOp, labelOpIdx); - - if (options.ReTrain) - { - var lr = GetOperationFromName(options.LearningRateOperation, _session); - tfInputs[_tfInputNodes.Length + 1] = new TF_Output(lr.Item1, lr.Item2); - } + var lr = GetOperationFromName(options.LearningRateOperation, _session); + tfInputs[_tfInputNodes.Length + 1] = new TF_Output(lr.Item1, lr.Item2); // Create graph operations. IntPtr[] ops = null; - if (options.ReTrain && options.OptimizationOperation != null) + if (options.OptimizationOperation != null) ops = new[] { c_api.TF_GraphOperationByName(Graph, options.OptimizationOperation) }; - else - ops = new[] { (IntPtr)_trainStep }; - - Saver trainSaver = null; - FileWriter trainWriter = null; - Tensor merged = null; - Runner testSetRunner = null; - Runner validationSetRunner = null; - if (options.TransferLearning) - { - merged = tf.summary.merge_all(); - trainWriter = tf.summary.FileWriter(Path.Combine(Directory.GetCurrentDirectory(), "train"), _session.graph); - trainSaver = tf.train.Saver(); - trainSaver.save(_session, _checkpointPath); - } // Instantiate the graph. Runner runner; @@ -379,190 +322,53 @@ private void TrainCore(DnnEstimator.Options options, IDataView input, IDataView using (var ch = Host.Start("Training TensorFlow model...")) using (var pch = Host.StartProgressChannel("TensorFlow training progress...")) { - if (options.ReTrain) - { - float loss = 0; - float metric = 0; - pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch)); + float loss = 0; + float metric = 0; + pch.SetHeader(new ProgressHeader(new[] { "Loss", "Metric" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch)); - while (cursor.MoveNext()) - { - for (int i = 0; i < inputsForTraining.Length; i++) - { - isDataLeft = true; - srcTensorGetters[i].BufferTrainingData(); - } - - if (((cursor.Position + 1) % options.BatchSize) == 0) - { - isDataLeft = false; - runner = new Runner(_session); - - // Add Learning Rate. - if (!string.IsNullOrEmpty(options.LearningRateOperation)) - runner.AddInput(options.LearningRateOperation, new Tensor(options.LearningRate)); - - // Add operations. - if (!string.IsNullOrEmpty(options.OptimizationOperation)) - runner.AddOperation(options.OptimizationOperation); - - // Add outputs. - if (options.LossOperation != null) - runner.AddOutputs(options.LossOperation); - if (options.MetricOperation != null) - runner.AddOutputs(options.MetricOperation); - - var (l, m) = ExecuteGraphAndRetrieveMetrics(inputsForTraining, srcTensorGetters, runner); - loss += l; - metric += m; - } - } - if (isDataLeft) - { - isDataLeft = false; - ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize); - } - pch.Checkpoint(new double?[] { loss, metric }); - } - else + while (cursor.MoveNext()) { - pch.SetHeader(new ProgressHeader(null, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch)); - - while (cursor.MoveNext()) + for (int i = 0; i < inputsForTraining.Length; i++) { - for (int i = 0; i < inputsForTraining.Length; i++) - { - isDataLeft = true; - srcTensorGetters[i].BufferTrainingData(); - } - - if (((cursor.Position + 1) % options.BatchSize) == 0) - { - isDataLeft = false; - runner = new Runner(_session); - - // Add operations. - runner.AddOperation(_trainStep); - - // Feed inputs. - for (int i = 0; i < inputsForTraining.Length; i++) - runner.AddInput(inputsForTraining[i], srcTensorGetters[i].GetBufferedBatchTensor()); - - // Execute the graph. - var t = runner.Run(); - } + isDataLeft = true; + srcTensorGetters[i].BufferTrainingData(); } - if (isDataLeft) + if (((cursor.Position + 1) % options.BatchSize) == 0) { isDataLeft = false; - ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize); - } - } - } - } + runner = new Runner(_session); - // Measure accuracy of the model. - if (options.TransferLearning && options.MeasureTrainAccuracy) - { - // Test on the training set to get accuracy. - using (var cursor = input.GetRowCursor(cols)) - { - var srcTensorGetters = GetTensorValueGetters(cursor, inputColIndices, isInputVector, tfInputTypes, tfInputShapes); - - float accuracy = 0; - float crossEntropy = 0; - bool isDataLeft = false; - int batch = 0; - using (var ch = Host.Start("Test TensorFlow model...")) - using (var pch = Host.StartProgressChannel("TensorFlow testing progress...")) - { - pch.SetHeader(new ProgressHeader(new[] { "Accuracy", "Cross Entropy" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch)); + // Add Learning Rate. + if (!string.IsNullOrEmpty(options.LearningRateOperation)) + runner.AddInput(options.LearningRateOperation, new Tensor(options.LearningRate)); - while (cursor.MoveNext()) - { - for (int i = 0; i < inputColIndices.Length; i++) - { - isDataLeft = true; - srcTensorGetters[i].BufferTrainingData(); - } - - if (((cursor.Position + 1) % options.BatchSize) == 0) - { - isDataLeft = false; - testSetRunner = new Runner(_session); - testSetRunner.AddOutputs(_evaluationStep.name); - testSetRunner.AddOutputs(_crossEntropy.name); - testSetRunner.AddOutputs(_bottleneckTensor.name); - var (acc, ce) = ExecuteGraphAndRetrieveMetrics(inputsForTraining, srcTensorGetters, testSetRunner); - accuracy += acc; - crossEntropy += ce; - batch++; - } - } + // Add operations. + if (!string.IsNullOrEmpty(options.OptimizationOperation)) + runner.AddOperation(options.OptimizationOperation); - if (isDataLeft) - { - isDataLeft = false; - ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize); + // Add outputs. + if (options.LossOperation != null) + runner.AddOutputs(options.LossOperation); + if (options.MetricOperation != null) + runner.AddOutputs(options.MetricOperation); + + var (l, m) = ExecuteGraphAndRetrieveMetrics(inputsForTraining, srcTensorGetters, runner); + loss += l; + metric += m; } - pch.Checkpoint(new double?[] { accuracy / batch, crossEntropy / batch }); - ch.Info(MessageSensitivity.None, $"Accuracy: {accuracy / batch}, Cross-Entropy: {crossEntropy / batch}"); } - } - - // Test on the validation set. - if (validationSet != null) - { - using (var cursor = validationSet.GetRowCursor(cols)) + if (isDataLeft) { - var srcTensorGetters = GetTensorValueGetters(cursor, inputColIndices, isInputVector, tfInputTypes, tfInputShapes); - - float accuracy = 0; - bool isDataLeft = false; - int batch = 0; - using (var ch = Host.Start("Test TensorFlow model with validation set...")) - using (var pch = Host.StartProgressChannel("TensorFlow validation progress...")) - { - pch.SetHeader(new ProgressHeader(new[] { "Accuracy" }, new[] { "Epoch" }), (e) => e.SetProgress(0, epoch, options.Epoch)); - - while (cursor.MoveNext()) - { - for (int i = 0; i < inputColIndices.Length; i++) - { - isDataLeft = true; - srcTensorGetters[i].BufferTrainingData(); - } - - if (((cursor.Position + 1) % options.BatchSize) == 0) - { - isDataLeft = false; - validationSetRunner = new Runner(_session); - validationSetRunner.AddOutputs(_evaluationStep.name); - var (acc, _) = ExecuteGraphAndRetrieveMetrics(inputsForTraining, srcTensorGetters, validationSetRunner); - accuracy += acc; - batch++; - } - } - if (isDataLeft) - { - isDataLeft = false; - ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize); - } - pch.Checkpoint(new double?[] { accuracy / batch }); - } + isDataLeft = false; + ch.Warning("Not training on the last batch. The batch size is less than {0}.", options.BatchSize); } + pch.Checkpoint(new double?[] { loss, metric }); } } } - if (options.ReTrain) - UpdateModelOnDisk(options.ModelLocation, options); - else - { - trainSaver.save(_session, _checkpointPath); - UpdateTransferLearningModelOnDisk(options, _classCount); - } + UpdateModelOnDisk(options.ModelLocation, options); } private (float loss, float metric) ExecuteGraphAndRetrieveMetrics( @@ -588,7 +394,7 @@ private void TrainCore(DnnEstimator.Options options, IDataView input, IDataView /// After retraining Session and Graphs are both up-to-date /// However model on disk is not which is used to serialzed to ML.Net stream /// - private void UpdateModelOnDisk(string modelDir, DnnEstimator.Options options) + private void UpdateModelOnDisk(string modelDir, DnnRetrainEstimator.Options options) { try { @@ -648,150 +454,6 @@ private void UpdateModelOnDisk(string modelDir, DnnEstimator.Options options) } } - private (Session, Tensor, Tensor, Tensor) BuildEvaluationSession(DnnEstimator.Options options, int classCount) - { - var evalGraph = DnnUtils.LoadMetaGraph(options.ModelLocation); - var evalSess = tf.Session(graph: evalGraph); - Tensor evaluationStep = null; - Tensor prediction = null; - Tensor bottleneckTensor = evalGraph.OperationByName(_bottleneckOperationName); - - tf_with(evalGraph.as_default(), graph => - { - var (_, _, groundTruthInput, finalTensor) = AddFinalRetrainOps(classCount, options.LabelColumn, - options.ScoreColumnName, options.LearningRate, bottleneckTensor, false); - - tf.train.Saver().restore(evalSess, Path.Combine(Directory.GetCurrentDirectory(), _checkpointPath)); - (evaluationStep, prediction) = AddEvaluationStep(finalTensor, groundTruthInput); - }); - - return (evalSess, _labelTensor, evaluationStep, prediction); - } - - private (Tensor, Tensor) AddEvaluationStep(Tensor resultTensor, Tensor groundTruthTensor) - { - Tensor evaluationStep = null; - Tensor correctPrediction = null; - - tf_with(tf.name_scope("accuracy"), scope => - { - tf_with(tf.name_scope("correct_prediction"), delegate - { - _prediction = tf.argmax(resultTensor, 1); - correctPrediction = tf.equal(_prediction, groundTruthTensor); - }); - - tf_with(tf.name_scope("accuracy"), delegate - { - evaluationStep = tf.reduce_mean(tf.cast(correctPrediction, tf.float32)); - }); - }); - - tf.summary.scalar("accuracy", evaluationStep); - return (evaluationStep, _prediction); - } - - private void UpdateTransferLearningModelOnDisk(DnnEstimator.Options options, int classCount) - { - var (sess, _, _, _) = BuildEvaluationSession(options, classCount); - var graph = sess.graph; - var outputGraphDef = tf.graph_util.convert_variables_to_constants( - sess, graph.as_graph_def(), new string[] { _softMaxTensor.name.Split(':')[0], _prediction.name.Split(':')[0] }); - - string frozenModelPath = _checkpointPath + ".pb"; - File.WriteAllBytes(_checkpointPath + ".pb", outputGraphDef.ToByteArray()); - _session = LoadTFSessionByModelFilePath(_env, frozenModelPath, false); - } - - private void VariableSummaries(RefVariable var) - { - tf_with(tf.name_scope("summaries"), delegate - { - var mean = tf.reduce_mean(var); - tf.summary.scalar("mean", mean); - Tensor stddev = null; - tf_with(tf.name_scope("stddev"), delegate - { - stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))); - }); - tf.summary.scalar("stddev", stddev); - tf.summary.scalar("max", tf.reduce_max(var)); - tf.summary.scalar("min", tf.reduce_min(var)); - tf.summary.histogram("histogram", var); - }); - } - - private (Operation, Tensor, Tensor, Tensor) AddFinalRetrainOps(int classCount, string labelColumn, - string scoreColumnName, float learningRate, Tensor bottleneckTensor, bool isTraining) - { - var (batch_size, bottleneck_tensor_size) = (bottleneckTensor.TensorShape.Dimensions[0], bottleneckTensor.TensorShape.Dimensions[1]); - tf_with(tf.name_scope("input"), scope => - { - _labelTensor = tf.placeholder(tf.int64, new TensorShape(batch_size), name: labelColumn); - }); - - string layerName = "final_retrain_ops"; - Tensor logits = null; - tf_with(tf.name_scope(layerName), scope => - { - RefVariable layerWeights = null; - tf_with(tf.name_scope("weights"), delegate - { - var initialValue = tf.truncated_normal(new int[] { bottleneck_tensor_size, classCount }, stddev: 0.001f); - layerWeights = tf.Variable(initialValue, name: "final_weights"); - VariableSummaries(layerWeights); - }); - - RefVariable layerBiases = null; - tf_with(tf.name_scope("biases"), delegate - { - layerBiases = tf.Variable(tf.zeros(classCount), name: "final_biases"); - VariableSummaries(layerBiases); - }); - - tf_with(tf.name_scope("Wx_plus_b"), delegate - { - var matmul = tf.matmul(bottleneckTensor, layerWeights); - logits = matmul + layerBiases; - tf.summary.histogram("pre_activations", logits); - }); - }); - - _softMaxTensor = tf.nn.softmax(logits, name: scoreColumnName); - - tf.summary.histogram("activations", _softMaxTensor); - if (!isTraining) - return (null, null, _labelTensor, _softMaxTensor); - - Tensor crossEntropyMean = null; - tf_with(tf.name_scope("cross_entropy"), delegate - { - crossEntropyMean = tf.losses.sparse_softmax_cross_entropy( - labels: _labelTensor, logits: logits); - }); - - tf.summary.scalar("cross_entropy", crossEntropyMean); - - tf_with(tf.name_scope("train"), delegate - { - var optimizer = tf.train.GradientDescentOptimizer(learningRate); - _trainStep = optimizer.minimize(crossEntropyMean); - }); - - return (_trainStep, crossEntropyMean, _labelTensor, _softMaxTensor); - } - - private void AddTransferLearningLayer(string labelColumn, - string scoreColumnName, float learningRate, int classCount) - { - _bottleneckTensor = Graph.OperationByName(_bottleneckOperationName); - tf_with(Graph.as_default(), delegate - { - (_trainStep, _crossEntropy, _labelTensor, _softMaxTensor) = - AddFinalRetrainOps(classCount, labelColumn, scoreColumnName, learningRate, _bottleneckTensor, true); - }); - } - private static ITensorValueGetter CreateTensorValueGetter(DataViewRow input, bool isVector, int colIndex, TensorShape tfShape, bool keyType = false) { if (isVector) @@ -833,9 +495,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat => Create(env, ctx).MakeRowMapper(inputSchema); private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, - out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput, out bool transferLearning, - out string labelColumn, out string checkpointName, out Architecture arch, - out string scoreColumnName, out string predictedColumnName, out float learningRate, out int classCount, out string predictionTensorName, out string softMaxTensorName) + out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput) { isFrozen = ctx.Reader.ReadBoolByte(); addBatchDimensionInput = ctx.Reader.ReadBoolByte(); @@ -851,26 +511,12 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out outputs = new string[numOutputs]; for (int j = 0; j < outputs.Length; j++) outputs[j] = ctx.LoadNonEmptyString(); - - transferLearning = ctx.Reader.ReadBoolean(); - labelColumn = ctx.Reader.ReadString(); - checkpointName = ctx.Reader.ReadString(); - arch = (Architecture)ctx.Reader.ReadInt32(); - scoreColumnName = ctx.Reader.ReadString(); - predictedColumnName = ctx.Reader.ReadString(); - learningRate = ctx.Reader.ReadFloat(); - classCount = ctx.Reader.ReadInt32(); - predictionTensorName = ctx.Reader.ReadString(); - softMaxTensorName = ctx.Reader.ReadString(); - } - internal DnnTransformer(IHostEnvironment env, Session session, string[] outputColumnNames, + internal DnnRetrainTransformer(IHostEnvironment env, Session session, string[] outputColumnNames, string[] inputColumnNames, string modelLocation, bool isTemporarySavedModel, - bool addBatchDimensionInput, int batchSize, bool transferLearning, string labelColumnName, string checkpointName, Architecture arch, - string scoreColumnName, string predictedLabelColumnName, float learningRate, DataViewSchema inputSchema, int? classCount = null, bool loadModel = false, - string predictionTensorName = null, string softMaxTensorName = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(DnnTransformer))) + bool addBatchDimensionInput, int batchSize) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(DnnRetrainTransformer))) { Host.CheckValue(session, nameof(session)); @@ -885,76 +531,15 @@ internal DnnTransformer(IHostEnvironment env, Session session, string[] outputCo _inputs = inputColumnNames; _outputs = outputColumnNames; _idvToTfMapping = new Dictionary(); - _transferLearning = transferLearning; - _labelColumnName = labelColumnName; - _checkpointName = checkpointName; - _arch = arch; - _scoreColumnName = scoreColumnName; - _predictedLabelColumnName = predictedLabelColumnName; - _learningRate = learningRate; - _softmaxTensorName = softMaxTensorName; - _predictionTensorName = predictionTensorName; - if (transferLearning) - { - if (classCount == null) - { - var labelColumn = inputSchema.GetColumnOrNull(labelColumnName).Value; - var labelType = labelColumn.Type; - var labelCount = labelType.GetKeyCount(); - if (labelCount <= 0) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", (string)labelColumn.Name, "Key", (string)labelType.ToString()); - - _classCount = labelCount == 1 ? 2 : (int)labelCount; - } - else - _classCount = classCount.Value; - _checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), modelLocation + checkpointName); + foreach (var x in _inputs) + _idvToTfMapping[x] = x; - // Configure bottleneck tensor based on the model. - if (arch == DnnEstimator.Architecture.ResnetV2101) - _bottleneckOperationName = "resnet_v2_101/SpatialSqueeze"; - else if(arch == DnnEstimator.Architecture.InceptionV3) - _bottleneckOperationName = "module_apply_default/hub_output/feature_vector/SpatialSqueeze"; + foreach (var x in _outputs) + _idvToTfMapping[x] = x; - if (arch == DnnEstimator.Architecture.ResnetV2101) - _idvToTfMapping[_inputs[0]] = "input"; - else if (arch == DnnEstimator.Architecture.InceptionV3) - _idvToTfMapping[_inputs[0]] = "Placeholder"; + (_tfOutputTypes, _outputTypes, _tfOutputOperations) = GetOutputInfo(Host, _session, _outputs); - _outputs = new[] { scoreColumnName, predictedLabelColumnName }; - - if (loadModel == false) - { - // Add transfer learning layer. - AddTransferLearningLayer(labelColumnName, scoreColumnName, learningRate, _classCount); - - // Initialize the variables. - new Runner(_session).AddOperation(tf.global_variables_initializer()).Run(); - - // Add evaluation layer. - (_evaluationStep, _) = AddEvaluationStep(_softMaxTensor, _labelTensor); - _softmaxTensorName = _softMaxTensor.name; - _predictionTensorName = _prediction.name; - } - - _idvToTfMapping[scoreColumnName] = _softmaxTensorName; - _idvToTfMapping[predictedLabelColumnName] = _predictionTensorName; - - (_tfOutputTypes, _outputTypes, _tfOutputOperations) = GetOutputInfo(Host, _session, new[] { _softmaxTensorName, _predictionTensorName }); - _transferLearning = true; - } - else - { - foreach (var x in _inputs) - _idvToTfMapping[x] = x; - - foreach (var x in _outputs) - _idvToTfMapping[x] = x; - - (_tfOutputTypes, _outputTypes, _tfOutputOperations) = GetOutputInfo(Host, _session, _outputs); - - } (_tfInputTypes, _tfInputShapes, _tfInputOperations) = GetInputInfo(Host, _session, _inputs.Select(x => _idvToTfMapping[x]).ToArray(), batchSize); _tfInputNodes = new TF_Output[_inputs.Length]; @@ -1093,7 +678,7 @@ private protected override void SaveModel(ModelSaveContext ctx) // for each output column // int: id of output column name // stream: tensorFlow model. - var isFrozen = _transferLearning || DnnUtils.IsSavedModel(_env, _modelLocation); + var isFrozen = DnnUtils.IsSavedModel(_env, _modelLocation); ctx.Writer.WriteBoolByte(isFrozen); ctx.Writer.WriteBoolByte(_addBatchDimensionInput); @@ -1107,58 +692,35 @@ private protected override void SaveModel(ModelSaveContext ctx) foreach (var colName in _outputs) ctx.SaveNonEmptyString(colName); - ctx.Writer.Write(_transferLearning); - ctx.Writer.Write(_labelColumnName); - ctx.Writer.Write(_checkpointName); - ctx.Writer.Write((int)_arch); - ctx.Writer.Write(_scoreColumnName); - ctx.Writer.Write(_predictedLabelColumnName); - ctx.Writer.Write(_learningRate); - ctx.Writer.Write(_classCount); - ctx.Writer.Write(_predictionTensorName); - ctx.Writer.Write(_softmaxTensorName); - - if (isFrozen || _transferLearning) - { - Status status = new Status(); - var buffer = _session.graph.ToGraphDef(status); - ctx.SaveBinaryStream("TFModel", w => - { - w.WriteByteArray(buffer.Data); - }); - } - else + ctx.SaveBinaryStream("TFSavedModel", w => { - ctx.SaveBinaryStream("TFSavedModel", w => + // only these files need to be saved. + string[] modelFilePaths = { - // only these files need to be saved. - string[] modelFilePaths = - { - Path.Combine(_modelLocation, DefaultModelFileNames.Graph), - Path.Combine(_modelLocation, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Data), - Path.Combine(_modelLocation, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index), - }; + Path.Combine(_modelLocation, DefaultModelFileNames.Graph), + Path.Combine(_modelLocation, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Data), + Path.Combine(_modelLocation, DefaultModelFileNames.VariablesFolder, DefaultModelFileNames.Index), + }; - w.Write(modelFilePaths.Length); + w.Write(modelFilePaths.Length); - foreach (var fullPath in modelFilePaths) - { - var relativePath = fullPath.Substring(_modelLocation.Length + 1); - w.Write(relativePath); + foreach (var fullPath in modelFilePaths) + { + var relativePath = fullPath.Substring(_modelLocation.Length + 1); + w.Write(relativePath); - using (var fs = new FileStream(fullPath, FileMode.Open)) - { - long fileLength = fs.Length; - w.Write(fileLength); - long actualWritten = fs.CopyRange(w.BaseStream, fileLength); - Host.Assert(actualWritten == fileLength); - } + using (var fs = new FileStream(fullPath, FileMode.Open)) + { + long fileLength = fs.Length; + w.Write(fileLength); + long actualWritten = fs.CopyRange(w.BaseStream, fileLength); + Host.Assert(actualWritten == fileLength); } - }); - } + } + }); } - ~DnnTransformer() + ~DnnRetrainTransformer() { Dispose(false); } @@ -1187,13 +749,13 @@ private void Dispose(bool disposing) private sealed class Mapper : MapperBase { - private readonly DnnTransformer _parent; + private readonly DnnRetrainTransformer _parent; private readonly int[] _inputColIndices; private readonly bool[] _isInputVector; private readonly TensorShape[] _fullySpecifiedShapes; private readonly ConcurrentBag _runners; - public Mapper(DnnTransformer parent, DataViewSchema inputSchema) : + public Mapper(DnnRetrainTransformer parent, DataViewSchema inputSchema) : base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent) { Host.CheckValue(parent, nameof(parent)); @@ -1612,28 +1174,11 @@ public Tensor GetBufferedBatchTensor() } } - /// - public sealed class DnnEstimator : IEstimator + /// + public sealed class DnnRetrainEstimator : IEstimator { /// - /// Image classification model. - /// - public enum Architecture - { - ResnetV2101, - InceptionV3 - }; - - /// - /// Backend DNN training framework. - /// - public enum DnnFramework - { - Tensorflow - }; - - /// - /// The options for the . + /// The options for the . /// internal sealed class Options : TransformInputBase { @@ -1729,12 +1274,6 @@ internal sealed class Options : TransformInputBase [Argument(ArgumentType.AtMostOnce, HelpText = "Name of the input in TensorFlow graph that specifiy the location for saving/restoring models from disk.", SortOrder = 14)] public string SaveOperation = "save/control_dependency"; - /// - /// Needed for command line to specify if retraining is requested. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Retrain TensorFlow model.", SortOrder = 15)] - public bool ReTrain = false; - /// /// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. /// @@ -1744,42 +1283,6 @@ internal sealed class Options : TransformInputBase /// [Argument(ArgumentType.AtMostOnce, HelpText = "Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].", SortOrder = 16)] public bool AddBatchDimensionInputs = false; - - /// - /// Indicates if transfer learning is requested. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Transfer learning on a model.", SortOrder = 15)] - public bool TransferLearning = false; - - /// - /// Specifies the model architecture to be used in the case of image classification training using transfer learning. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Model architecture to be used in transfer learning for image classification.", SortOrder = 15)] - public Architecture Arch = Architecture.ResnetV2101; - - /// - /// Name of the tensor that will contain the output scores of the last layer when transfer learning is done. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Softmax tensor of the last layer in transfer learning.", SortOrder = 15)] - public string ScoreColumnName = "Scores"; - - /// - /// Name of the tensor that will contain the predicted label from output scores of the last layer when transfer learning is done. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Argmax tensor of the last layer in transfer learning.", SortOrder = 15)] - public string PredictedLabelColumnName = "PredictedLabel"; - - /// - /// Checkpoint folder to store graph files in the event of transfer learning. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Checkpoint folder to store graph files in the event of transfer learning.", SortOrder = 15)] - public string CheckpointName = "_retrain_checkpoint"; - - /// - /// Use train set to measure model accuracy between each epoch. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Use train set to measure model accuracy between each epoch.", SortOrder = 15)] - public bool MeasureTrainAccuracy = false; } private readonly IHost _host; @@ -1787,25 +1290,16 @@ internal sealed class Options : TransformInputBase private readonly DnnModel _tensorFlowModel; private readonly TF_DataType[] _tfInputTypes; private readonly DataViewType[] _outputTypes; - private DnnTransformer _transformer; + private DnnRetrainTransformer _transformer; - internal DnnEstimator(IHostEnvironment env, Options options, DnnModel tensorFlowModel) + internal DnnRetrainEstimator(IHostEnvironment env, Options options, DnnModel tensorFlowModel) { - _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(DnnEstimator)); + _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(DnnRetrainEstimator)); _options = options; _tensorFlowModel = tensorFlowModel; - - if (options.TransferLearning) - _tfInputTypes = new[] { TF_DataType.TF_FLOAT }; - else - { - var inputTuple = DnnTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns); - _tfInputTypes = inputTuple.tfInputTypes; - } - if (options.TransferLearning) - _outputTypes = new[] { new VectorDataViewType(NumberDataViewType.Single), new VectorDataViewType(NumberDataViewType.Single, 1) }; - else - _outputTypes = DnnTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns).outputTypes; + var inputTuple = DnnRetrainTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns); + _tfInputTypes = inputTuple.tfInputTypes; + _outputTypes = DnnRetrainTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns).outputTypes; } private static Options CreateArguments(DnnModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput) @@ -1814,7 +1308,6 @@ private static Options CreateArguments(DnnModel tensorFlowModel, string[] output options.ModelLocation = tensorFlowModel.ModelPath; options.InputColumns = inputColumnName; options.OutputColumns = outputColumnNames; - options.ReTrain = false; options.AddBatchDimensionInputs = addBatchDimensionInput; return options; } @@ -1849,13 +1342,13 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) } /// - /// Trains and returns a . + /// Trains and returns a . /// - public DnnTransformer Fit(IDataView input) + public DnnRetrainTransformer Fit(IDataView input) { _host.CheckValue(input, nameof(input)); if (_transformer == null) - _transformer = new DnnTransformer(_host, _options, _tensorFlowModel, input); + _transformer = new DnnRetrainTransformer(_host, _options, _tensorFlowModel, input); // Validate input schema. _transformer.GetOutputSchema(input.Schema); diff --git a/src/Microsoft.ML.Dnn/DnnUtils.cs b/src/Microsoft.ML.Dnn/DnnUtils.cs index 3aa727f41d..623e103997 100644 --- a/src/Microsoft.ML.Dnn/DnnUtils.cs +++ b/src/Microsoft.ML.Dnn/DnnUtils.cs @@ -381,6 +381,25 @@ public Runner AddInput(string input, Tensor value) return this; } + public Runner AddInput(string input) + { + _inputs.Add(ParseOutput(input)); + return this; + } + + public Runner AddInput(Tensor value, int index) + { + if (_inputValues.Count <= index) + _inputValues.Add(value); + else + { + _inputValues[index].Dispose(); + _inputValues[index] = value; + } + + return this; + } + public Runner AddOutputs(string output) { _outputs.Add(ParseOutput(output)); @@ -444,10 +463,6 @@ public Tensor[] Run() return result; } - public Runner CloneRunner() - { - return new Runner(_session); - } } } diff --git a/src/Microsoft.ML.Dnn/ImageClassificationTransform.cs b/src/Microsoft.ML.Dnn/ImageClassificationTransform.cs new file mode 100644 index 0000000000..4729731972 --- /dev/null +++ b/src/Microsoft.ML.Dnn/ImageClassificationTransform.cs @@ -0,0 +1,1222 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using Google.Protobuf; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; +using Microsoft.ML.Transforms.Dnn; +using Tensorflow; +using Tensorflow.Summaries; +using static Microsoft.ML.Data.TextLoader; +using static Microsoft.ML.Transforms.Dnn.DnnUtils; +using static Microsoft.ML.Transforms.ImageClassificationEstimator; +using static Tensorflow.Python; +using Architecture = Microsoft.ML.Transforms.ImageClassificationEstimator.Architecture; + +[assembly: LoadableClass(ImageClassificationTransformer.Summary, typeof(IDataTransform), typeof(ImageClassificationTransformer), + typeof(ImageClassificationEstimator.Options), typeof(SignatureDataTransform), ImageClassificationTransformer.UserName, ImageClassificationTransformer.ShortName)] + +[assembly: LoadableClass(ImageClassificationTransformer.Summary, typeof(IDataTransform), typeof(ImageClassificationTransformer), null, typeof(SignatureLoadDataTransform), + ImageClassificationTransformer.UserName, ImageClassificationTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(ImageClassificationTransformer), null, typeof(SignatureLoadModel), + ImageClassificationTransformer.UserName, ImageClassificationTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageClassificationTransformer), null, typeof(SignatureLoadRowMapper), + ImageClassificationTransformer.UserName, ImageClassificationTransformer.LoaderSignature)] + +namespace Microsoft.ML.Transforms +{ + /// + /// for the . + /// + public sealed class ImageClassificationTransformer : RowToRowTransformerBase + { + private readonly IHostEnvironment _env; + private readonly bool _addBatchDimensionInput; + private Session _session; + private Tensor _bottleneckTensor; + private Operation _trainStep; + private Tensor _softMaxTensor; + private Tensor _crossEntropy; + private Tensor _labelTensor; + private Tensor _evaluationStep; + private Tensor _prediction; + private Tensor _bottleneckInput; + private Tensor _jpegData; + private Tensor _resizedImage; + private string _jpegDataTensorName; + private string _resizedImageTensorName; + private string _inputTensorName; + private readonly int _classCount; + private readonly string _checkpointPath; + private readonly string _bottleneckOperationName; + private Graph Graph => _session.graph; + private readonly string[] _inputs; + private readonly string[] _outputs; + private readonly string _labelColumnName; + private readonly string _finalModelPrefix; + private readonly Architecture _arch; + private readonly string _scoreColumnName; + private readonly string _predictedLabelColumnName; + private readonly float _learningRate; + private readonly string _softmaxTensorName; + private readonly string _predictionTensorName; + internal const string Summary = "Trains Dnn models."; + internal const string UserName = "ImageClassificationTransform"; + internal const string ShortName = "ImgClsTrans"; + internal const string LoaderSignature = "ImageClassificationTrans"; + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "IMGTRANS", + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00000001, + verReadableCur: 0x00000001, + verWeCanReadBack: 0x00000001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ImageClassificationTransformer).Assembly.FullName); + } + + // Factory method for SignatureLoadModel. + private static ImageClassificationTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + // *** Binary format *** + // byte: indicator for frozen models + // byte: indicator for adding batch dimension in input + // int: number of input columns + // for each input column + // int: id of int column name + // int: number of output columns + // for each output column + // int: id of output column name + // stream: tensorFlow model. + + GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool addBatchDimensionInput, + out string labelColumn, out string checkpointName, out Architecture arch, out string scoreColumnName, + out string predictedColumnName, out float learningRate, out int classCount, out string predictionTensorName, out string softMaxTensorName, + out string jpegDataTensorName, out string resizeTensorName); + + byte[] modelBytes = null; + if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) + throw env.ExceptDecode(); + + return new ImageClassificationTransformer(env, DnnUtils.LoadTFSession(env, modelBytes), outputs, inputs, + null, addBatchDimensionInput, 1, labelColumn, checkpointName, arch, + scoreColumnName, predictedColumnName, learningRate, null, classCount, true, predictionTensorName, + softMaxTensorName, jpegDataTensorName, resizeTensorName); + + } + + // Factory method for SignatureDataTransform. + internal static IDataTransform Create(IHostEnvironment env, ImageClassificationEstimator.Options options, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(options, nameof(options)); + env.CheckValue(input, nameof(input)); + env.CheckValue(options.InputColumns, nameof(options.InputColumns)); + env.CheckValue(options.OutputColumns, nameof(options.OutputColumns)); + + return new ImageClassificationTransformer(env, options, input).MakeDataTransform(input); + } + + internal ImageClassificationTransformer(IHostEnvironment env, ImageClassificationEstimator.Options options, IDataView input) + : this(env, options, DnnUtils.LoadDnnModel(env, options.ModelLocation), input) + { + } + + internal ImageClassificationTransformer(IHostEnvironment env, ImageClassificationEstimator.Options options, DnnModel tensorFlowModel, IDataView input) + : this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns, + options.ModelLocation, null, options.BatchSize, + options.LabelColumn, options.FinalModelPrefix, options.Arch, options.ScoreColumnName, + options.PredictedLabelColumnName, options.LearningRate, input.Schema) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(options, nameof(options)); + env.CheckValue(input, nameof(input)); + CheckTrainingParameters(options); + var imageProcessor = new ImageProcessor(this); + if (!options.ReuseTrainSetBottleneckCachedValues || !File.Exists(options.TrainSetBottleneckCachedValuesFilePath)) + CacheFeaturizedImagesToDisk(input, options.LabelColumn, options.InputColumns[0], imageProcessor, + _inputTensorName, _bottleneckTensor.name, options.TrainSetBottleneckCachedValuesFilePath, + ImageClassificationMetrics.Dataset.Train, options.MetricsCallback); + + if (options.ValidationSet != null && + (!options.ReuseTrainSetBottleneckCachedValues || !File.Exists(options.ValidationSetBottleneckCachedValuesFilePath))) + CacheFeaturizedImagesToDisk(options.ValidationSet, options.LabelColumn, options.InputColumns[0], + imageProcessor, _inputTensorName, _bottleneckTensor.name, options.ValidationSetBottleneckCachedValuesFilePath, + ImageClassificationMetrics.Dataset.Validation, options.MetricsCallback); + + TrainAndEvaluateClassificationLayer(options.TrainSetBottleneckCachedValuesFilePath, options, options.ValidationSetBottleneckCachedValuesFilePath); + } + + private void CheckTrainingParameters(ImageClassificationEstimator.Options options) + { + Host.CheckNonWhiteSpace(options.LabelColumn, nameof(options.LabelColumn)); + Host.CheckNonWhiteSpace(options.TensorFlowLabel, nameof(options.TensorFlowLabel)); + + if (_session.graph.OperationByName(options.TensorFlowLabel) == null) + throw Host.ExceptParam(nameof(options.TensorFlowLabel), $"'{options.TensorFlowLabel}' does not exist in the model"); + } + + private (Tensor, Tensor) AddJpegDecoding(int height, int width, int depth) + { + // height, width, depth + var inputDim = (height, width, depth); + var jpegData = tf.placeholder(tf.@string, name: "DecodeJPGInput"); + var decodedImage = tf.image.decode_jpeg(jpegData, channels: inputDim.Item3); + // Convert from full range of uint8 to range [0,1] of float32. + var decodedImageAsFloat = tf.image.convert_image_dtype(decodedImage, tf.float32); + var decodedImage4d = tf.expand_dims(decodedImageAsFloat, 0); + var resizeShape = tf.stack(new int[] { inputDim.Item1, inputDim.Item2 }); + var resizeShapeAsInt = tf.cast(resizeShape, dtype: tf.int32); + var resizedImage = tf.image.resize_bilinear(decodedImage4d, resizeShapeAsInt, false, "ResizeTensor"); + return (jpegData, resizedImage); + } + + private sealed class ImageProcessor + { + private Runner _imagePreprocessingRunner; + + public ImageProcessor(ImageClassificationTransformer transformer) + { + _imagePreprocessingRunner = new Runner(transformer._session); + _imagePreprocessingRunner.AddInput(transformer._jpegDataTensorName); + _imagePreprocessingRunner.AddOutputs(transformer._resizedImageTensorName); + } + + public Tensor ProcessImage(string path) + { + var imageTensor = new Tensor(File.ReadAllBytes(path), TF_DataType.TF_STRING); + var processedTensor = _imagePreprocessingRunner.AddInput(imageTensor, 0).Run()[0]; + imageTensor.Dispose(); + return processedTensor; + } + } + + private void CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName, string imagepathColumnName, + ImageProcessor imageProcessor, string inputTensorName, string outputTensorName, string cacheFilePath, + ImageClassificationMetrics.Dataset dataset, ImageClassificationMetricsCallback metricsCallback) + { + var labelColumn = input.Schema[labelColumnName]; + + if (labelColumn.Type.RawType != typeof(UInt32)) + throw Host.ExceptSchemaMismatch(nameof(labelColumn), "Label", + labelColumnName, typeof(uint).ToString(), + labelColumn.Type.RawType.ToString()); + + var imagePathColumn = input.Schema[imagepathColumnName]; + Runner runner = new Runner(_session); + runner.AddOutputs(outputTensorName); + + using (TextWriter writer = File.CreateText(cacheFilePath)) + using (var cursor = input.GetRowCursor(input.Schema.Where(c => c.Index == labelColumn.Index || c.Index == imagePathColumn.Index))) + { + var labelGetter = cursor.GetGetter(labelColumn); + var imagePathGetter = cursor.GetGetter>(imagePathColumn); + UInt32 label = UInt32.MaxValue; + ReadOnlyMemory imagePath = default; + runner.AddInput(inputTensorName); + ImageClassificationMetrics metrics = new ImageClassificationMetrics(); + metrics.Bottleneck = new BottleneckMetrics(); + metrics.Bottleneck.DatasetUsed = dataset; + while (cursor.MoveNext()) + { + labelGetter(ref label); + imagePathGetter(ref imagePath); + var imagePathStr = imagePath.ToString(); + var imageTensor = imageProcessor.ProcessImage(imagePathStr); + runner.AddInput(imageTensor, 0); + var featurizedImage = runner.Run()[0]; // Reuse memory? + writer.WriteLine(label - 1 + "," + string.Join(",", featurizedImage.Data())); + featurizedImage.Dispose(); + imageTensor.Dispose(); + metrics.Bottleneck.Index++; + metrics.Bottleneck.Name = imagePathStr; + metricsCallback?.Invoke(metrics); + } + } + } + + private IDataView GetShuffledData(string path) + { + return new RowShufflingTransformer( + _env, + new RowShufflingTransformer.Options + { + ForceShuffle = true, + ForceShuffleSource = true + }, + new TextLoader( + _env, + new TextLoader.Options + { + Separators = new[] { ',' }, + Columns = new[] + { + new Column("Label", DataKind.Int64, 0), + new Column("Features", DataKind.Single, new [] { new Range(1, null) }), + }, + }, + new MultiFileSource(path)) + .Load(new MultiFileSource(path))); + } + + private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath, ImageClassificationEstimator.Options options, + string validationSetBottleneckFilePath) + { + int batchSize = options.BatchSize; + int epochs = options.Epoch; + bool evaluateOnly = !string.IsNullOrEmpty(validationSetBottleneckFilePath); + ImageClassificationMetricsCallback statisticsCallback = options.MetricsCallback; + var trainingSet = GetShuffledData(trainBottleneckFilePath); + IDataView validationSet = null; + if (options.ValidationSet != null && !string.IsNullOrEmpty(validationSetBottleneckFilePath)) + validationSet = GetShuffledData(validationSetBottleneckFilePath); + + long label = long.MaxValue; + VBuffer features = default; + ReadOnlySpan featureValues = default; + var featureColumn = trainingSet.Schema[1]; + int featureLength = featureColumn.Type.GetVectorSize(); + float[] featureBatch = new float[featureLength * batchSize]; + var featureBatchHandle = GCHandle.Alloc(featureBatch, GCHandleType.Pinned); + IntPtr featureBatchPtr = featureBatchHandle.AddrOfPinnedObject(); + int featureBatchSizeInBytes = sizeof(float) * featureBatch.Length; + long[] labelBatch = new long[batchSize]; + var labelBatchHandle = GCHandle.Alloc(labelBatch, GCHandleType.Pinned); + IntPtr labelBatchPtr = labelBatchHandle.AddrOfPinnedObject(); + int labelBatchSizeInBytes = sizeof(long) * labelBatch.Length; + var labelTensorShape = _labelTensor.TensorShape.dims.Select(x => (long)x).ToArray(); + labelTensorShape[0] = batchSize; + int batchIndex = 0; + var runner = new Runner(_session); + var testEvalRunner = new Runner(_session); + testEvalRunner.AddOutputs(_evaluationStep.name); + testEvalRunner.AddOutputs(_crossEntropy.name); + + Runner validationEvalRunner = null; + if (validationSet != null) + { + validationEvalRunner = new Runner(_session); + validationEvalRunner.AddOutputs(_evaluationStep.name); + validationEvalRunner.AddInput(_bottleneckInput.name).AddInput(_labelTensor.name); + } + + runner.AddOperation(_trainStep); + var featureTensorShape = _bottleneckInput.TensorShape.dims.Select(x => (long)x).ToArray(); + featureTensorShape[0] = batchSize; + + Saver trainSaver = null; + FileWriter trainWriter = null; + Tensor merged = tf.summary.merge_all(); + trainWriter = tf.summary.FileWriter(Path.Combine(Directory.GetCurrentDirectory(), "train"), _session.graph); + trainSaver = tf.train.Saver(); + trainSaver.save(_session, _checkpointPath); + + runner.AddInput(_bottleneckInput.name).AddInput(_labelTensor.name); + testEvalRunner.AddInput(_bottleneckInput.name).AddInput(_labelTensor.name); + Dictionary classStatsTrain = new Dictionary(); + Dictionary classStatsValidate = new Dictionary(); + for (int index = 0; index < _classCount; index += 1) + classStatsTrain[index] = classStatsValidate[index] = 0; + + ImageClassificationMetrics metrics = new ImageClassificationMetrics(); + metrics.Train = new TrainMetrics(); + for (int epoch = 0; epoch < epochs; epoch += 1) + { + metrics.Train.Accuracy = 0; + metrics.Train.CrossEntropy = 0; + metrics.Train.BatchProcessedCount = 0; + using (var cursor = trainingSet.GetRowCursor(trainingSet.Schema.ToArray(), new Random())) + { + var labelGetter = cursor.GetGetter(trainingSet.Schema[0]); + var featuresGetter = cursor.GetGetter>(featureColumn); + while (cursor.MoveNext()) + { + labelGetter(ref label); + featuresGetter(ref features); + classStatsTrain[label]++; + + if (featureValues == default) + featureValues = features.GetValues(); + + // Buffer the values. + for (int index = 0; index < featureLength; index += 1) + featureBatch[batchIndex * featureLength + index] = featureValues[index]; + + labelBatch[batchIndex] = label; + batchIndex += 1; + // Train. + if (batchIndex == batchSize) + { + runner.AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0) + .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1) + .Run(); + + metrics.Train.BatchProcessedCount += 1; + + if (options.TestOnTrainSet && statisticsCallback != null) + { + var outputTensors = testEvalRunner + .AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0) + .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1) + .Run(); + + metrics.Train.Accuracy += outputTensors[0].Data()[0]; + metrics.Train.CrossEntropy += outputTensors[1].Data()[0]; + + outputTensors[0].Dispose(); + outputTensors[1].Dispose(); + } + + batchIndex = 0; + } + } + + if (options.TestOnTrainSet && statisticsCallback != null) + { + metrics.Train.Epoch = epoch; + metrics.Train.Accuracy /= metrics.Train.BatchProcessedCount; + metrics.Train.CrossEntropy /= metrics.Train.BatchProcessedCount; + metrics.Train.DatasetUsed = ImageClassificationMetrics.Dataset.Train; + statisticsCallback(metrics); + } + } + + if (validationSet == null) + continue; + + batchIndex = 0; + metrics.Train.BatchProcessedCount = 0; + metrics.Train.Accuracy = 0; + metrics.Train.CrossEntropy = 0; + using (var cursor = validationSet.GetRowCursor(validationSet.Schema.ToArray(), new Random())) + { + var labelGetter = cursor.GetGetter(validationSet.Schema[0]); + var featuresGetter = cursor.GetGetter>(featureColumn); + while (cursor.MoveNext()) + { + labelGetter(ref label); + featuresGetter(ref features); + classStatsValidate[label]++; + // Buffer the values. + for (int index = 0; index < featureLength; index += 1) + featureBatch[batchIndex * featureLength + index] = featureValues[index]; + + labelBatch[batchIndex] = label; + batchIndex += 1; + // Evaluate. + if (batchIndex == batchSize) + { + var outputTensors = validationEvalRunner + .AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0) + .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1) + .Run(); + + metrics.Train.Accuracy += outputTensors[0].Data()[0]; + metrics.Train.BatchProcessedCount += 1; + batchIndex = 0; + + outputTensors[0].Dispose(); + } + } + + if (statisticsCallback != null) + { + metrics.Train.Epoch = epoch; + metrics.Train.Accuracy /= metrics.Train.BatchProcessedCount; + metrics.Train.DatasetUsed = ImageClassificationMetrics.Dataset.Validation; + statisticsCallback(metrics); + } + } + } + + trainSaver.save(_session, _checkpointPath); + UpdateTransferLearningModelOnDisk(options, _classCount); + } + + private (Session, Tensor, Tensor, Tensor) BuildEvaluationSession(ImageClassificationEstimator.Options options, int classCount) + { + var evalGraph = DnnUtils.LoadMetaGraph(options.ModelLocation); + var evalSess = tf.Session(graph: evalGraph); + Tensor evaluationStep = null; + Tensor prediction = null; + Tensor bottleneckTensor = evalGraph.OperationByName(_bottleneckOperationName); + + tf_with(evalGraph.as_default(), graph => + { + var (_, _, groundTruthInput, finalTensor) = AddFinalRetrainOps(classCount, options.LabelColumn, + options.ScoreColumnName, options.LearningRate, bottleneckTensor, false); + + tf.train.Saver().restore(evalSess, _checkpointPath); + (evaluationStep, prediction) = AddEvaluationStep(finalTensor, groundTruthInput); + (_jpegData, _resizedImage) = AddJpegDecoding(299, 299, 3); + }); + + return (evalSess, _labelTensor, evaluationStep, prediction); + } + + private (Tensor, Tensor) AddEvaluationStep(Tensor resultTensor, Tensor groundTruthTensor) + { + Tensor evaluationStep = null; + Tensor correctPrediction = null; + + tf_with(tf.name_scope("accuracy"), scope => + { + tf_with(tf.name_scope("correct_prediction"), delegate + { + _prediction = tf.argmax(resultTensor, 1); + correctPrediction = tf.equal(_prediction, groundTruthTensor); + }); + + tf_with(tf.name_scope("accuracy"), delegate + { + evaluationStep = tf.reduce_mean(tf.cast(correctPrediction, tf.float32)); + }); + }); + + tf.summary.scalar("accuracy", evaluationStep); + return (evaluationStep, _prediction); + } + + private void UpdateTransferLearningModelOnDisk(ImageClassificationEstimator.Options options, int classCount) + { + var (sess, _, _, _) = BuildEvaluationSession(options, classCount); + var graph = sess.graph; + var outputGraphDef = tf.graph_util.convert_variables_to_constants( + sess, graph.as_graph_def(), new string[] { _softMaxTensor.name.Split(':')[0], _prediction.name.Split(':')[0], _jpegData.name.Split(':')[0], _resizedImage.name.Split(':')[0] }); + + string frozenModelPath = _checkpointPath + ".pb"; + File.WriteAllBytes(_checkpointPath + ".pb", outputGraphDef.ToByteArray()); + _session.graph.Dispose(); + _session.Dispose(); + _session = LoadTFSessionByModelFilePath(_env, frozenModelPath, false); + } + + private void VariableSummaries(RefVariable var) + { + tf_with(tf.name_scope("summaries"), delegate + { + var mean = tf.reduce_mean(var); + tf.summary.scalar("mean", mean); + Tensor stddev = null; + tf_with(tf.name_scope("stddev"), delegate + { + stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))); + }); + tf.summary.scalar("stddev", stddev); + tf.summary.scalar("max", tf.reduce_max(var)); + tf.summary.scalar("min", tf.reduce_min(var)); + tf.summary.histogram("histogram", var); + }); + } + + private (Operation, Tensor, Tensor, Tensor) AddFinalRetrainOps(int classCount, string labelColumn, + string scoreColumnName, float learningRate, Tensor bottleneckTensor, bool isTraining) + { + var (batch_size, bottleneck_tensor_size) = (bottleneckTensor.TensorShape.Dimensions[0], bottleneckTensor.TensorShape.Dimensions[1]); + tf_with(tf.name_scope("input"), scope => + { + if (isTraining) + { + _bottleneckInput = tf.placeholder_with_default( + bottleneckTensor, + shape: bottleneckTensor.TensorShape.Dimensions, + name: "BottleneckInputPlaceholder"); + } + + _labelTensor = tf.placeholder(tf.int64, new TensorShape(batch_size), name: labelColumn); + }); + + string layerName = "final_retrain_ops"; + Tensor logits = null; + tf_with(tf.name_scope(layerName), scope => + { + RefVariable layerWeights = null; + tf_with(tf.name_scope("weights"), delegate + { + var initialValue = tf.truncated_normal(new int[] { bottleneck_tensor_size, classCount }, stddev: 0.001f); + layerWeights = tf.Variable(initialValue, name: "final_weights"); + VariableSummaries(layerWeights); + }); + + RefVariable layerBiases = null; + tf_with(tf.name_scope("biases"), delegate + { + layerBiases = tf.Variable(tf.zeros(classCount), name: "final_biases"); + VariableSummaries(layerBiases); + }); + + tf_with(tf.name_scope("Wx_plus_b"), delegate + { + var matmul = tf.matmul(isTraining ? _bottleneckInput : bottleneckTensor, layerWeights); + logits = matmul + layerBiases; + tf.summary.histogram("pre_activations", logits); + }); + }); + + _softMaxTensor = tf.nn.softmax(logits, name: scoreColumnName); + + tf.summary.histogram("activations", _softMaxTensor); + if (!isTraining) + return (null, null, _labelTensor, _softMaxTensor); + + Tensor crossEntropyMean = null; + tf_with(tf.name_scope("cross_entropy"), delegate + { + crossEntropyMean = tf.losses.sparse_softmax_cross_entropy( + labels: _labelTensor, logits: logits); + }); + + tf.summary.scalar("cross_entropy", crossEntropyMean); + + tf_with(tf.name_scope("train"), delegate + { + var optimizer = tf.train.GradientDescentOptimizer(learningRate); + _trainStep = optimizer.minimize(crossEntropyMean); + }); + + return (_trainStep, crossEntropyMean, _labelTensor, _softMaxTensor); + } + + private void AddTransferLearningLayer(string labelColumn, + string scoreColumnName, float learningRate, int classCount) + { + _bottleneckTensor = Graph.OperationByName(_bottleneckOperationName); + tf_with(Graph.as_default(), delegate + { + (_trainStep, _crossEntropy, _labelTensor, _softMaxTensor) = + AddFinalRetrainOps(classCount, labelColumn, scoreColumnName, learningRate, _bottleneckTensor, true); + }); + } + + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, + out string[] outputs, out bool addBatchDimensionInput, + out string labelColumn, out string checkpointName, out Architecture arch, + out string scoreColumnName, out string predictedColumnName, out float learningRate, out int classCount, out string predictionTensorName, out string softMaxTensorName, + out string jpegDataTensorName, out string resizeTensorName) + { + addBatchDimensionInput = ctx.Reader.ReadBoolByte(); + + var numInputs = ctx.Reader.ReadInt32(); + env.CheckDecode(numInputs > 0); + inputs = new string[numInputs]; + for (int j = 0; j < inputs.Length; j++) + inputs[j] = ctx.LoadNonEmptyString(); + + var numOutputs = ctx.Reader.ReadInt32(); + env.CheckDecode(numOutputs > 0); + outputs = new string[numOutputs]; + for (int j = 0; j < outputs.Length; j++) + outputs[j] = ctx.LoadNonEmptyString(); + + labelColumn = ctx.Reader.ReadString(); + checkpointName = ctx.Reader.ReadString(); + arch = (Architecture)ctx.Reader.ReadInt32(); + scoreColumnName = ctx.Reader.ReadString(); + predictedColumnName = ctx.Reader.ReadString(); + learningRate = ctx.Reader.ReadFloat(); + classCount = ctx.Reader.ReadInt32(); + predictionTensorName = ctx.Reader.ReadString(); + softMaxTensorName = ctx.Reader.ReadString(); + jpegDataTensorName = ctx.Reader.ReadString(); + resizeTensorName = ctx.Reader.ReadString(); + } + + internal ImageClassificationTransformer(IHostEnvironment env, Session session, string[] outputColumnNames, + string[] inputColumnNames, string modelLocation, + bool? addBatchDimensionInput, int batchSize, string labelColumnName, string finalModelPrefix, Architecture arch, + string scoreColumnName, string predictedLabelColumnName, float learningRate, DataViewSchema inputSchema, int? classCount = null, bool loadModel = false, + string predictionTensorName = null, string softMaxTensorName = null, string jpegDataTensorName = null, string resizeTensorName = null) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageClassificationTransformer))) + + { + Host.CheckValue(session, nameof(session)); + Host.CheckNonEmpty(inputColumnNames, nameof(inputColumnNames)); + Host.CheckNonEmpty(outputColumnNames, nameof(outputColumnNames)); + + _env = env; + _session = session; + _addBatchDimensionInput = addBatchDimensionInput ?? arch == Architecture.ResnetV2101; + _inputs = inputColumnNames; + _outputs = outputColumnNames; + _labelColumnName = labelColumnName; + _finalModelPrefix = finalModelPrefix; + _arch = arch; + _scoreColumnName = scoreColumnName; + _predictedLabelColumnName = predictedLabelColumnName; + _learningRate = learningRate; + _softmaxTensorName = softMaxTensorName; + _predictionTensorName = predictionTensorName; + _jpegDataTensorName = jpegDataTensorName; + _resizedImageTensorName = resizeTensorName; + + if (classCount == null) + { + var labelColumn = inputSchema.GetColumnOrNull(labelColumnName).Value; + var labelType = labelColumn.Type; + var labelCount = labelType.GetKeyCount(); + if (labelCount <= 0) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", (string)labelColumn.Name, "Key", (string)labelType.ToString()); + + _classCount = labelCount == 1 ? 2 : (int)labelCount; + } + else + _classCount = classCount.Value; + + _checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), finalModelPrefix + modelLocation); + + // Configure bottleneck tensor based on the model. + if (arch == ImageClassificationEstimator.Architecture.ResnetV2101) + { + _bottleneckOperationName = "resnet_v2_101/SpatialSqueeze"; + _inputTensorName = "input"; + } + else if (arch == ImageClassificationEstimator.Architecture.InceptionV3) + { + _bottleneckOperationName = "module_apply_default/hub_output/feature_vector/SpatialSqueeze"; + _inputTensorName = "Placeholder"; + } + + _outputs = new[] { scoreColumnName, predictedLabelColumnName }; + + if (loadModel == false) + { + (_jpegData, _resizedImage) = AddJpegDecoding(299, 299, 3); + _jpegDataTensorName = _jpegData.name; + _resizedImageTensorName = _resizedImage.name; + + // Add transfer learning layer. + AddTransferLearningLayer(labelColumnName, scoreColumnName, learningRate, _classCount); + + // Initialize the variables. + new Runner(_session).AddOperation(tf.global_variables_initializer()).Run(); + + // Add evaluation layer. + (_evaluationStep, _) = AddEvaluationStep(_softMaxTensor, _labelTensor); + _softmaxTensorName = _softMaxTensor.name; + _predictionTensorName = _prediction.name; + } + } + + private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema); + + private protected override void SaveModel(ModelSaveContext ctx) + { + Host.AssertValue(ctx); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + ctx.Writer.WriteBoolByte(_addBatchDimensionInput); + + Host.AssertNonEmpty(_inputs); + ctx.Writer.Write(_inputs.Length); + foreach (var colName in _inputs) + ctx.SaveNonEmptyString(colName); + + Host.AssertNonEmpty(_outputs); + ctx.Writer.Write(_outputs.Length); + foreach (var colName in _outputs) + ctx.SaveNonEmptyString(colName); + + ctx.Writer.Write(_labelColumnName); + ctx.Writer.Write(_finalModelPrefix); + ctx.Writer.Write((int)_arch); + ctx.Writer.Write(_scoreColumnName); + ctx.Writer.Write(_predictedLabelColumnName); + ctx.Writer.Write(_learningRate); + ctx.Writer.Write(_classCount); + ctx.Writer.Write(_predictionTensorName); + ctx.Writer.Write(_softmaxTensorName); + ctx.Writer.Write(_jpegDataTensorName); + ctx.Writer.Write(_resizedImageTensorName); + Status status = new Status(); + var buffer = _session.graph.ToGraphDef(status); + ctx.SaveBinaryStream("TFModel", w => + { + w.WriteByteArray(buffer.Data); + }); + status.Check(true); + } + + ~ImageClassificationTransformer() + { + Dispose(false); + } + + private void Dispose(bool disposing) + { + // Ensure that the Session is not null and it's handle is not Zero, as it may have already been disposed/finalized. + // Technically we shouldn't be calling this if disposing == false, since we're running in finalizer + // and the GC doesn't guarantee ordering of finalization of managed objects, but we have to make sure + // that the Session is closed before deleting our temporary directory. + if (_session != null && _session != IntPtr.Zero) + { + _session.close(); + } + } + + private sealed class Mapper : MapperBase + { + private readonly ImageClassificationTransformer _parent; + private readonly int[] _inputColIndices; + + public Mapper(ImageClassificationTransformer parent, DataViewSchema inputSchema) : + base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent) + { + Host.CheckValue(parent, nameof(parent)); + _parent = parent; + _inputColIndices = new int[1]; + if (!inputSchema.TryGetColumnIndex(_parent._inputs[0], out _inputColIndices[0])) + throw Host.ExceptSchemaMismatch(nameof(InputSchema), "source", _parent._inputs[0]); + } + + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); + + private class OutputCache + { + public long Position; + private ValueGetter> _imagePathGetter; + private ReadOnlyMemory _imagePath; + private Runner _runner; + private ImageProcessor _imageProcessor; + public UInt32 PredictedLabel { get; set; } + public float[] ClassProbabilities { get; set; } + private DataViewRow _inputRow; + + public OutputCache(DataViewRow input, ImageClassificationTransformer transformer) + { + _imagePath = default; + _imagePathGetter = input.GetGetter>(input.Schema[transformer._inputs[0]]); + _runner = new Runner(transformer._session); + _runner.AddInput(transformer._inputTensorName); + _runner.AddOutputs(transformer._softmaxTensorName); + _runner.AddOutputs(transformer._predictionTensorName); + _imageProcessor = new ImageProcessor(transformer); + _inputRow = input; + Position = -1; + } + + public void UpdateCacheIfNeeded() + { + lock (this) + { + if (_inputRow.Position != Position) + { + Position = _inputRow.Position; + _imagePathGetter(ref _imagePath); + var processedTensor = _imageProcessor.ProcessImage(_imagePath.ToString()); + var outputTensor = _runner.AddInput(processedTensor, 0).Run(); + ClassProbabilities = outputTensor[0].Data(); + PredictedLabel = (UInt32)outputTensor[1].Data()[0]; + outputTensor[0].Dispose(); + outputTensor[1].Dispose(); + processedTensor.Dispose(); + } + } + } + } + + protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) + { + disposer = null; + Host.AssertValue(input); + var cache = new OutputCache(input, _parent); + + if (iinfo == 0) + { + ValueGetter> valuegetter = (ref VBuffer dst) => + { + cache.UpdateCacheIfNeeded(); + var editor = VBufferEditor.Create(ref dst, cache.ClassProbabilities.Length); + new Span(cache.ClassProbabilities, 0, cache.ClassProbabilities.Length).CopyTo(editor.Values); + dst = editor.Commit(); + }; + return valuegetter; + } + else + { + ValueGetter valuegetter = (ref UInt32 dst) => + { + cache.UpdateCacheIfNeeded(); + dst = cache.PredictedLabel; + }; + + return valuegetter; + } + } + + private protected override Func GetDependenciesCore(Func activeOutput) + { + return col => Enumerable.Range(0, _parent._outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col); + } + + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() + { + var info = new DataViewSchema.DetachedColumn[_parent._outputs.Length]; + info[0] = new DataViewSchema.DetachedColumn(_parent._outputs[0], new VectorDataViewType(NumberDataViewType.Single, _parent._classCount), null); + info[1] = new DataViewSchema.DetachedColumn(_parent._outputs[1], NumberDataViewType.UInt32, null); + return info; + } + } + } + + /// + public sealed class ImageClassificationEstimator : IEstimator + { + /// + /// Image classification model. + /// + public enum Architecture + { + ResnetV2101, + InceptionV3 + }; + + /// + /// Backend DNN training framework. + /// + public enum DnnFramework + { + Tensorflow + }; + + /// + /// Callback that returns DNN statistics during training phase. + /// + public delegate void ImageClassificationMetricsCallback(ImageClassificationMetrics metrics); + + /// + /// DNN training metrics. + /// + public sealed class TrainMetrics + { + /// + /// Indicates the dataset on which metrics are being reported. + /// + /// + public ImageClassificationMetrics.Dataset DatasetUsed { get; set; } + + /// + /// The number of batches processed in an epoch. + /// + public int BatchProcessedCount { get; set; } + + /// + /// The training epoch index for which this metric is reported. + /// + public int Epoch { get; set; } + + /// + /// Accuracy of the batch on this . Higher the better. + /// + public float Accuracy { get; set; } + + /// + /// Cross-Entropy (loss) of the batch on this . Lower + /// the better. + /// + public float CrossEntropy { get; set; } + + /// + /// String representation of the metrics. + /// + public override string ToString() + { + if (DatasetUsed == ImageClassificationMetrics.Dataset.Train) + return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, " + + $"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}, Cross-Entropy: {CrossEntropy,10}"; + else + return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, " + + $"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}"; + } + } + + /// + /// Metrics for image featurization values. The input image is passed through + /// the network and features are extracted from second or last layer to + /// train a custom full connected layer that serves as classifier. + /// + public sealed class BottleneckMetrics + { + /// + /// Indicates the dataset on which metrics are being reported. + /// + /// + public ImageClassificationMetrics.Dataset DatasetUsed { get; set; } + + /// + /// Name of the input image. + /// + public string Name { get; set; } + + /// + /// Index of the input image. + /// + public int Index { get; set; } + + /// + /// String representation of the metrics. + /// + public override string ToString() => $"Phase: Bottleneck Computation, Dataset used: {DatasetUsed.ToString(),10}, Image Index: {Index,3}, Image Name: {Name}"; + } + + /// + /// Metrics for image classification training. + /// + public sealed class ImageClassificationMetrics + { + /// + /// Indicates the kind of the dataset of which metric is reported. + /// + public enum Dataset + { + Train, + Validation + }; + + /// + /// Contains train time metrics. + /// + public TrainMetrics Train { get; set; } + + /// + /// Contains pre-train time metrics. These contains metrics on image + /// featurization. + /// + public BottleneckMetrics Bottleneck { get; set; } + + /// + /// String representation of the metrics. + /// + public override string ToString() => Train != null ? Train.ToString() : Bottleneck.ToString(); + } + + /// + /// The options for the . + /// + internal sealed class Options : TransformInputBase + { + /// + /// Location of the TensorFlow model. + /// + [Argument(ArgumentType.Required, HelpText = "TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", SortOrder = 0)] + public string ModelLocation; + + /// + /// The names of the model inputs. + /// + [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 1)] + public string[] InputColumns; + + /// + /// The names of the requested model outputs. + /// + [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The name of the outputs", ShortName = "outputs", SortOrder = 2)] + public string[] OutputColumns; + + /// + /// The name of the label column in that will be mapped to label node in TensorFlow model. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Training labels.", ShortName = "label", SortOrder = 4)] + public string LabelColumn; + + /// + /// The name of the label in TensorFlow model. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "TensorFlow label node.", ShortName = "TFLabel", SortOrder = 5)] + public string TensorFlowLabel; + + /// + /// Number of samples to use for mini-batch training. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Number of samples to use for mini-batch training.", SortOrder = 9)] + public int BatchSize = 64; + + /// + /// Number of training iterations. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Number of training iterations.", SortOrder = 10)] + public int Epoch = 5; + + /// + /// Learning rate to use during optimization. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate to use during optimization.", SortOrder = 12)] + public float LearningRate = 0.01f; + + /// + /// Specifies the model architecture to be used in the case of image classification training using transfer learning. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Model architecture to be used in transfer learning for image classification.", SortOrder = 15)] + public Architecture Arch = Architecture.InceptionV3; + + /// + /// Name of the tensor that will contain the output scores of the last layer when transfer learning is done. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Softmax tensor of the last layer in transfer learning.", SortOrder = 15)] + public string ScoreColumnName = "Scores"; + + /// + /// Name of the tensor that will contain the predicted label from output scores of the last layer when transfer learning is done. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Argmax tensor of the last layer in transfer learning.", SortOrder = 15)] + public string PredictedLabelColumnName = "PredictedLabel"; + + /// + /// Final model and checkpoint files/folder prefix for storing graph files. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Final model and checkpoint files/folder prefix for storing graph files.", SortOrder = 15)] + public string FinalModelPrefix = "custom_retrained_model_based_on_"; + + /// + /// Callback to report statistics on accuracy/cross entropy during training phase. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Callback to report metrics during training and validation phase.", SortOrder = 15)] + public ImageClassificationMetricsCallback MetricsCallback = null; + + /// + /// Frequency of epochs at which statistics on training phase should be reported. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Frequency of epochs at which statistics on training/validation phase should be reported.", SortOrder = 15)] + public int StatisticsFrequency = 1; + + /// + /// Indicates the choice DNN training framework. Currently only TensorFlow is supported. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the choice DNN training framework. Currently only TensorFlow is supported.", SortOrder = 15)] + public DnnFramework Framework = DnnFramework.Tensorflow; + + /// + /// Indicates the path where the newly retrained model should be saved. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the path where the newly retrained model should be saved.", SortOrder = 15)] + public string ModelSavePath = null; + + /// + /// Indicates to evaluate the model on train set after every epoch. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to evaluate the model on train set after every epoch.", SortOrder = 15)] + public bool TestOnTrainSet; + + /// + /// Indicates to not re-compute cached bottleneck trainset values if already available in the bin folder. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to not re-compute trained cached bottleneck values if already available in the bin folder.", SortOrder = 15)] + public bool ReuseTrainSetBottleneckCachedValues; + + /// + /// Indicates to not re-compute cached bottleneck validationset values if already available in the bin folder. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to not re-compute validataionset cached bottleneck validationset values if already available in the bin folder.", SortOrder = 15)] + public bool ReuseValidationSetBottleneckCachedValues; + + /// + /// Validation set. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Validation set.", SortOrder = 15)] + public IDataView ValidationSet; + + /// + /// Indicates the file path to store trainset bottleneck values for caching. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the file path to store trainset bottleneck values for caching.", SortOrder = 15)] + public string TrainSetBottleneckCachedValuesFilePath; + + /// + /// Indicates the file path to store validationset bottleneck values for caching. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the file path to store validationset bottleneck values for caching.", SortOrder = 15)] + public string ValidationSetBottleneckCachedValuesFilePath; + } + + private readonly IHost _host; + private readonly Options _options; + private readonly DnnModel _dnnModel; + private readonly TF_DataType[] _tfInputTypes; + private readonly DataViewType[] _outputTypes; + private ImageClassificationTransformer _transformer; + + internal ImageClassificationEstimator(IHostEnvironment env, Options options, DnnModel dnnModel) + { + _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageClassificationEstimator)); + _options = options; + _dnnModel = dnnModel; + _tfInputTypes = new[] { TF_DataType.TF_STRING }; + _outputTypes = new[] { new VectorDataViewType(NumberDataViewType.Single), NumberDataViewType.UInt32.GetItemType() }; + } + + private static Options CreateArguments(DnnModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput) + { + var options = new Options(); + options.ModelLocation = tensorFlowModel.ModelPath; + options.InputColumns = inputColumnName; + options.OutputColumns = outputColumnNames; + return options; + } + + /// + /// Returns the of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. + /// + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.ToDictionary(x => x.Name); + var resultDic = inputSchema.ToDictionary(x => x.Name); + for (var i = 0; i < _options.InputColumns.Length; i++) + { + var input = _options.InputColumns[i]; + if (!inputSchema.TryFindColumn(input, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); + var expectedType = DnnUtils.Tf2MlNetType(_tfInputTypes[i]); + if (col.ItemType != expectedType) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString()); + } + for (var i = 0; i < _options.OutputColumns.Length; i++) + { + resultDic[_options.OutputColumns[i]] = new SchemaShape.Column(_options.OutputColumns[i], + _outputTypes[i].IsKnownSizeVector() ? SchemaShape.Column.VectorKind.Vector + : SchemaShape.Column.VectorKind.VariableVector, _outputTypes[i].GetItemType(), false); + } + return new SchemaShape(resultDic.Values); + } + + /// + /// Trains and returns a . + /// + public ImageClassificationTransformer Fit(IDataView input) + { + _host.CheckValue(input, nameof(input)); + if (_transformer == null) + _transformer = new ImageClassificationTransformer(_host, _options, _dnnModel, input); + + // Validate input schema. + _transformer.GetOutputSchema(input.Schema); + return _transformer; + } + } +} diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 7a8d7063dd..e1b8ef452e 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -639,7 +639,10 @@ private void UpdateCacheIfNeeded(long position, ITensorValueGetter[] srcTensorGe // Feed inputs to the graph. for (int i = 0; i < _parent.Inputs.Length; i++) - runner.AddInput(_parent.Inputs[i], srcTensorGetters[i].GetTensor()); + { + var tensor = srcTensorGetters[i].GetTensor(); + runner.AddInput(_parent.Inputs[i], tensor); + } // Add outputs. for (int i = 0; i < _parent.Outputs.Length; i++) @@ -651,8 +654,12 @@ private void UpdateCacheIfNeeded(long position, ITensorValueGetter[] srcTensorGe Contracts.Assert(tensors.Length > 0); for (int j = 0; j < activeOutputColNames.Length; j++) - outputCache.Outputs[activeOutputColNames[j]] = tensors[j]; + { + if (outputCache.Outputs.TryGetValue(activeOutputColNames[j], out Tensor outTensor)) + outTensor.Dispose(); + outputCache.Outputs[activeOutputColNames[j]] = tensors[j]; + } outputCache.Position = position; } } @@ -704,7 +711,6 @@ private class TensorValueGetter : ITensorValueGetter private readonly T[] _bufferedData; private readonly TensorShape _tfShape; private int _position; - private readonly List _tensors; public TensorValueGetter(DataViewRow input, int colIndex, TensorShape tfShape) { @@ -719,7 +725,6 @@ public TensorValueGetter(DataViewRow input, int colIndex, TensorShape tfShape) size *= dim; } _bufferedData = new T[size]; - _tensors = new List(); } public Tensor GetTensor() @@ -728,7 +733,6 @@ public Tensor GetTensor() _srcgetter(ref scalar); var tensor = new Tensor(new[] { scalar }); tensor.SetShape(_tfShape); - _tensors.Add(tensor); return tensor; } @@ -743,7 +747,6 @@ public Tensor GetBufferedBatchTensor() { var tensor = new Tensor(new NDArray(_bufferedData, _tfShape)); _position = 0; - _tensors.Add(tensor); return tensor; } } @@ -757,7 +760,6 @@ private class TensorValueGetterVec : ITensorValueGetter private T[] _bufferedData; private int _position; private long[] _dims; - private readonly List _tensors; private readonly long _bufferedDataSize; public TensorValueGetterVec(DataViewRow input, int colIndex, TensorShape tfShape) @@ -778,7 +780,6 @@ public TensorValueGetterVec(DataViewRow input, int colIndex, TensorShape tfShape _bufferedData = new T[size]; if (_tfShape.Dimensions != null) _dims = _tfShape.Dimensions.Select(x => (long)x).ToArray(); - _tensors = new List(); _bufferedDataSize = size; } @@ -792,7 +793,6 @@ public Tensor GetTensor() _denseData = new T[_vBuffer.Length]; _vBuffer.CopyTo(_denseData); var tensor = CastDataAndReturnAsTensor(_denseData); - _tensors.Add(tensor); return tensor; } @@ -845,7 +845,6 @@ public Tensor GetBufferedBatchTensor() { _position = 0; var tensor = CastDataAndReturnAsTensor(_bufferedData); - _tensors.Add(tensor); _bufferedData = new T[_bufferedDataSize]; return tensor; diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index d5bf30ab96..e9727d13a7 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -3760,21 +3760,6 @@ public void EntryPointWordEmbeddings() } } - [TensorFlowFact] - public void EntryPointTensorFlowTransform() - { - Env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransformer).Assembly); - - TestEntryPointPipelineRoutine(GetDataPath("Train-Tiny-28x28.txt"), "col=Label:R4:0 col=Placeholder:R4:1-784", - new[] { "Transforms.TensorFlowScorer" }, - new[] - { - @"'InputColumns': [ 'Placeholder' ], - 'ModelLocation': 'mnist_model/frozen_saved_model.pb', - 'OutputColumns': [ 'Softmax' ]" - }); - } - [Fact(Skip = "Needs real time series dataset. https://github.com/dotnet/machinelearning/issues/1120")] public void EntryPointSsaChangePoint() { @@ -5637,6 +5622,21 @@ public void TestOvaMacroWithUncalibratedLearner() } } + [TensorFlowFact] + public void EntryPointTensorFlowTransform() + { + Env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransformer).Assembly); + + TestEntryPointPipelineRoutine(GetDataPath("Train-Tiny-28x28.txt"), "col=Label:R4:0 col=Placeholder:R4:1-784", + new[] { "Transforms.TensorFlowScorer" }, + new[] + { + @"'InputColumns': [ 'Placeholder' ], + 'ModelLocation': 'mnist_model/frozen_saved_model.pb', + 'OutputColumns': [ 'Softmax' ]" + }); + } + [TensorFlowFact] public void TestTensorFlowEntryPoint() { diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs deleted file mode 100644 index ef62dcae77..0000000000 --- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs +++ /dev/null @@ -1,99 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.IO; -using Microsoft.ML.Data; -using Microsoft.ML.Transforms.Image; -using Microsoft.ML.TestFramework.Attributes; -using Microsoft.ML.Transforms; -using Xunit; - -namespace Microsoft.ML.Scenarios -{ - public partial class ScenariosTests - { - [TensorFlowFact] - public void TensorFlowTransforCifarEndToEndTest() - { - var imageHeight = 32; - var imageWidth = 32; - var model_location = "cifar_model/frozen_model.pb"; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - - var mlContext = new MLContext(seed: 1); - var data = TextLoader.Create(mlContext, new TextLoader.Options() - { - Columns = new[] - { - new TextLoader.Column("ImagePath", DataKind.String, 0), - new TextLoader.Column("Label", DataKind.String, 1), - } - }, new MultiFileSource(dataFile)); - - var pipeEstimator = new ImageLoadingEstimator(mlContext, imageFolder, ("ImageReal", "ImagePath")) - .Append(new ImageResizingEstimator(mlContext, "ImageCropped", imageHeight, imageWidth, "ImageReal")) - .Append(new ImagePixelExtractingEstimator(mlContext, "Input", "ImageCropped", interleavePixelColors: true)) - .Append(mlContext.Model.LoadTensorFlowModel(model_location).ScoreTensorFlowModel("Output", "Input")) - .Append(new ColumnConcatenatingEstimator(mlContext, "Features", "Output")) - .Append(new ValueToKeyMappingEstimator(mlContext, "Label")) - .AppendCacheCheckpoint(mlContext) - .Append(mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy()); - - - var transformer = pipeEstimator.Fit(data); - var predictions = transformer.Transform(data); - - var metrics = mlContext.MulticlassClassification.Evaluate(predictions); - Assert.Equal(1, metrics.MicroAccuracy, 2); - - var predictFunction = mlContext.Model.CreatePredictionEngine(transformer); - var prediction = predictFunction.Predict(new CifarData() - { - ImagePath = GetDataPath("images/banana.jpg") - }); - Assert.Equal(0, prediction.PredictedScores[0], 2); - Assert.Equal(1, prediction.PredictedScores[1], 2); - Assert.Equal(0, prediction.PredictedScores[2], 2); - - prediction = predictFunction.Predict(new CifarData() - { - ImagePath = GetDataPath("images/hotdog.jpg") - }); - Assert.Equal(0, prediction.PredictedScores[0], 2); - Assert.Equal(0, prediction.PredictedScores[1], 2); - Assert.Equal(1, prediction.PredictedScores[2], 2); - } - } - - public class CifarData - { - [LoadColumn(0)] - public string ImagePath; - - [LoadColumn(1)] - public string Label; - } - - public class CifarPrediction - { - [ColumnName("Score")] - public float[] PredictedScores; - } - - public class ImageNetData - { - [LoadColumn(0)] - public string ImagePath; - - [LoadColumn(1)] - public string Label; - } - - public class ImageNetPrediction - { - [ColumnName("Score")] - public float[] PredictedLabels; - } -} diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 810a308d67..a48ecc6550 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -5,21 +5,30 @@ using System; using System.Collections.Generic; using System.IO; +using System.IO.Compression; using System.Linq; +using System.Net; using System.Runtime.InteropServices; using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.TestFramework; using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Image; using Microsoft.ML.Transforms.TensorFlow; -using Tensorflow; using Xunit; +using Xunit.Abstractions; +using static Microsoft.ML.DataOperationsCatalog; namespace Microsoft.ML.Scenarios { - public partial class ScenariosTests + [Collection("NoParallelization")] + public sealed class TensorFlowScenariosTests : BaseTestClass { + public TensorFlowScenariosTests(ITestOutputHelper output) : base(output) + { + } + private class TestData { [VectorType(4)] @@ -28,6 +37,89 @@ private class TestData public float[] b; } + public class CifarData + { + [LoadColumn(0)] + public string ImagePath; + + [LoadColumn(1)] + public string Label; + } + + public class CifarPrediction + { + [ColumnName("Score")] + public float[] PredictedScores; + } + + public class ImageNetData + { + [LoadColumn(0)] + public string ImagePath; + + [LoadColumn(1)] + public string Label; + } + + public class ImageNetPrediction + { + [ColumnName("Score")] + public float[] PredictedLabels; + } + + [TensorFlowFact] + public void TensorFlowTransforCifarEndToEndTest2() + { + var imageHeight = 32; + var imageWidth = 32; + var model_location = "cifar_model/frozen_model.pb"; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + + var mlContext = new MLContext(seed: 1); + var data = TextLoader.Create(mlContext, new TextLoader.Options() + { + Columns = new[] + { + new TextLoader.Column("ImagePath", DataKind.String, 0), + new TextLoader.Column("Label", DataKind.String, 1), + } + }, new MultiFileSource(dataFile)); + + var pipeEstimator = new ImageLoadingEstimator(mlContext, imageFolder, ("ImageReal", "ImagePath")) + .Append(new ImageResizingEstimator(mlContext, "ImageCropped", imageHeight, imageWidth, "ImageReal")) + .Append(new ImagePixelExtractingEstimator(mlContext, "Input", "ImageCropped", interleavePixelColors: true)) + .Append(mlContext.Model.LoadTensorFlowModel(model_location).ScoreTensorFlowModel("Output", "Input")) + .Append(new ColumnConcatenatingEstimator(mlContext, "Features", "Output")) + .Append(new ValueToKeyMappingEstimator(mlContext, "Label")) + .AppendCacheCheckpoint(mlContext) + .Append(mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy()); + + + var transformer = pipeEstimator.Fit(data); + var predictions = transformer.Transform(data); + + var metrics = mlContext.MulticlassClassification.Evaluate(predictions); + Assert.Equal(1, metrics.MicroAccuracy, 2); + + var predictFunction = mlContext.Model.CreatePredictionEngine(transformer); + var prediction = predictFunction.Predict(new CifarData() + { + ImagePath = GetDataPath("images/banana.jpg") + }); + Assert.Equal(0, prediction.PredictedScores[0], 2); + Assert.Equal(1, prediction.PredictedScores[1], 2); + Assert.Equal(0, prediction.PredictedScores[2], 2); + + prediction = predictFunction.Predict(new CifarData() + { + ImagePath = GetDataPath("images/hotdog.jpg") + }); + Assert.Equal(0, prediction.PredictedScores[0], 2); + Assert.Equal(0, prediction.PredictedScores[1], 2); + Assert.Equal(1, prediction.PredictedScores[2], 2); + } + [TensorFlowFact] public void TensorFlowTransformMatrixMultiplicationTest() { @@ -1113,5 +1205,183 @@ public void TensorFlowStringTest() Assert.Equal(input.A[i], textOutput.AOut[i]); Assert.Equal(string.Join(" ", input.B).Replace("/", " "), textOutput.BOut[0]); } + + [TensorFlowFact] + public void TensorFlowImageClassification() + { + string assetsRelativePath = @"assets"; + string assetsPath = GetAbsolutePath(assetsRelativePath); + string imagesDownloadFolderPath = Path.Combine(assetsPath, "inputs", + "images"); + + //Download the image set and unzip + string finalImagesFolderName = DownloadImageSet( + imagesDownloadFolderPath); + + string fullImagesetFolderPath = Path.Combine( + imagesDownloadFolderPath, finalImagesFolderName); + + MLContext mlContext = new MLContext(seed: 1); + + //Load all the original images info + IEnumerable images = LoadImagesFromDirectory( + folder: fullImagesetFolderPath, useFolderNameAsLabel: true); + + IDataView shuffledFullImagesDataset = mlContext.Data.ShuffleRows( + mlContext.Data.LoadFromEnumerable(images), seed: 1); + + shuffledFullImagesDataset = mlContext.Transforms.Conversion + .MapValueToKey("Label") + .Fit(shuffledFullImagesDataset) + .Transform(shuffledFullImagesDataset); + + // Split the data 80:10 into train and test sets, train and evaluate. + TrainTestData trainTestData = mlContext.Data.TrainTestSplit( + shuffledFullImagesDataset, testFraction: 0.2, seed: 1); + + IDataView trainDataset = trainTestData.TrainSet; + IDataView testDataset = trainTestData.TestSet; + + var pipeline = mlContext.Model.ImageClassification( + "ImagePath", "Label", + arch: ImageClassificationEstimator.Architecture.ResnetV2101, + epoch: 5, + batchSize: 5, + learningRate: 0.01f, + testOnTrainSet: false); + + var trainedModel = pipeline.Fit(trainDataset); + + mlContext.Model.Save(trainedModel, shuffledFullImagesDataset.Schema, + "model.zip"); + + ITransformer loadedModel; + DataViewSchema schema; + using (var file = File.OpenRead("model.zip")) + loadedModel = mlContext.Model.Load(file, out schema); + + IDataView predictions = trainedModel.Transform(testDataset); + var metrics = mlContext.MulticlassClassification.Evaluate(predictions); + + // On Ubuntu the results seem to vary quite a bit but they can probably be + // controlled by training more epochs, however that will slow the + // build down. Accuracy values seen were 0.33, 0.66, 0.70+. The model + // seems to be unstable, there could be many reasons, will need to + // investigate this further. + if (!(RuntimeInformation.IsOSPlatform(OSPlatform.Windows) || + (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)))) + { + Assert.InRange(metrics.MicroAccuracy, 0.3, 1); + Assert.InRange(metrics.MacroAccuracy, 0.3, 1); + } + else + { + Assert.Equal(1, metrics.MicroAccuracy); + Assert.Equal(1, metrics.MacroAccuracy); + } + } + + public static IEnumerable LoadImagesFromDirectory(string folder, + bool useFolderNameAsLabel = true) + { + var files = Directory.GetFiles(folder, "*", + searchOption: SearchOption.AllDirectories); + + foreach (var file in files) + { + if (Path.GetExtension(file) != ".jpg") + continue; + + var label = Path.GetFileName(file); + if (useFolderNameAsLabel) + label = Directory.GetParent(file).Name; + else + { + for (int index = 0; index < label.Length; index++) + { + if (!char.IsLetter(label[index])) + { + label = label.Substring(0, index); + break; + } + } + } + + yield return new ImageData() + { + ImagePath = file, + Label = label + }; + + } + } + + public static string DownloadImageSet(string imagesDownloadFolder) + { + string fileName = "flower_photos_tiny_set_for_unit_tests.zip"; + string url = $"https://mlnetfilestorage.file.core.windows.net/imagesets" + + $"/flower_images/flower_photos_tiny_set_for_unit_tests.zip?st=2019" + + $"-08-29T00%3A07%3A21Z&se=2030-08-30T00%3A07%3A00Z&sp=rl&sv=2018" + + $"-03-28&sr=f&sig=N8HbLziTcT61kstprNLmn%2BDC0JoMrNwo6yRWb3hLLag%3D"; + + Download(url, imagesDownloadFolder, fileName); + UnZip(Path.Combine(imagesDownloadFolder, fileName), imagesDownloadFolder); + + return Path.GetFileNameWithoutExtension(fileName); + } + + private static bool Download(string url, string destDir, string destFileName) + { + if (destFileName == null) + destFileName = url.Split(Path.DirectorySeparatorChar).Last(); + + Directory.CreateDirectory(destDir); + + string relativeFilePath = Path.Combine(destDir, destFileName); + + if (File.Exists(relativeFilePath)) + return false; + + new WebClient().DownloadFile(url, relativeFilePath); + return true; + } + + private static void UnZip(String gzArchiveName, String destFolder) + { + var flag = gzArchiveName.Split(Path.DirectorySeparatorChar) + .Last() + .Split('.') + .First() + ".bin"; + + if (File.Exists(Path.Combine(destFolder, flag))) + return; + + ZipFile.ExtractToDirectory(gzArchiveName, destFolder); + File.Create(Path.Combine(destFolder, flag)); + } + + public static string GetAbsolutePath(string relativePath) => + Path.Combine(new FileInfo(typeof( + TensorFlowScenariosTests).Assembly.Location).Directory.FullName, relativePath); + + + public class ImageData + { + [LoadColumn(0)] + public string ImagePath; + + [LoadColumn(1)] + public string Label; + } + + public class ImagePrediction + { + [ColumnName("Score")] + public float[] Score; + + [ColumnName("PredictedLabel")] + public UInt32 PredictedLabel; + } + } } diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs index b3838a791e..2a78910617 100644 --- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs @@ -17,6 +17,10 @@ namespace Microsoft.ML.Tests { + [CollectionDefinition("NoParallelization", DisableParallelization = true)] + public class NoParallelizationCollection { } + + [Collection("NoParallelization")] public class TensorFlowEstimatorTests : TestDataPipeBase { private class TestData