diff --git a/docs/samples/Microsoft.ML.Samples.GPU/Microsoft.ML.Samples.GPU.csproj b/docs/samples/Microsoft.ML.Samples.GPU/Microsoft.ML.Samples.GPU.csproj index a227b1e82b..df7275ebb5 100644 --- a/docs/samples/Microsoft.ML.Samples.GPU/Microsoft.ML.Samples.GPU.csproj +++ b/docs/samples/Microsoft.ML.Samples.GPU/Microsoft.ML.Samples.GPU.csproj @@ -11,7 +11,7 @@ - + Dynamic\ImageClassification\%(FileName) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ImageClassificationDefault.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ImageClassificationDefault.cs similarity index 98% rename from docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ImageClassificationDefault.cs rename to docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ImageClassificationDefault.cs index e8edfed9c3..71a5cb7db0 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ImageClassificationDefault.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ImageClassificationDefault.cs @@ -9,7 +9,6 @@ using System.Threading.Tasks; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Transforms; using static Microsoft.ML.DataOperationsCatalog; namespace Samples.Dynamic @@ -60,12 +59,11 @@ public static void Example() IDataView trainDataset = trainTestData.TrainSet; IDataView testDataset = trainTestData.TestSet; - var pipeline = mlContext.Model.ImageClassification("Image", "Label", validationSet: testDataset) - .Append(mlContext.Transforms.Conversion.MapKeyToValue( - outputColumnName: "PredictedLabel", + var pipeline = mlContext.MulticlassClassification.Trainers + .ImageClassification(featureColumnName:"Image", validationSet:testDataset) + .Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel", inputColumnName: "PredictedLabel")); - Console.WriteLine("*** Training the image classification model " + "with DNN Transfer Learning on top of the selected " + "pre-trained model/architecture ***"); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/LearningRateSchedulingCifarResnetTransferLearning.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/LearningRateSchedulingCifarResnetTransferLearning.cs similarity index 97% rename from docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/LearningRateSchedulingCifarResnetTransferLearning.cs rename to docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/LearningRateSchedulingCifarResnetTransferLearning.cs index cb9b3eae79..bf454daad8 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/LearningRateSchedulingCifarResnetTransferLearning.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/LearningRateSchedulingCifarResnetTransferLearning.cs @@ -9,8 +9,8 @@ using System.Threading.Tasks; using Microsoft.ML; using Microsoft.ML.Data; +using Microsoft.ML.Dnn; using Microsoft.ML.Transforms; -using static Microsoft.ML.DataOperationsCatalog; namespace Samples.Dynamic { @@ -65,20 +65,19 @@ public static void Example() .Fit(testDataset) .Transform(testDataset); - var options = new ImageClassificationEstimator.Options() + var options = new ImageClassificationTrainer.Options() { - FeaturesColumnName = "Image", + FeatureColumnName = "Image", LabelColumnName = "Label", // Just by changing/selecting InceptionV3/MobilenetV2 here instead of // ResnetV2101 you can try a different architecture/ // pre-trained model. - Arch = ImageClassificationEstimator.Architecture.ResnetV2101, + Arch = ImageClassificationTrainer.Architecture.ResnetV2101, Epoch = 182, BatchSize = 128, LearningRate = 0.01f, MetricsCallback = (metrics) => Console.WriteLine(metrics), ValidationSet = testDataset, - DisableEarlyStopping = true, ReuseValidationSetBottleneckCachedValues = false, ReuseTrainSetBottleneckCachedValues = false, // Use linear scaling rule and Learning rate decay as an option @@ -88,7 +87,7 @@ public static void Example() LearningRateScheduler = new LsrDecay() }; - var pipeline = mlContext.Model.ImageClassification(options) + var pipeline = mlContext.MulticlassClassification.Trainers.ImageClassification(options) .Append(mlContext.Transforms.Conversion.MapKeyToValue( outputColumnName: "PredictedLabel", inputColumnName: "PredictedLabel")); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningEarlyStopping.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ResnetV2101TransferLearningEarlyStopping.cs similarity index 96% rename from docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningEarlyStopping.cs rename to docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ResnetV2101TransferLearningEarlyStopping.cs index 6f084f5f66..10f6ee2e33 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningEarlyStopping.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ResnetV2101TransferLearningEarlyStopping.cs @@ -11,6 +11,7 @@ using System.IO.Compression; using System.Threading; using System.Net; +using Microsoft.ML.Dnn; namespace Samples.Dynamic { @@ -62,23 +63,24 @@ public static void Example() .Fit(testDataset) .Transform(testDataset); - var options = new ImageClassificationEstimator.Options() + var options = new ImageClassificationTrainer.Options() { - FeaturesColumnName = "Image", + FeatureColumnName = "Image", LabelColumnName = "Label", // Just by changing/selecting InceptionV3/MobilenetV2/ResnetV250 here instead of // ResnetV2101 you can try a different architecture/ // pre-trained model. - Arch = ImageClassificationEstimator.Architecture.ResnetV2101, + Arch = ImageClassificationTrainer.Architecture.ResnetV2101, BatchSize = 10, LearningRate = 0.01f, - EarlyStoppingCriteria = new ImageClassificationEstimator.EarlyStopping(minDelta: 0.001f, patience: 20, metric: ImageClassificationEstimator.EarlyStoppingMetric.Loss), + EarlyStoppingCriteria = new ImageClassificationTrainer.EarlyStopping( + minDelta: 0.001f, patience: 20, metric: ImageClassificationTrainer.EarlyStoppingMetric.Loss), MetricsCallback = (metrics) => Console.WriteLine(metrics), ValidationSet = validationSet }; var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer - .Append(mlContext.Model.ImageClassification(options)); + .Append(mlContext.MulticlassClassification.Trainers.ImageClassification(options)); Console.WriteLine("*** Training the image classification model with " + "DNN Transfer Learning on top of the selected pre-trained " + diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs similarity index 96% rename from docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs rename to docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs index 3fef8d6280..43e29d114b 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs @@ -9,7 +9,7 @@ using System.Threading.Tasks; using Microsoft.ML; using Microsoft.ML.Data; -using Microsoft.ML.Transforms; +using Microsoft.ML.Dnn; using static Microsoft.ML.DataOperationsCatalog; namespace Samples.Dynamic @@ -60,23 +60,22 @@ public static void Example() IDataView trainDataset = trainTestData.TrainSet; IDataView testDataset = trainTestData.TestSet; - var options = new ImageClassificationEstimator.Options() + var options = new ImageClassificationTrainer.Options() { - FeaturesColumnName = "Image", + FeatureColumnName = "Image", LabelColumnName = "Label", // Just by changing/selecting InceptionV3/MobilenetV2/ResnetV250 here instead of // ResnetV2101 you can try a different architecture/ // pre-trained model. - Arch = ImageClassificationEstimator.Architecture.ResnetV2101, + Arch = ImageClassificationTrainer.Architecture.ResnetV2101, Epoch = 50, BatchSize = 10, LearningRate = 0.01f, MetricsCallback = (metrics) => Console.WriteLine(metrics), - ValidationSet = testDataset, - DisableEarlyStopping = true + ValidationSet = testDataset }; - var pipeline = mlContext.Model.ImageClassification(options) + var pipeline = mlContext.MulticlassClassification.Trainers.ImageClassification(options) .Append(mlContext.Transforms.Conversion.MapKeyToValue( outputColumnName: "PredictedLabel", inputColumnName: "PredictedLabel")); diff --git a/src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs b/src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs index 1cb1ffc045..05d61f2620 100644 --- a/src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Threading; using Microsoft.ML; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.Data.IO; using Microsoft.ML.Internal.Utilities; @@ -34,6 +35,11 @@ internal sealed class MulticlassClassificationScorer : PredictedLabelScorerBase // between scores and probabilities (using IDistributionPredictor) public sealed class Arguments : ScorerArgumentsBase { + [Argument(ArgumentType.AtMostOnce, HelpText = "Score Column Name.", ShortName = "scn")] + public string ScoreColumnName = AnnotationUtils.Const.ScoreValueKind.Score; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Predicted Label Column Name.", ShortName = "plcn")] + public string PredictedLabelColumnName = DefaultColumnNames.PredictedLabel; } public const string LoaderSignature = "MultiClassScoreTrans"; @@ -486,7 +492,7 @@ internal static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoun [BestFriend] internal MulticlassClassificationScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) : base(args, env, data, WrapIfNeeded(env, mapper, trainSchema), trainSchema, RegistrationName, AnnotationUtils.Const.ScoreColumnKind.MulticlassClassification, - AnnotationUtils.Const.ScoreValueKind.Score, OutputTypeMatches, GetPredColType) + args.ScoreColumnName, OutputTypeMatches, GetPredColType, args.PredictedLabelColumnName) { } diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index 90ec8b0357..c970182062 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -42,8 +42,8 @@ private protected sealed class BindingsImpl : BindingsBase private readonly AnnotationUtils.AnnotationGetter> _getScoreValueKind; private readonly DataViewSchema.Annotations _predColMetadata; private BindingsImpl(DataViewSchema input, ISchemaBoundRowMapper mapper, string suffix, string scoreColumnKind, - bool user, int scoreColIndex, DataViewType predColType) - : base(input, mapper, suffix, user, DefaultColumnNames.PredictedLabel) + bool user, int scoreColIndex, DataViewType predColType, string predictedLabelColumnName = DefaultColumnNames.PredictedLabel) + : base(input, mapper, suffix, user, predictedLabelColumnName) { Contracts.AssertNonEmpty(scoreColumnKind); Contracts.Assert(DerivedColumnCount == 1); @@ -82,7 +82,7 @@ private static DataViewSchema.Annotations KeyValueMetadataFromMetadata(DataVi } public static BindingsImpl Create(DataViewSchema input, ISchemaBoundRowMapper mapper, string suffix, - string scoreColKind, int scoreColIndex, DataViewType predColType) + string scoreColKind, int scoreColIndex, DataViewType predColType, string predictedLabelColumnName = DefaultColumnNames.PredictedLabel) { Contracts.AssertValue(input); Contracts.AssertValue(mapper); @@ -90,7 +90,7 @@ public static BindingsImpl Create(DataViewSchema input, ISchemaBoundRowMapper ma Contracts.AssertNonEmpty(scoreColKind); return new BindingsImpl(input, mapper, suffix, scoreColKind, true, - scoreColIndex, predColType); + scoreColIndex, predColType, predictedLabelColumnName); } public BindingsImpl ApplyToSchema(DataViewSchema input, ISchemaBindableMapper bindable, IHostEnvironment env) @@ -272,7 +272,7 @@ public override Func GetActiveMapperColumns(bool[] active) [BestFriend] private protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema, string registrationName, string scoreColKind, string scoreColName, - Func outputTypeMatches, Func getPredColType) + Func outputTypeMatches, Func getPredColType, string predictedLabelColumnName = DefaultColumnNames.PredictedLabel) : base(env, data, registrationName, Contracts.CheckRef(mapper, nameof(mapper)).Bindable) { Host.CheckValue(args, nameof(args)); @@ -292,7 +292,7 @@ private protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnviro Host.Check(outputTypeMatches(scoreType), "Unexpected predictor output type"); var predColType = getPredColType(scoreType, rowMapper); - Bindings = BindingsImpl.Create(data.Schema, rowMapper, args.Suffix, scoreColKind, scoreColIndex, predColType); + Bindings = BindingsImpl.Create(data.Schema, rowMapper, args.Suffix, scoreColKind, scoreColIndex, predColType, predictedLabelColumnName); OutputSchema = Bindings.AsSchema; } diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index 53751614c2..2e4fff8b58 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -5,6 +5,7 @@ using System; using System.IO; using System.Reflection; +using System.Runtime.CompilerServices; using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Data.IO; @@ -448,44 +449,60 @@ public sealed class MulticlassPredictionTransformer : SingleFeaturePredi where TModel : class { private readonly string _trainLabelColumn; + private readonly string _scoreColumn; + private readonly string _predictedLabelColumn; [BestFriend] - internal MulticlassPredictionTransformer(IHostEnvironment env, TModel model, DataViewSchema inputSchema, string featureColumn, string labelColumn) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer)), model, inputSchema, featureColumn) + internal MulticlassPredictionTransformer(IHostEnvironment env, TModel model, DataViewSchema inputSchema, string featureColumn, string labelColumn, + string scoreColumn = AnnotationUtils.Const.ScoreValueKind.Score, string predictedLabel = DefaultColumnNames.PredictedLabel) : + base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer)), model, inputSchema, featureColumn) { Host.CheckValueOrNull(labelColumn); _trainLabelColumn = labelColumn; + _scoreColumn = scoreColumn; + _predictedLabelColumn = predictedLabel; SetScorer(); } internal MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer)), ctx) { - InitializationLogic(ctx, out _trainLabelColumn); + InitializationLogic(ctx, out _trainLabelColumn, out _scoreColumn, out _predictedLabelColumn); } internal MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx, IHost host, TModel model) : base(host, ctx, model) { - InitializationLogic(ctx, out _trainLabelColumn); + InitializationLogic(ctx, out _trainLabelColumn, out _scoreColumn, out _predictedLabelColumn); } - private void InitializationLogic(ModelLoadContext ctx, out string trainLabelColumn) + private void InitializationLogic(ModelLoadContext ctx, out string trainLabelColumn, out string scoreColumn, out string predictedLabelColumn) { // *** Binary format *** // // id of string: train label column trainLabelColumn = ctx.LoadStringOrNull(); + if (ctx.Header.ModelVerWritten >= 0x00010002) + { + scoreColumn = ctx.LoadStringOrNull(); + predictedLabelColumn = ctx.LoadStringOrNull(); + } + else + { + scoreColumn = AnnotationUtils.Const.ScoreValueKind.Score; + predictedLabelColumn = DefaultColumnNames.PredictedLabel; + } + SetScorer(); } private void SetScorer() { var schema = new RoleMappedSchema(TrainSchema, _trainLabelColumn, FeatureColumnName); - var args = new MulticlassClassificationScorer.Arguments(); + var args = new MulticlassClassificationScorer.Arguments() { ScoreColumnName = _scoreColumn, PredictedLabelColumnName = _predictedLabelColumn}; Scorer = new MulticlassClassificationScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } @@ -500,13 +517,16 @@ private protected override void SaveCore(ModelSaveContext ctx) base.SaveCore(ctx); ctx.SaveStringOrNull(_trainLabelColumn); + ctx.SaveStringOrNull(_scoreColumn); + ctx.SaveStringOrNull(_predictedLabelColumn); } private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "MC PRED", - verWrittenCur: 0x00010001, // Initial + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Score and Predicted Label column names. verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: MulticlassPredictionTransformer.LoaderSignature, diff --git a/src/Microsoft.ML.Dnn/DnnCatalog.cs b/src/Microsoft.ML.Dnn/DnnCatalog.cs index 3ca2411822..42559f1a62 100644 --- a/src/Microsoft.ML.Dnn/DnnCatalog.cs +++ b/src/Microsoft.ML.Dnn/DnnCatalog.cs @@ -3,9 +3,9 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Data; +using Microsoft.ML.Dnn; +using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Dnn; -using static Microsoft.ML.Transforms.ImageClassificationEstimator; namespace Microsoft.ML { @@ -80,59 +80,51 @@ internal static DnnRetrainEstimator RetrainDnnModel( /// /// Performs image classification using transfer learning. - /// usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for more information. + /// Usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document + /// for more information. /// /// /// /// - /// - /// The name of the input features column. - /// The name of the labels column. - /// The name of the output score column. - /// The name of the output predicted label columns. - /// Validation set. - - public static ImageClassificationEstimator ImageClassification( - this ModelOperationsCatalog catalog, - string featuresColumnName, - string labelColumnName, - string scoreColumnName = "Score", - string predictedLabelColumnName = "PredictedLabel", - IDataView validationSet = null - ) - { - var options = new ImageClassificationEstimator.Options() - { - FeaturesColumnName = featuresColumnName, - LabelColumnName = labelColumnName, - ScoreColumnName = scoreColumnName, - PredictedLabelColumnName = predictedLabelColumnName, - ValidationSet = validationSet - }; + /// Catalog + /// An object specifying advanced + /// options for . - return ImageClassification(catalog, options); - } + public static ImageClassificationTrainer ImageClassification( + this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + ImageClassificationTrainer.Options options) => + new ImageClassificationTrainer(CatalogUtils.GetEnvironment(catalog), options); /// /// Performs image classification using transfer learning. - /// usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for more information. + /// Usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for + /// more information. /// /// /// /// - /// - /// An object specifying advanced options for . - public static ImageClassificationEstimator ImageClassification( - this ModelOperationsCatalog catalog, Options options) - { - options.EarlyStoppingCriteria = options.DisableEarlyStopping ? null : options.EarlyStoppingCriteria ?? new EarlyStopping(); + /// Catalog + /// The name of the labels column. + /// The name of the input features column. + /// The name of the output score column. + /// The name of the output predicted label columns. + /// The validation set used while training to improve model quality. - var env = CatalogUtils.GetEnvironment(catalog); - return new ImageClassificationEstimator(env, options, DnnUtils.LoadDnnModel(env, options.Arch, true)); + public static ImageClassificationTrainer ImageClassification( + this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + string labelColumnName = DefaultColumnNames.Label, + string featureColumnName = DefaultColumnNames.Features, + string scoreColumnName = DefaultColumnNames.Score, + string predictedLabelColumnName = DefaultColumnNames.PredictedLabel, + IDataView validationSet = null) + { + Contracts.CheckValue(catalog, nameof(catalog)); + return new ImageClassificationTrainer(CatalogUtils.GetEnvironment(catalog), labelColumnName, + featureColumnName, scoreColumnName, predictedLabelColumnName, validationSet); } } } diff --git a/src/Microsoft.ML.Dnn/DnnModel.cs b/src/Microsoft.ML.Dnn/DnnModel.cs index a5324e9e39..8f8b840800 100644 --- a/src/Microsoft.ML.Dnn/DnnModel.cs +++ b/src/Microsoft.ML.Dnn/DnnModel.cs @@ -3,34 +3,30 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; using Tensorflow; -using static Microsoft.ML.Transforms.DnnRetrainEstimator; -namespace Microsoft.ML.Transforms +namespace Microsoft.ML.Dnn { /// /// This class holds the information related to TensorFlow model and session. /// It provides some convenient methods to query model schema as well as /// creation of object. /// - public sealed class DnnModel + internal sealed class DnnModel { internal Session Session { get; } internal string ModelPath { get; } - private readonly IHostEnvironment _env; - /// /// Instantiates . /// - /// An object. /// TensorFlow session object. /// Location of the model from where was loaded. - internal DnnModel(IHostEnvironment env, Session session, string modelLocation) + internal DnnModel(Session session, string modelLocation) { Session = session; ModelPath = modelLocation; - _env = env; } } } diff --git a/src/Microsoft.ML.Dnn/DnnRetrainTransform.cs b/src/Microsoft.ML.Dnn/DnnRetrainTransform.cs index 5ca961063c..f9b998b8b1 100644 --- a/src/Microsoft.ML.Dnn/DnnRetrainTransform.cs +++ b/src/Microsoft.ML.Dnn/DnnRetrainTransform.cs @@ -11,13 +11,13 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; +using Microsoft.ML.Dnn; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Dnn; using NumSharp; using Tensorflow; -using static Microsoft.ML.Transforms.Dnn.DnnUtils; +using static Microsoft.ML.Dnn.DnnUtils; using static Tensorflow.Binding; [assembly: LoadableClass(DnnRetrainTransformer.Summary, typeof(IDataTransform), typeof(DnnRetrainTransformer), diff --git a/src/Microsoft.ML.Dnn/DnnUtils.cs b/src/Microsoft.ML.Dnn/DnnUtils.cs index 57d5e1b9a2..1e9f2f7ef8 100644 --- a/src/Microsoft.ML.Dnn/DnnUtils.cs +++ b/src/Microsoft.ML.Dnn/DnnUtils.cs @@ -13,9 +13,10 @@ using Microsoft.ML.Data; using Microsoft.ML.Runtime; using Tensorflow; +using static Microsoft.ML.Dnn.ImageClassificationTrainer; using static Tensorflow.Binding; -namespace Microsoft.ML.Transforms.Dnn +namespace Microsoft.ML.Dnn { internal static class DnnUtils { @@ -260,25 +261,15 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity) /// The model to load. /// /// - internal static DnnModel LoadDnnModel(IHostEnvironment env, string modelPath, bool metaGraph = false) - { - var session = GetSession(env, modelPath, metaGraph); - return new DnnModel(env, session, modelPath); - } + internal static DnnModel LoadDnnModel(IHostEnvironment env, string modelPath, bool metaGraph = false) => + new DnnModel(GetSession(env, modelPath, metaGraph), modelPath); - /// - /// Load TensorFlow model into memory. - /// - /// The environment to use. - /// The architecture of the model to load. - /// - /// - internal static DnnModel LoadDnnModel(IHostEnvironment env, ImageClassificationEstimator.Architecture arch, bool metaGraph = false) + internal static DnnModel LoadDnnModel(IHostEnvironment env, Architecture arch, bool metaGraph = false) { - var modelPath = ImageClassificationEstimator.ModelLocation[arch]; + var modelPath = ModelLocation[arch]; if (!File.Exists(modelPath)) { - if (arch == ImageClassificationEstimator.Architecture.InceptionV3) + if (arch == Architecture.InceptionV3) { var baseGitPath = @"https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta"; using (WebClient client = new WebClient()) @@ -293,7 +284,7 @@ internal static DnnModel LoadDnnModel(IHostEnvironment env, ImageClassificationE ZipFile.ExtractToDirectory(Path.Combine(Directory.GetCurrentDirectory(), @"tfhub_modules.zip"), @"tfhub_modules"); } } - else if (arch == ImageClassificationEstimator.Architecture.ResnetV2101) + else if (arch == Architecture.ResnetV2101) { var baseGitPath = @"https://aka.ms/mlnet-resources/image/ResNet101Tensorflow/resnet_v2_101_299.meta"; using (WebClient client = new WebClient()) @@ -301,7 +292,7 @@ internal static DnnModel LoadDnnModel(IHostEnvironment env, ImageClassificationE client.DownloadFile(new Uri($"{baseGitPath}"), @"resnet_v2_101_299.meta"); } } - else if (arch == ImageClassificationEstimator.Architecture.MobilenetV2) + else if (arch == Architecture.MobilenetV2) { var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/MobileNetV2TensorFlow/mobilenet_v2.meta"; using (WebClient client = new WebClient()) @@ -309,7 +300,7 @@ internal static DnnModel LoadDnnModel(IHostEnvironment env, ImageClassificationE client.DownloadFile(new Uri($"{baseGitPath}"), @"mobilenet_v2.meta"); } } - else if (arch == ImageClassificationEstimator.Architecture.ResnetV250) + else if (arch == Architecture.ResnetV250) { var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/ResNetV250TensorFlow/resnet_v2_50_299.meta"; using (WebClient client = new WebClient()) @@ -320,8 +311,7 @@ internal static DnnModel LoadDnnModel(IHostEnvironment env, ImageClassificationE } - var session = GetSession(env, modelPath, metaGraph); - return new DnnModel(env, session, modelPath); + return new DnnModel(GetSession(env, modelPath, metaGraph), modelPath); } internal static Session GetSession(IHostEnvironment env, string modelPath, bool metaGraph = false) diff --git a/src/Microsoft.ML.Dnn/ImageClassificationTransform.cs b/src/Microsoft.ML.Dnn/ImageClassificationTrainer.cs similarity index 58% rename from src/Microsoft.ML.Dnn/ImageClassificationTransform.cs rename to src/Microsoft.ML.Dnn/ImageClassificationTrainer.cs index 19d77b440f..2cbcd5b8ca 100644 --- a/src/Microsoft.ML.Dnn/ImageClassificationTransform.cs +++ b/src/Microsoft.ML.Dnn/ImageClassificationTrainer.cs @@ -11,1598 +11,1444 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; +using Microsoft.ML.Dnn; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; +using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Dnn; using Tensorflow; using Tensorflow.Summaries; using static Microsoft.ML.Data.TextLoader; -using static Microsoft.ML.Transforms.Dnn.DnnUtils; -using static Microsoft.ML.Transforms.ImageClassificationEstimator; +using static Microsoft.ML.Dnn.DnnUtils; using static Tensorflow.Binding; -using Architecture = Microsoft.ML.Transforms.ImageClassificationEstimator.Architecture; +using Column = Microsoft.ML.Data.TextLoader.Column; -[assembly: LoadableClass(ImageClassificationTransformer.Summary, typeof(IDataTransform), typeof(ImageClassificationTransformer), - typeof(ImageClassificationEstimator.Options), typeof(SignatureDataTransform), ImageClassificationTransformer.UserName, ImageClassificationTransformer.ShortName)] +[assembly: LoadableClass(ImageClassificationTrainer.Summary, typeof(ImageClassificationTrainer), + typeof(ImageClassificationTrainer.Options), + new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer) }, + ImageClassificationTrainer.UserName, + ImageClassificationTrainer.LoadName, + ImageClassificationTrainer.ShortName)] -[assembly: LoadableClass(ImageClassificationTransformer.Summary, typeof(IDataTransform), typeof(ImageClassificationTransformer), null, typeof(SignatureLoadDataTransform), - ImageClassificationTransformer.UserName, ImageClassificationTransformer.LoaderSignature)] +[assembly: LoadableClass(typeof(ImageClassificationModelParameters), null, typeof(SignatureLoadModel), + "Image classification predictor", ImageClassificationModelParameters.LoaderSignature)] -[assembly: LoadableClass(typeof(ImageClassificationTransformer), null, typeof(SignatureLoadModel), - ImageClassificationTransformer.UserName, ImageClassificationTransformer.LoaderSignature)] - -[assembly: LoadableClass(typeof(IRowMapper), typeof(ImageClassificationTransformer), null, typeof(SignatureLoadRowMapper), - ImageClassificationTransformer.UserName, ImageClassificationTransformer.LoaderSignature)] - -namespace Microsoft.ML.Transforms +namespace Microsoft.ML.Dnn { /// - /// for the . + /// The for training a Deep Neural Network(DNN) to classify images. /// - public sealed class ImageClassificationTransformer : RowToRowTransformerBase + /// + /// . + /// + /// This trainer outputs the following columns: + /// + /// | Output Column Name | Column Type | Description| + /// | -- | -- | -- | + /// | `Score` | Vector of | The scores of all classes.Higher value means higher probability to fall into the associated class. If the i-th element has the largest value, the predicted label index would be i.Note that i is zero-based index. | + /// | `PredictedLabel` | [key](xref:Microsoft.ML.Data.KeyDataViewType) type | The predicted label's index. If its value is i, the actual label would be the i-th category in the key-valued input label type. | + /// + /// ### Trainer Characteristics + /// | | | + /// | -- | -- | + /// | Machine learning task | Multiclass classification | + /// | Is normalization required? | No | + /// | Is caching required? | No | + /// | Required NuGet in addition to Microsoft.ML | Micrsoft.ML.Dnn and SciSharp.TensorFlow.Redist / SciSharp.TensorFlow.Redist-Windows-GPU / SciSharp.TensorFlow.Redist-Linux-GPU | + /// + /// [!include[io](~/../docs/samples/docs/api-reference/tensorflow-usage.md)] + /// + /// ### Training Algorithm Details + /// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained model such as Resnet50 for the purpose + /// of classifying images. + /// ]]> + /// + /// + public sealed class ImageClassificationTrainer : + TrainerEstimatorBase, + ImageClassificationModelParameters> { - private readonly IHostEnvironment _env; - private readonly bool _addBatchDimensionInput; - private Session _session; - private Tensor _bottleneckTensor; - private Tensor _learningRateInput; - private Operation _trainStep; - private Tensor _softMaxTensor; - private Tensor _crossEntropy; - private Tensor _labelTensor; - private Tensor _evaluationStep; - private Tensor _prediction; - private Tensor _bottleneckInput; - private Tensor _jpegData; - private Tensor _resizedImage; - private string _jpegDataTensorName; - private string _resizedImageTensorName; - private string _inputTensorName; - private readonly int _classCount; - private readonly string _checkpointPath; - private readonly string _bottleneckOperationName; - private Graph Graph => _session.graph; - private readonly string[] _inputs; - private readonly string[] _outputs; - private ReadOnlyMemory[] _keyValueAnnotations; - private readonly string _labelColumnName; - private readonly string _finalModelPrefix; - private readonly Architecture _arch; - private readonly string _scoreColumnName; - private readonly string _predictedLabelColumnName; - private readonly float _learningRate; - private readonly bool _useLRScheduling; - private readonly string _softmaxTensorName; - private readonly string _predictionTensorName; - internal const string Summary = "Trains Dnn models."; - internal const string UserName = "ImageClassificationTransform"; - internal const string ShortName = "ImgClsTrans"; - internal const string LoaderSignature = "ImageClassificationTrans"; + internal const string LoadName = "ImageClassificationTrainer"; + internal const string UserName = "Image Classification Trainer"; + internal const string ShortName = "IMGCLSS"; + internal const string Summary = "Trains a DNN model to classify images."; - private static VersionInfo GetVersionInfo() + /// + /// Image classification model. + /// + public enum Architecture { - return new VersionInfo( - modelSignature: "IMGTRANS", - //verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x00000001, - verReadableCur: 0x00000001, - verWeCanReadBack: 0x00000001, - loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(ImageClassificationTransformer).Assembly.FullName); - } + ResnetV2101, + InceptionV3, + MobilenetV2, + ResnetV250 + }; - // Factory method for SignatureLoadModel. - private static ImageClassificationTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + /// + /// Dictionary mapping model architecture to model location. + /// + internal static IReadOnlyDictionary ModelLocation = new Dictionary { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); - - // *** Binary format *** - // byte: indicator for frozen models - // byte: indicator for adding batch dimension in input - // int: number of input columns - // for each input column - // int: id of int column name - // int: number of output columns - // for each output column - // int: id of output column name - // string: value of label column name - // string: prefix pf final model and checkpoint files/folder for storing graph files - // int: value of the utilized model architecture for transfer learning - // string: value of score column name - // string: value of predicted label column name - // float: value of learning rate - // bool: uses learning rate scheduling if true - // int: number of prediction classes - // for each key value annotation column - // string: value of key value annotations - // string: name of prediction tensor - // string: name of softmax tensor - // string: name of JPEG data tensor - // string: name of resized image tensor - // stream (byte): tensorFlow model. - - GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool addBatchDimensionInput, - out string labelColumn, out string checkpointName, out Architecture arch, out string scoreColumnName, - out string predictedColumnName, out float learningRate, out bool useLearningRateScheduling, out int classCount, - out string[] keyValueAnnotations, out string predictionTensorName, out string softMaxTensorName, - out string jpegDataTensorName, out string resizeTensorName); - - byte[] modelBytes = null; - if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) - throw env.ExceptDecode(); - - return new ImageClassificationTransformer(env, DnnUtils.LoadTFSession(env, modelBytes), outputs, inputs, - addBatchDimensionInput, 1, labelColumn, checkpointName, arch, - scoreColumnName, predictedColumnName, learningRate, useLearningRateScheduling, null, null, classCount, true, predictionTensorName, - softMaxTensorName, jpegDataTensorName, resizeTensorName, keyValueAnnotations); - - } + { Architecture.ResnetV2101, @"resnet_v2_101_299.meta" }, + { Architecture.InceptionV3, @"InceptionV3.meta" }, + { Architecture.MobilenetV2, @"mobilenet_v2.meta" }, + { Architecture.ResnetV250, @"resnet_v2_50_299.meta" } + }; - // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, ImageClassificationEstimator.Options options, IDataView input) + /// + /// Dictionary mapping model architecture to image input size supported. + /// + internal static IReadOnlyDictionary> ImagePreprocessingSize = + new Dictionary> { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(options, nameof(options)); - env.CheckValue(input, nameof(input)); - env.CheckValue(options.InputColumns, nameof(options.InputColumns)); - env.CheckValue(options.OutputColumns, nameof(options.OutputColumns)); + { Architecture.ResnetV2101, new Tuple(299,299) }, + { Architecture.InceptionV3, new Tuple(299,299) }, + { Architecture.MobilenetV2, new Tuple(224,224) }, + { Architecture.ResnetV250, new Tuple(299,299) } + }; - return new ImageClassificationTransformer(env, options, input).MakeDataTransform(input); + /// + /// Indicates the metric to be monitored to decide Early Stopping criteria. + /// + public enum EarlyStoppingMetric + { + Accuracy, + Loss } - internal ImageClassificationTransformer(IHostEnvironment env, ImageClassificationEstimator.Options options, IDataView input) - : this(env, options, DnnUtils.LoadDnnModel(env, options.Arch), input) + /// + /// DNN training metrics. + /// + public sealed class TrainMetrics { - } + /// + /// Indicates the dataset on which metrics are being reported. + /// + /// + public ImageClassificationMetrics.Dataset DatasetUsed { get; set; } - internal ImageClassificationTransformer(IHostEnvironment env, ImageClassificationEstimator.Options options, DnnModel tensorFlowModel, IDataView input) - : this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns, null, options.BatchSize, - options.LabelColumnName, options.FinalModelPrefix, options.Arch, options.ScoreColumnName, - options.PredictedLabelColumnName, options.LearningRate, options.LearningRateScheduler == null? false:true, options.ModelSavePath, input.Schema) + /// + /// The number of batches processed in an epoch. + /// + public int BatchProcessedCount { get; set; } - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(options, nameof(options)); - env.CheckValue(input, nameof(input)); - CheckTrainingParameters(options); - var imageProcessor = new ImageProcessor(this); - int trainingsetSize = -1; - if (!options.ReuseTrainSetBottleneckCachedValues || !File.Exists(options.TrainSetBottleneckCachedValuesFilePath)) - { - trainingsetSize = CacheFeaturizedImagesToDisk(input, options.LabelColumnName, options.InputColumns[0], imageProcessor, - _inputTensorName, _bottleneckTensor.name, options.TrainSetBottleneckCachedValuesFilePath, - ImageClassificationMetrics.Dataset.Train, options.MetricsCallback); - File.WriteAllText("TrainingSetSize.txt", trainingsetSize.ToString()); // Write training set size to a file for use during training + /// + /// The training epoch index for which this metric is reported. + /// + public int Epoch { get; set; } - } + /// + /// Accuracy of the batch on this . Higher the better. + /// + public float Accuracy { get; set; } + + /// + /// Cross-Entropy (loss) of the batch on this . Lower + /// the better. + /// + public float CrossEntropy { get; set; } - if (options.ValidationSet != null && - (!options.ReuseTrainSetBottleneckCachedValues || !File.Exists(options.ValidationSetBottleneckCachedValuesFilePath))) - CacheFeaturizedImagesToDisk(options.ValidationSet, options.LabelColumnName, options.InputColumns[0], - imageProcessor, _inputTensorName, _bottleneckTensor.name, options.ValidationSetBottleneckCachedValuesFilePath, - ImageClassificationMetrics.Dataset.Validation, options.MetricsCallback); + /// + /// Learning Rate used for this . Changes for learning rate scheduling. + /// + public float LearningRate { get; set; } - TrainAndEvaluateClassificationLayer(options.TrainSetBottleneckCachedValuesFilePath, options, options.ValidationSetBottleneckCachedValuesFilePath, trainingsetSize); + /// + /// String representation of the metrics. + /// + public override string ToString() + { + if (DatasetUsed == ImageClassificationMetrics.Dataset.Train) + return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, Learning Rate: {LearningRate,10} " + + $"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}, Cross-Entropy: {CrossEntropy,10}"; + else + return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, " + + $"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}"; + } } - private void CheckTrainingParameters(ImageClassificationEstimator.Options options) + /// + /// Metrics for image featurization values. The input image is passed through + /// the network and features are extracted from second or last layer to + /// train a custom full connected layer that serves as classifier. + /// + public sealed class BottleneckMetrics { - Host.CheckNonWhiteSpace(options.LabelColumnName, nameof(options.LabelColumnName)); + /// + /// Indicates the dataset on which metrics are being reported. + /// + /// + public ImageClassificationMetrics.Dataset DatasetUsed { get; set; } - if (_session.graph.OperationByName(_labelTensor.name.Split(':')[0]) == null) - throw Host.ExceptParam(nameof(_labelTensor.name), $"'{_labelTensor.name}' does not exist in the model"); - if (options.EarlyStoppingCriteria != null && options.ValidationSet == null && options.TestOnTrainSet == false) - throw Host.ExceptParam(nameof(options.EarlyStoppingCriteria), $"Early stopping enabled but unable to find a validation" + - $" set and/or train set testing disabled. Please disable early stopping or either provide a validation set or enable train set training."); - } + /// + /// Index of the input image. + /// + public int Index { get; set; } - private (Tensor, Tensor) AddJpegDecoding(int height, int width, int depth) - { - // height, width, depth - var inputDim = (height, width, depth); - var jpegData = tf.placeholder(tf.@string, name: "DecodeJPGInput"); - var decodedImage = tf.image.decode_jpeg(jpegData, channels: inputDim.Item3); - // Convert from full range of uint8 to range [0,1] of float32. - var decodedImageAsFloat = tf.image.convert_image_dtype(decodedImage, tf.float32); - var decodedImage4d = tf.expand_dims(decodedImageAsFloat, 0); - var resizeShape = tf.stack(new int[] { inputDim.Item1, inputDim.Item2 }); - var resizeShapeAsInt = tf.cast(resizeShape, dtype: tf.int32); - var resizedImage = tf.image.resize_bilinear(decodedImage4d, resizeShapeAsInt, false, "ResizeTensor"); - return (jpegData, resizedImage); + /// + /// String representation of the metrics. + /// + public override string ToString() => $"Phase: Bottleneck Computation, Dataset used: {DatasetUsed.ToString(),10}, Image Index: {Index,3}"; } - private static Tensor EncodeByteAsString(VBuffer buffer) + /// + /// Early Stopping feature stops training when monitored quantity stops improving'. + /// Modeled after https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/ + /// tensorflow/python/keras/callbacks.py#L1143 + /// + public sealed class EarlyStopping { - int length = buffer.Length; - var size = c_api.TF_StringEncodedSize((UIntPtr)length); - var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); + /// + /// Best value of metric seen so far. + /// + private float _bestMetricValue; - IntPtr tensor = c_api.TF_TensorData(handle); - Marshal.WriteInt64(tensor, 0); + /// + /// Current counter for number of epochs where there has been no improvement. + /// + private int _wait; - var status = new Status(); - unsafe - { - fixed (byte* src = buffer.GetValues()) - c_api.TF_StringEncode(src, (UIntPtr)length, (sbyte*)(tensor + sizeof(Int64)), size, status); - } + /// + /// The metric to be monitored (eg Accuracy, Loss). + /// + private EarlyStoppingMetric _metric; - status.Check(true); - status.Dispose(); - return new Tensor(handle); - } + /// + /// Minimum change in the monitored quantity to be considered as an improvement. + /// + public float MinDelta { get; set; } - private sealed class ImageProcessor - { - private Runner _imagePreprocessingRunner; + /// + /// Number of epochs to wait after no improvement is seen consecutively + /// before stopping the training. + /// + public int Patience { get; set; } - public ImageProcessor(ImageClassificationTransformer transformer) - { - _imagePreprocessingRunner = new Runner(transformer._session); - _imagePreprocessingRunner.AddInput(transformer._jpegDataTensorName); - _imagePreprocessingRunner.AddOutputs(transformer._resizedImageTensorName); - } + /// + /// Whether the monitored quantity is to be increasing (eg. Accuracy, CheckIncreasing = true) + /// or decreasing (eg. Loss, CheckIncreasing = false). + /// + public bool CheckIncreasing { get; set; } - public Tensor ProcessImage(in VBuffer imageBuffer) + /// + /// + /// + /// + public EarlyStopping(float minDelta = 0.01f, int patience = 20, EarlyStoppingMetric metric = EarlyStoppingMetric.Accuracy, bool checkIncreasing = true) { - var imageTensor = EncodeByteAsString(imageBuffer); - var processedTensor = _imagePreprocessingRunner.AddInput(imageTensor, 0).Run()[0]; - imageTensor.Dispose(); - return processedTensor; - } - } - - private int CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName, string imageColumnName, - ImageProcessor imageProcessor, string inputTensorName, string outputTensorName, string cacheFilePath, - ImageClassificationMetrics.Dataset dataset, ImageClassificationMetricsCallback metricsCallback) - { - var labelColumn = input.Schema[labelColumnName]; + _bestMetricValue = 0.0f; + _wait = 0; + _metric = metric; + MinDelta = Math.Abs(minDelta); + Patience = patience; + CheckIncreasing = checkIncreasing; - if (labelColumn.Type.RawType != typeof(UInt32)) - throw Host.ExceptSchemaMismatch(nameof(labelColumn), "Label", - labelColumnName, typeof(uint).ToString(), - labelColumn.Type.RawType.ToString()); + //Set the CheckIncreasing according to the metric being monitored + if (metric == EarlyStoppingMetric.Accuracy) + CheckIncreasing = true; + else if (metric == EarlyStoppingMetric.Loss) + CheckIncreasing = false; + } - var imageColumn = input.Schema[imageColumnName]; - Runner runner = new Runner(_session); - runner.AddOutputs(outputTensorName); - int datasetsize = 0; - using (TextWriter writer = File.CreateText(cacheFilePath)) - using (var cursor = input.GetRowCursor(input.Schema.Where(c => c.Index == labelColumn.Index || c.Index == imageColumn.Index))) + /// + /// To be called at the end of every epoch to check if training should stop. + /// For increasing metric(eg.: Accuracy), if metric stops increasing, stop training if + /// value of metric doesn't increase within 'patience' number of epochs. + /// For decreasing metric(eg.: Loss), stop training if value of metric doesn't decrease + /// within 'patience' number of epochs. + /// Any change in the value of metric of less than 'minDelta' is not considered a change. + /// + public bool ShouldStop(TrainMetrics currentMetrics) { - var labelGetter = cursor.GetGetter(labelColumn); - var imageGetter = cursor.GetGetter>(imageColumn); - UInt32 label = UInt32.MaxValue; - VBuffer image = default; - runner.AddInput(inputTensorName); - ImageClassificationMetrics metrics = new ImageClassificationMetrics(); - metrics.Bottleneck = new BottleneckMetrics(); - metrics.Bottleneck.DatasetUsed = dataset; - float[] imageArray = null; - while (cursor.MoveNext()) - { - labelGetter(ref label); - imageGetter(ref image); - if (image.Length <= 0) - continue; //Empty Image + float currentMetricValue = _metric == EarlyStoppingMetric.Accuracy ? currentMetrics.Accuracy : currentMetrics.CrossEntropy; - var imageTensor = imageProcessor.ProcessImage(image); - runner.AddInput(imageTensor, 0); - var featurizedImage = runner.Run()[0]; // Reuse memory - featurizedImage.ToArray(ref imageArray); - Host.Assert((int)featurizedImage.size == imageArray.Length); - writer.WriteLine(label - 1 + "," + string.Join(",", imageArray)); - featurizedImage.Dispose(); - imageTensor.Dispose(); - metrics.Bottleneck.Index++; - metricsCallback?.Invoke(metrics); + if (CheckIncreasing) + { + if ((currentMetricValue - _bestMetricValue) < MinDelta) + { + _wait += 1; + if (_wait >= Patience) + return true; + } + else + { + _wait = 0; + _bestMetricValue = currentMetricValue; + } } - datasetsize = metrics.Bottleneck.Index; - } - return datasetsize; - } - - private IDataView GetShuffledData(string path) - { - return new RowShufflingTransformer( - _env, - new RowShufflingTransformer.Options + else { - ForceShuffle = true, - ForceShuffleSource = true - }, - new TextLoader( - _env, - new TextLoader.Options + if ((_bestMetricValue - currentMetricValue) < MinDelta) { - Separators = new[] { ',' }, - Columns = new[] - { - new Column("Label", DataKind.Int64, 0), - new Column("Features", DataKind.Single, new [] { new Range(1, null) }), - }, - }, - new MultiFileSource(path)) - .Load(new MultiFileSource(path))); - } - - private int GetNumSamples(string path) - { - using var reader = File.OpenText(path); - return int.Parse(reader.ReadLine()); + _wait += 1; + if (_wait >= Patience) + return true; + } + else + { + _wait = 0; + _bestMetricValue = currentMetricValue; + } + } + return false; + } } - private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath, ImageClassificationEstimator.Options options, - string validationSetBottleneckFilePath, int trainingsetSize) + /// + /// Metrics for image classification bottlenect phase and training. + /// Train metrics may be null when bottleneck phase is running, so have check! + /// + public sealed class ImageClassificationMetrics { - int batchSize = options.BatchSize; - int epochs = options.Epoch; - float learningRate = options.LearningRate; - bool evaluateOnly = !string.IsNullOrEmpty(validationSetBottleneckFilePath); - ImageClassificationMetricsCallback statisticsCallback = options.MetricsCallback; - var trainingSet = GetShuffledData(trainBottleneckFilePath); - IDataView validationSet = null; - if (options.ValidationSet != null && !string.IsNullOrEmpty(validationSetBottleneckFilePath)) - validationSet = GetShuffledData(validationSetBottleneckFilePath); - - long label = long.MaxValue; - VBuffer features = default; - ReadOnlySpan featureValues = default; - var featureColumn = trainingSet.Schema[1]; - int featureLength = featureColumn.Type.GetVectorSize(); - float[] featureBatch = new float[featureLength * batchSize]; - var featureBatchHandle = GCHandle.Alloc(featureBatch, GCHandleType.Pinned); - IntPtr featureBatchPtr = featureBatchHandle.AddrOfPinnedObject(); - int featureBatchSizeInBytes = sizeof(float) * featureBatch.Length; - long[] labelBatch = new long[batchSize]; - var labelBatchHandle = GCHandle.Alloc(labelBatch, GCHandleType.Pinned); - IntPtr labelBatchPtr = labelBatchHandle.AddrOfPinnedObject(); - int labelBatchSizeInBytes = sizeof(long) * labelBatch.Length; - var labelTensorShape = _labelTensor.TensorShape.dims.Select(x => (long)x).ToArray(); - labelTensorShape[0] = batchSize; - int batchIndex = 0; - var runner = new Runner(_session); - var testEvalRunner = new Runner(_session); - testEvalRunner.AddOutputs(_evaluationStep.name); - testEvalRunner.AddOutputs(_crossEntropy.name); - - Runner validationEvalRunner = null; - if (validationSet != null) + /// + /// Indicates the kind of the dataset of which metric is reported. + /// + public enum Dataset { - validationEvalRunner = new Runner(_session); - validationEvalRunner.AddOutputs(_evaluationStep.name); - validationEvalRunner.AddInput(_bottleneckInput.name).AddInput(_labelTensor.name); + Train, + Validation } - runner.AddOperation(_trainStep); - var featureTensorShape = _bottleneckInput.TensorShape.dims.Select(x => (long)x).ToArray(); - featureTensorShape[0] = batchSize; - - Saver trainSaver = null; - FileWriter trainWriter = null; - Tensor merged = tf.summary.merge_all(); - trainWriter = tf.summary.FileWriter(Path.Combine(Directory.GetCurrentDirectory(), "train"), _session.graph); - trainSaver = tf.train.Saver(); - trainSaver.save(_session, _checkpointPath); - - runner.AddInput(_bottleneckInput.name).AddInput(_labelTensor.name); - if(options.LearningRateScheduler != null) - runner.AddInput(_learningRateInput.name); - testEvalRunner.AddInput(_bottleneckInput.name).AddInput(_labelTensor.name); - Dictionary classStatsTrain = new Dictionary(); - Dictionary classStatsValidate = new Dictionary(); - for (int index = 0; index < _classCount; index += 1) - classStatsTrain[index] = classStatsValidate[index] = 0; - - ImageClassificationMetrics metrics = new ImageClassificationMetrics(); - metrics.Train = new TrainMetrics(); - float accuracy = 0; - float crossentropy = 0; - TrainState trainstate = new TrainState - { - BatchSize = options.BatchSize, - BatchesPerEpoch = (trainingsetSize < 0 ? GetNumSamples("TrainingSetSize.txt") : trainingsetSize) / options.BatchSize - }; - - for (int epoch = 0; epoch < epochs; epoch += 1) - { - batchIndex = 0; - metrics.Train.Accuracy = 0; - metrics.Train.CrossEntropy = 0; - metrics.Train.BatchProcessedCount = 0; - metrics.Train.LearningRate = learningRate; - // Update train state. - trainstate.CurrentEpoch = epoch; - using (var cursor = trainingSet.GetRowCursor(trainingSet.Schema.ToArray(), new Random())) - { - var labelGetter = cursor.GetGetter(trainingSet.Schema[0]); - var featuresGetter = cursor.GetGetter>(featureColumn); - while (cursor.MoveNext()) - { - labelGetter(ref label); - featuresGetter(ref features); - classStatsTrain[label]++; - - if (featureValues == default) - featureValues = features.GetValues(); - - // Buffer the values. - for (int index = 0; index < featureLength; index += 1) - featureBatch[batchIndex * featureLength + index] = featureValues[index]; - - labelBatch[batchIndex] = label; - batchIndex += 1; - trainstate.CurrentBatchIndex = batchIndex; - // Train. - if (batchIndex == batchSize) - { - runner.AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0) - .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1); - if (options.LearningRateScheduler != null) - { - // Add learning rate as a placeholder only when learning rate scheduling is used. - learningRate = options.LearningRateScheduler.GetLearningRate(trainstate); - metrics.Train.LearningRate = learningRate; - runner.AddInput(new Tensor(learningRate, TF_DataType.TF_FLOAT), 2); - } - runner.Run(); + /// + /// Contains train time metrics. + /// + public TrainMetrics Train { get; set; } - metrics.Train.BatchProcessedCount += 1; - if (options.TestOnTrainSet && statisticsCallback != null) - { - var outputTensors = testEvalRunner - .AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0) - .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1) - .Run(); + /// + /// Contains pre-train time metrics. These contains metrics on image + /// featurization. + /// + public BottleneckMetrics Bottleneck { get; set; } - outputTensors[0].ToScalar(ref accuracy); - outputTensors[1].ToScalar(ref crossentropy); - metrics.Train.Accuracy += accuracy; - metrics.Train.CrossEntropy += crossentropy; + /// + /// String representation of the metrics. + /// + public override string ToString() => Train != null ? Train.ToString() : Bottleneck.ToString(); + } - outputTensors[0].Dispose(); - outputTensors[1].Dispose(); - } + /// + /// Options class for . + /// + public sealed class Options : TrainerInputBaseWithLabel + { + /// + /// Number of samples to use for mini-batch training. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Number of samples to use for mini-batch training.", SortOrder = 9)] + public int BatchSize = 64; - batchIndex = 0; - } - } + /// + /// Number of training iterations. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Number of training iterations.", SortOrder = 10)] + public int Epoch = 100; - //Process last incomplete batch - if (batchIndex > 0) - { - featureTensorShape[0] = batchIndex; - featureBatchSizeInBytes = sizeof(float) * featureLength * batchIndex; - labelTensorShape[0] = batchIndex; - labelBatchSizeInBytes = sizeof(long) * batchIndex; - runner.AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0) - .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1); - if (options.LearningRateScheduler != null) - { - // Add learning rate as a placeholder only when learning rate scheduling is used. - learningRate = options.LearningRateScheduler.GetLearningRate(trainstate); - metrics.Train.LearningRate = learningRate; - runner.AddInput(new Tensor(learningRate, TF_DataType.TF_FLOAT), 2); - } - runner.Run(); + /// + /// Learning rate to use during optimization. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate to use during optimization.", SortOrder = 12)] + public float LearningRate = 0.01f; - metrics.Train.BatchProcessedCount += 1; + /// + /// Early stopping technique parameters to be used to terminate training when training metric stops improving. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Early stopping technique parameters to be used to terminate training when training metric stops improving.", SortOrder = 15)] + public EarlyStopping EarlyStoppingCriteria = new EarlyStopping(); - if (options.TestOnTrainSet && statisticsCallback != null) - { - var outputTensors = testEvalRunner - .AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0) - .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1) - .Run(); + /// + /// Specifies the model architecture to be used in the case of image classification training using transfer learning. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Model architecture to be used in transfer learning for image classification.", SortOrder = 15)] + public Architecture Arch = Architecture.InceptionV3; - outputTensors[0].ToScalar(ref accuracy); - outputTensors[1].ToScalar(ref crossentropy); - metrics.Train.Accuracy += accuracy; - metrics.Train.CrossEntropy += crossentropy; + /// + /// Name of the tensor that will contain the output scores of the last layer when transfer learning is done. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Softmax tensor of the last layer in transfer learning.", SortOrder = 15)] + public string ScoreColumnName = "Score"; - outputTensors[0].Dispose(); - outputTensors[1].Dispose(); - } + /// + /// Name of the tensor that will contain the predicted label from output scores of the last layer when transfer learning is done. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Argmax tensor of the last layer in transfer learning.", SortOrder = 15)] + public string PredictedLabelColumnName = "PredictedLabel"; - batchIndex = 0; - featureTensorShape[0] = batchSize; - featureBatchSizeInBytes = sizeof(float) * featureBatch.Length; - labelTensorShape[0] = batchSize; - labelBatchSizeInBytes = sizeof(long) * batchSize; - } + /// + /// Final model and checkpoint files/folder prefix for storing graph files. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Final model and checkpoint files/folder prefix for storing graph files.", SortOrder = 15)] + public string FinalModelPrefix = "custom_retrained_model_based_on_"; - if (options.TestOnTrainSet && statisticsCallback != null) - { - metrics.Train.Epoch = epoch; - metrics.Train.Accuracy /= metrics.Train.BatchProcessedCount; - metrics.Train.CrossEntropy /= metrics.Train.BatchProcessedCount; - metrics.Train.DatasetUsed = ImageClassificationMetrics.Dataset.Train; - statisticsCallback(metrics); - } - } + /// + /// Callback to report statistics on accuracy/cross entropy during training phase. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Callback to report metrics during training and validation phase.", SortOrder = 15)] + public Action MetricsCallback = null; - if (validationSet == null) - { - //Early stopping check - if (options.EarlyStoppingCriteria != null) - { - if (options.EarlyStoppingCriteria.ShouldStop(metrics.Train)) - break; - } - continue; - } + /// + /// Indicates the path where the newly retrained model should be saved. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the path where the newly retrained model should be saved.", SortOrder = 15)] + public string ModelSavePath = null; - batchIndex = 0; - metrics.Train.BatchProcessedCount = 0; - metrics.Train.Accuracy = 0; - metrics.Train.CrossEntropy = 0; - using (var cursor = validationSet.GetRowCursor(validationSet.Schema.ToArray(), new Random())) - { - var labelGetter = cursor.GetGetter(validationSet.Schema[0]); - var featuresGetter = cursor.GetGetter>(featureColumn); - while (cursor.MoveNext()) - { - labelGetter(ref label); - featuresGetter(ref features); - classStatsValidate[label]++; - // Buffer the values. - for (int index = 0; index < featureLength; index += 1) - featureBatch[batchIndex * featureLength + index] = featureValues[index]; + /// + /// Indicates to evaluate the model on train set after every epoch. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to evaluate the model on train set after every epoch.", SortOrder = 15)] + public bool TestOnTrainSet = true; - labelBatch[batchIndex] = label; - batchIndex += 1; - // Evaluate. - if (batchIndex == batchSize) - { - var outputTensors = validationEvalRunner - .AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0) - .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1) - .Run(); + /// + /// Indicates to not re-compute cached bottleneck trainset values if already available in the bin folder. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to not re-compute trained cached bottleneck values if already available in the bin folder.", SortOrder = 15)] + public bool ReuseTrainSetBottleneckCachedValues = false; - outputTensors[0].ToScalar(ref accuracy); - metrics.Train.Accuracy += accuracy; - metrics.Train.BatchProcessedCount += 1; - batchIndex = 0; + /// + /// Indicates to not re-compute cached bottleneck validationset values if already available in the bin folder. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to not re-compute validataionset cached bottleneck validationset values if already available in the bin folder.", SortOrder = 15)] + public bool ReuseValidationSetBottleneckCachedValues = false; - outputTensors[0].Dispose(); - } - } + /// + /// Validation set. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Validation set.", SortOrder = 15)] + public IDataView ValidationSet; - //Process last incomplete batch - if(batchIndex > 0) - { - featureTensorShape[0] = batchIndex; - featureBatchSizeInBytes = sizeof(float) * featureLength * batchIndex; - labelTensorShape[0] = batchIndex; - labelBatchSizeInBytes = sizeof(long) * batchIndex; - var outputTensors = validationEvalRunner - .AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0) - .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1) - .Run(); + /// + /// Indicates the file path to store trainset bottleneck values for caching. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the file path to store trainset bottleneck values for caching.", SortOrder = 15)] + public string TrainSetBottleneckCachedValuesFilePath = "trainSetBottleneckFile.csv"; - outputTensors[0].ToScalar(ref accuracy); - metrics.Train.Accuracy += accuracy; - metrics.Train.BatchProcessedCount += 1; - batchIndex = 0; + /// + /// Indicates the file path to store validationset bottleneck values for caching. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the file path to store validationset bottleneck values for caching.", SortOrder = 15)] + public string ValidationSetBottleneckCachedValuesFilePath = "validationSetBottleneckFile.csv"; - featureTensorShape[0] = batchSize; - featureBatchSizeInBytes = sizeof(float) * featureBatch.Length; - labelTensorShape[0] = batchSize; - labelBatchSizeInBytes = sizeof(long) * batchSize; + /// + /// A class that performs learning rate scheduling. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "A class that performs learning rate scheduling.", SortOrder = 15)] + public LearningRateScheduler LearningRateScheduler; + } - outputTensors[0].Dispose(); - } + /// Return the type of prediction task. + private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification; - if (statisticsCallback != null) - { - metrics.Train.Epoch = epoch; - metrics.Train.Accuracy /= metrics.Train.BatchProcessedCount; - metrics.Train.DatasetUsed = ImageClassificationMetrics.Dataset.Validation; - statisticsCallback(metrics); - } - } + private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false); - //Early stopping check - if (options.EarlyStoppingCriteria != null) - { - if (options.EarlyStoppingCriteria.ShouldStop(metrics.Train)) - break; - } - } + /// + /// Auxiliary information about the trainer in terms of its capabilities + /// and requirements. + /// + public override TrainerInfo Info => _info; - trainSaver.save(_session, _checkpointPath); - UpdateTransferLearningModelOnDisk(options, _classCount); - } + private readonly Options _options; + private Session _session; + private Operation _trainStep; + private Tensor _bottleneckTensor; + private Tensor _learningRateInput; + private Tensor _softMaxTensor; + private Tensor _crossEntropy; + private Tensor _labelTensor; + private Tensor _evaluationStep; + private Tensor _prediction; + private Tensor _bottleneckInput; + private Tensor _jpegData; + private Tensor _resizedImage; + private string _jpegDataTensorName; + private string _resizedImageTensorName; + private string _inputTensorName; + private string _softmaxTensorName; + private readonly string _checkpointPath; + private readonly string _bottleneckOperationName; + private readonly bool _useLRScheduling; + private int _classCount; + private Graph Graph => _session.graph; - private (Session, Tensor, Tensor, Tensor) BuildEvaluationSession(ImageClassificationEstimator.Options options, int classCount) + /// + /// Initializes a new instance of + /// + /// The environment to use. + /// The name of the label column. + /// The name of the feature column. + /// The name of score column. + /// The name of the predicted label column. + /// The validation set used while training to improve model quality. + internal ImageClassificationTrainer(IHostEnvironment env, + string labelColumn = DefaultColumnNames.Label, + string featureColumn = DefaultColumnNames.Features, + string scoreColumn = DefaultColumnNames.Score, + string predictedLabelColumn = DefaultColumnNames.PredictedLabel, + IDataView validationSet = null) + : this(env, new Options() + { + FeatureColumnName = featureColumn, + LabelColumnName = labelColumn, + ScoreColumnName = scoreColumn, + PredictedLabelColumnName = predictedLabelColumn, + ValidationSet = validationSet + }) { - var evalGraph = DnnUtils.LoadMetaGraph(ModelLocation[options.Arch]); - var evalSess = tf.Session(graph: evalGraph); - Tensor evaluationStep = null; - Tensor prediction = null; - Tensor bottleneckTensor = evalGraph.OperationByName(_bottleneckOperationName); - evalGraph.as_default(); - var (_, _, groundTruthInput, finalTensor) = AddFinalRetrainOps(classCount, options.LabelColumnName, - options.ScoreColumnName, bottleneckTensor, false, (options.LearningRateScheduler == null ? false:true), options.LearningRate); - tf.train.Saver().restore(evalSess, _checkpointPath); - (evaluationStep, prediction) = AddEvaluationStep(finalTensor, groundTruthInput); - var imageSize = ImageClassificationEstimator.ImagePreprocessingSize[options.Arch]; - (_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3); - return (evalSess, _labelTensor, evaluationStep, prediction); } - private (Tensor, Tensor) AddEvaluationStep(Tensor resultTensor, Tensor groundTruthTensor) + /// + /// Initializes a new instance of + /// + internal ImageClassificationTrainer(IHostEnvironment env, Options options) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), + new SchemaShape.Column(options.FeatureColumnName, SchemaShape.Column.VectorKind.VariableVector, + NumberDataViewType.Byte, false), + TrainerUtils.MakeU4ScalarColumn(options.LabelColumnName)) { - Tensor evaluationStep = null; - Tensor correctPrediction = null; + Host.CheckValue(options, nameof(options)); + Host.CheckNonEmpty(options.FeatureColumnName, nameof(options.FeatureColumnName)); + Host.CheckNonEmpty(options.LabelColumnName, nameof(options.LabelColumnName)); + Host.CheckNonEmpty(options.ScoreColumnName, nameof(options.ScoreColumnName)); + Host.CheckNonEmpty(options.PredictedLabelColumnName, nameof(options.PredictedLabelColumnName)); + + _options = options; + _session = DnnUtils.LoadDnnModel(env, _options.Arch, true).Session; + _useLRScheduling = _options.LearningRateScheduler != null; + _checkpointPath = _options.ModelSavePath ?? + Path.Combine(Directory.GetCurrentDirectory(), _options.FinalModelPrefix + + ModelLocation[_options.Arch]); + + // Configure bottleneck tensor based on the model. + var arch = _options.Arch; + if (arch == Architecture.ResnetV2101) + { + _bottleneckOperationName = "resnet_v2_101/SpatialSqueeze"; + _inputTensorName = "input"; + } + else if (arch == Architecture.InceptionV3) + { + _bottleneckOperationName = "module_apply_default/hub_output/feature_vector/SpatialSqueeze"; + _inputTensorName = "Placeholder"; + } + else if (arch == Architecture.MobilenetV2) + { + _bottleneckOperationName = "import/MobilenetV2/Logits/Squeeze"; + _inputTensorName = "import/input"; + } + else if (arch == Architecture.ResnetV250) + { + _bottleneckOperationName = "resnet_v2_50/SpatialSqueeze"; + _inputTensorName = "input"; + } + } - tf_with(tf.name_scope("accuracy"), scope => + private void InitializeTrainingGraph(IDataView input) + { + var labelColumn = input.Schema.GetColumnOrNull(_options.LabelColumnName).Value; + var labelType = labelColumn.Type; + var labelCount = labelType.GetKeyCount(); + if (labelCount <= 0) { - tf_with(tf.name_scope("correct_prediction"), delegate - { - _prediction = tf.argmax(resultTensor, 1); - correctPrediction = tf.equal(_prediction, groundTruthTensor); - }); + throw Host.ExceptSchemaMismatch(nameof(input.Schema), "label", (string)labelColumn.Name, "Key", + (string)labelType.ToString()); + } - tf_with(tf.name_scope("accuracy"), delegate - { - evaluationStep = tf.reduce_mean(tf.cast(correctPrediction, tf.float32)); - }); - }); + _classCount = labelCount == 1 ? 2 : (int)labelCount; + var imageSize = ImagePreprocessingSize[_options.Arch]; + (_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3); + _jpegDataTensorName = _jpegData.name; + _resizedImageTensorName = _resizedImage.name; - tf.summary.scalar("accuracy", evaluationStep); - return (evaluationStep, _prediction); - } + // Add transfer learning layer. + AddTransferLearningLayer(_options.LabelColumnName, _options.ScoreColumnName, _options.LearningRate, + _useLRScheduling, _classCount); - private void UpdateTransferLearningModelOnDisk(ImageClassificationEstimator.Options options, int classCount) - { - var (sess, _, _, _) = BuildEvaluationSession(options, classCount); - var graph = sess.graph; - var outputGraphDef = tf.graph_util.convert_variables_to_constants( - sess, graph.as_graph_def(), new string[] { _softMaxTensor.name.Split(':')[0], _prediction.name.Split(':')[0], _jpegData.name.Split(':')[0], _resizedImage.name.Split(':')[0] }); + // Initialize the variables. + new Runner(_session).AddOperation(tf.global_variables_initializer()).Run(); - string frozenModelPath = _checkpointPath + ".pb"; - File.WriteAllBytes(_checkpointPath + ".pb", outputGraphDef.ToByteArray()); - _session.graph.Dispose(); - _session.Dispose(); - _session = LoadTFSessionByModelFilePath(_env, frozenModelPath, false); + // Add evaluation layer. + (_evaluationStep, _) = AddEvaluationStep(_softMaxTensor, _labelTensor); + _softmaxTensorName = _softMaxTensor.name; } - private void VariableSummaries(RefVariable var) + private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { - tf_with(tf.name_scope("summaries"), delegate + bool success = inputSchema.TryFindColumn(_options.LabelColumnName, out _); + Contracts.Assert(success); + var metadata = new List(); + metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, + TextDataViewType.Instance, false)); + + return new[] { - var mean = tf.reduce_mean(var); - tf.summary.scalar("mean", mean); - Tensor stddev = null; - tf_with(tf.name_scope("stddev"), delegate - { - stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))); - }); - tf.summary.scalar("stddev", stddev); - tf.summary.scalar("max", tf.reduce_max(var)); - tf.summary.scalar("min", tf.reduce_min(var)); - tf.summary.histogram("histogram", var); - }); + new SchemaShape.Column(_options.ScoreColumnName, SchemaShape.Column.VectorKind.Vector, + NumberDataViewType.Single, false), + new SchemaShape.Column(_options.PredictedLabelColumnName, SchemaShape.Column.VectorKind.Scalar, + NumberDataViewType.UInt32, true, new SchemaShape(metadata.ToArray())) + }; } - private (Operation, Tensor, Tensor, Tensor) AddFinalRetrainOps(int classCount, string labelColumn, - string scoreColumnName, Tensor bottleneckTensor, bool isTraining, bool useLearningRateScheduler, float learningRate) + private protected override MulticlassPredictionTransformer MakeTransformer( + ImageClassificationModelParameters model, DataViewSchema trainSchema) + => new MulticlassPredictionTransformer(Host, model, trainSchema, + FeatureColumn.Name, LabelColumn.Name, _options.ScoreColumnName, _options.PredictedLabelColumnName); + + private protected override ImageClassificationModelParameters TrainModelCore(TrainContext trainContext) { - var bottleneckTensorDims = bottleneckTensor.TensorShape.dims; - var (batch_size, bottleneck_tensor_size) = (bottleneckTensorDims[0], bottleneckTensorDims[1]); - tf_with(tf.name_scope("input"), scope => + InitializeTrainingGraph(trainContext.TrainingSet.Data); + CheckTrainingParameters(_options); + var validationSet = trainContext.ValidationSet?.Data ?? _options.ValidationSet; + var imageProcessor = new ImageProcessor(_session, _jpegDataTensorName, _resizedImageTensorName); + int trainingsetSize = -1; + if (!_options.ReuseTrainSetBottleneckCachedValues || + !File.Exists(_options.TrainSetBottleneckCachedValuesFilePath)) { - if (isTraining) - { - _bottleneckInput = tf.placeholder_with_default( - bottleneckTensor, - shape: bottleneckTensorDims, - name: "BottleneckInputPlaceholder"); - if (useLearningRateScheduler) - _learningRateInput = tf.placeholder(tf.float32, null, name: "learningRateInputPlaceholder"); + trainingsetSize = CacheFeaturizedImagesToDisk(trainContext.TrainingSet.Data, _options.LabelColumnName, + _options.FeatureColumnName, imageProcessor, + _inputTensorName, _bottleneckTensor.name, _options.TrainSetBottleneckCachedValuesFilePath, + ImageClassificationMetrics.Dataset.Train, _options.MetricsCallback); - } - _labelTensor = tf.placeholder(tf.int64, new TensorShape(batch_size), name: labelColumn); - }); + // Write training set size to a file for use during training + File.WriteAllText("TrainingSetSize.txt", trainingsetSize.ToString()); + } - string layerName = "final_retrain_ops"; - Tensor logits = null; - tf_with(tf.name_scope(layerName), scope => + if (validationSet != null && + (!_options.ReuseTrainSetBottleneckCachedValues || + !File.Exists(_options.ValidationSetBottleneckCachedValuesFilePath))) { - RefVariable layerWeights = null; - tf_with(tf.name_scope("weights"), delegate - { - var initialValue = tf.truncated_normal(new int[] { bottleneck_tensor_size, classCount }, stddev: 0.001f); - layerWeights = tf.Variable(initialValue, name: "final_weights"); - VariableSummaries(layerWeights); - }); - - RefVariable layerBiases = null; - tf_with(tf.name_scope("biases"), delegate - { - TensorShape shape = new TensorShape(classCount); - layerBiases = tf.Variable(tf.zeros(shape), name: "final_biases"); - VariableSummaries(layerBiases); - }); + CacheFeaturizedImagesToDisk(validationSet, _options.LabelColumnName, + _options.FeatureColumnName, imageProcessor, _inputTensorName, _bottleneckTensor.name, + _options.ValidationSetBottleneckCachedValuesFilePath, + ImageClassificationMetrics.Dataset.Validation, _options.MetricsCallback); + } - tf_with(tf.name_scope("Wx_plus_b"), delegate - { - var matmul = tf.matmul(isTraining ? _bottleneckInput : bottleneckTensor, layerWeights); - logits = matmul + layerBiases; - tf.summary.histogram("pre_activations", logits); - }); - }); + TrainAndEvaluateClassificationLayer(_options.TrainSetBottleneckCachedValuesFilePath, _options, + _options.ValidationSetBottleneckCachedValuesFilePath, trainingsetSize); - _softMaxTensor = tf.nn.softmax(logits, name: scoreColumnName); + // Leave the ownership of _session so that it is not disposed/closed when this object goes out of scope + // since it will be used by ImageClassificationModelParameters class (new owner that will take care of + // disposing). + var session = _session; + _session = null; + return new ImageClassificationModelParameters(Host, session, _classCount, _jpegDataTensorName, + _resizedImageTensorName, _inputTensorName, _softmaxTensorName); + } - tf.summary.histogram("activations", _softMaxTensor); - if (!isTraining) - return (null, null, _labelTensor, _softMaxTensor); + private void CheckTrainingParameters(Options options) + { + Host.CheckNonWhiteSpace(options.LabelColumnName, nameof(options.LabelColumnName)); - Tensor crossEntropyMean = null; - tf_with(tf.name_scope("cross_entropy"), delegate + if (_session.graph.OperationByName(_labelTensor.name.Split(':')[0]) == null) { - crossEntropyMean = tf.losses.sparse_softmax_cross_entropy( - labels: _labelTensor, logits: logits); - }); - - tf.summary.scalar("cross_entropy", crossEntropyMean); + throw Host.ExceptParam(nameof(_labelTensor.name), $"'{_labelTensor.name}' does not" + + $"exist in the model"); + } - tf_with(tf.name_scope("train"), delegate + if (options.EarlyStoppingCriteria != null && options.ValidationSet == null && + options.TestOnTrainSet == false) { - var optimizer = useLearningRateScheduler ? tf.train.GradientDescentOptimizer(_learningRateInput) : tf.train.GradientDescentOptimizer(learningRate); - _trainStep = optimizer.minimize(crossEntropyMean); - }); + throw Host.ExceptParam(nameof(options.EarlyStoppingCriteria), $"Early stopping enabled but unable to" + + $"find a validation set and/or train set testing disabled. Please disable early stopping " + + $"or either provide a validation set or enable train set training."); + } - return (_trainStep, crossEntropyMean, _labelTensor, _softMaxTensor); } - private void AddTransferLearningLayer(string labelColumn, - string scoreColumnName, float learningRate, bool useLearningRateScheduling, int classCount) + private (Tensor, Tensor) AddJpegDecoding(int height, int width, int depth) { - _bottleneckTensor = Graph.OperationByName(_bottleneckOperationName); - (_trainStep, _crossEntropy, _labelTensor, _softMaxTensor) = - AddFinalRetrainOps(classCount, labelColumn, scoreColumnName, _bottleneckTensor, true, useLearningRateScheduling, learningRate); - + // height, width, depth + var inputDim = (height, width, depth); + var jpegData = tf.placeholder(tf.@string, name: "DecodeJPGInput"); + var decodedImage = tf.image.decode_jpeg(jpegData, channels: inputDim.Item3); + // Convert from full range of uint8 to range [0,1] of float32. + var decodedImageAsFloat = tf.image.convert_image_dtype(decodedImage, tf.float32); + var decodedImage4d = tf.expand_dims(decodedImageAsFloat, 0); + var resizeShape = tf.stack(new int[] { inputDim.Item1, inputDim.Item2 }); + var resizeShapeAsInt = tf.cast(resizeShape, dtype: tf.int32); + var resizedImage = tf.image.resize_bilinear(decodedImage4d, resizeShapeAsInt, false, "ResizeTensor"); + return (jpegData, resizedImage); } - // Factory method for SignatureLoadDataTransform. - private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - => Create(env, ctx).MakeDataTransform(input); + private static Tensor EncodeByteAsString(VBuffer buffer) + { + int length = buffer.Length; + var size = c_api.TF_StringEncodedSize((UIntPtr)length); + var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); + + IntPtr tensor = c_api.TF_TensorData(handle); + Marshal.WriteInt64(tensor, 0); + + var status = new Status(); + unsafe + { + fixed (byte* src = buffer.GetValues()) + c_api.TF_StringEncode(src, (UIntPtr)length, (sbyte*)(tensor + sizeof(Int64)), size, status); + } - // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) - => Create(env, ctx).MakeRowMapper(inputSchema); + status.Check(true); + status.Dispose(); + return new Tensor(handle); + } - private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, - out string[] outputs, out bool addBatchDimensionInput, - out string labelColumn, out string checkpointName, out Architecture arch, - out string scoreColumnName, out string predictedColumnName, out float learningRate, out bool useLearningRateScheduling, out int classCount, - out string[] keyValueAnnotations, out string predictionTensorName, out string softMaxTensorName, - out string jpegDataTensorName, out string resizeTensorName) + internal sealed class ImageProcessor { - addBatchDimensionInput = ctx.Reader.ReadBoolByte(); - - var numInputs = ctx.Reader.ReadInt32(); - env.CheckDecode(numInputs > 0); - inputs = new string[numInputs]; - for (int j = 0; j < inputs.Length; j++) - inputs[j] = ctx.LoadNonEmptyString(); - - var numOutputs = ctx.Reader.ReadInt32(); - env.CheckDecode(numOutputs > 0); - outputs = new string[numOutputs]; - for (int j = 0; j < outputs.Length; j++) - outputs[j] = ctx.LoadNonEmptyString(); - - labelColumn = ctx.Reader.ReadString(); - checkpointName = ctx.Reader.ReadString(); - arch = (Architecture)ctx.Reader.ReadInt32(); - scoreColumnName = ctx.Reader.ReadString(); - predictedColumnName = ctx.Reader.ReadString(); - learningRate = ctx.Reader.ReadFloat(); - useLearningRateScheduling = ctx.Reader.ReadBoolByte(); - classCount = ctx.Reader.ReadInt32(); - - env.CheckDecode(classCount > 0); - keyValueAnnotations = new string[classCount]; - for (int j = 0; j < keyValueAnnotations.Length; j++) - keyValueAnnotations[j] = ctx.LoadNonEmptyString(); - - predictionTensorName = ctx.Reader.ReadString(); - softMaxTensorName = ctx.Reader.ReadString(); - jpegDataTensorName = ctx.Reader.ReadString(); - resizeTensorName = ctx.Reader.ReadString(); - } + private Runner _imagePreprocessingRunner; + + public ImageProcessor(Session session, string jpegDataTensorName, string resizeImageTensorName) + { + _imagePreprocessingRunner = new Runner(session); + _imagePreprocessingRunner.AddInput(jpegDataTensorName); + _imagePreprocessingRunner.AddOutputs(resizeImageTensorName); + } - internal ImageClassificationTransformer(IHostEnvironment env, Session session, string[] outputColumnNames, - string[] inputColumnNames, - bool? addBatchDimensionInput, int batchSize, string labelColumnName, string finalModelPrefix, Architecture arch, - string scoreColumnName, string predictedLabelColumnName, float learningRate, bool useLearningRateScheduling, string modelSavePath, - DataViewSchema inputSchema, int? classCount = null, bool loadModel = false, - string predictionTensorName = null, string softMaxTensorName = null, string jpegDataTensorName = null, string resizeTensorName = null, string[] labelAnnotations = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageClassificationTransformer))) + public Tensor ProcessImage(in VBuffer imageBuffer) + { + var imageTensor = EncodeByteAsString(imageBuffer); + var processedTensor = _imagePreprocessingRunner.AddInput(imageTensor, 0).Run()[0]; + imageTensor.Dispose(); + return processedTensor; + } + } + private int CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName, string imageColumnName, + ImageProcessor imageProcessor, string inputTensorName, string outputTensorName, string cacheFilePath, + ImageClassificationMetrics.Dataset dataset, Action metricsCallback) { - Host.CheckValue(session, nameof(session)); - Host.CheckNonEmpty(inputColumnNames, nameof(inputColumnNames)); - Host.CheckNonEmpty(outputColumnNames, nameof(outputColumnNames)); + var labelColumn = input.Schema[labelColumnName]; - _env = env; - _session = session; - _addBatchDimensionInput = addBatchDimensionInput ?? arch == Architecture.ResnetV2101; - _inputs = inputColumnNames; - _outputs = outputColumnNames; - _labelColumnName = labelColumnName; - _finalModelPrefix = finalModelPrefix; - _arch = arch; - _scoreColumnName = scoreColumnName; - _predictedLabelColumnName = predictedLabelColumnName; - _learningRate = learningRate; - _softmaxTensorName = softMaxTensorName; - _predictionTensorName = predictionTensorName; - _jpegDataTensorName = jpegDataTensorName; - _resizedImageTensorName = resizeTensorName; - _useLRScheduling = useLearningRateScheduling; - - if (classCount == null) + if (labelColumn.Type.RawType != typeof(UInt32)) + throw Host.ExceptSchemaMismatch(nameof(labelColumn), "Label", + labelColumnName, typeof(uint).ToString(), + labelColumn.Type.RawType.ToString()); + + var imageColumn = input.Schema[imageColumnName]; + Runner runner = new Runner(_session); + runner.AddOutputs(outputTensorName); + int datasetsize = 0; + using (TextWriter writer = File.CreateText(cacheFilePath)) + using (var cursor = input.GetRowCursor( + input.Schema.Where(c => c.Index == labelColumn.Index || c.Index == imageColumn.Index))) { - var labelColumn = inputSchema.GetColumnOrNull(labelColumnName).Value; - var labelType = labelColumn.Type; - var labelCount = labelType.GetKeyCount(); - if (labelCount <= 0) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", (string)labelColumn.Name, "Key", (string)labelType.ToString()); + var labelGetter = cursor.GetGetter(labelColumn); + var imageGetter = cursor.GetGetter>(imageColumn); + UInt32 label = UInt32.MaxValue; + VBuffer image = default; + runner.AddInput(inputTensorName); + ImageClassificationMetrics metrics = new ImageClassificationMetrics(); + metrics.Bottleneck = new BottleneckMetrics(); + metrics.Bottleneck.DatasetUsed = dataset; + float[] imageArray = null; + while (cursor.MoveNext()) + { + labelGetter(ref label); + imageGetter(ref image); + if (image.Length <= 0) + continue; //Empty Image - _classCount = labelCount == 1 ? 2 : (int)labelCount; + var imageTensor = imageProcessor.ProcessImage(image); + runner.AddInput(imageTensor, 0); + var featurizedImage = runner.Run()[0]; // Reuse memory + featurizedImage.ToArray(ref imageArray); + Host.Assert((int)featurizedImage.size == imageArray.Length); + writer.WriteLine(label - 1 + "," + string.Join(",", imageArray)); + featurizedImage.Dispose(); + imageTensor.Dispose(); + metrics.Bottleneck.Index++; + metricsCallback?.Invoke(metrics); + } + datasetsize = metrics.Bottleneck.Index; } - else - _classCount = classCount.Value; + return datasetsize; + } + + private IDataView GetShuffledData(string path) + { + return new RowShufflingTransformer( + Host, + new RowShufflingTransformer.Options + { + ForceShuffle = true, + ForceShuffleSource = true + }, + new TextLoader( + Host, + new TextLoader.Options + { + Separators = new[] { ',' }, + Columns = new[] + { + new Column("Label", DataKind.Int64, 0), + new Column("Features", DataKind.Single, new [] { new Range(1, null) }), + }, + }, + new MultiFileSource(path)) + .Load(new MultiFileSource(path))); + } + + private int GetNumSamples(string path) + { + using var reader = File.OpenText(path); + return int.Parse(reader.ReadLine()); + } + + private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath, Options options, + string validationSetBottleneckFilePath, int trainingsetSize) + { + int batchSize = options.BatchSize; + int epochs = options.Epoch; + float learningRate = options.LearningRate; + bool evaluateOnly = !string.IsNullOrEmpty(validationSetBottleneckFilePath); + Action statisticsCallback = _options.MetricsCallback; + var trainingSet = GetShuffledData(trainBottleneckFilePath); + IDataView validationSet = null; + if (options.ValidationSet != null && !string.IsNullOrEmpty(validationSetBottleneckFilePath)) + validationSet = GetShuffledData(validationSetBottleneckFilePath); - _checkpointPath = modelSavePath != null ? modelSavePath : Path.Combine(Directory.GetCurrentDirectory(), finalModelPrefix + ModelLocation[arch]); + long label = long.MaxValue; + VBuffer features = default; + ReadOnlySpan featureValues = default; + var featureColumn = trainingSet.Schema[1]; + int featureLength = featureColumn.Type.GetVectorSize(); + float[] featureBatch = new float[featureLength * batchSize]; + var featureBatchHandle = GCHandle.Alloc(featureBatch, GCHandleType.Pinned); + IntPtr featureBatchPtr = featureBatchHandle.AddrOfPinnedObject(); + int featureBatchSizeInBytes = sizeof(float) * featureBatch.Length; + long[] labelBatch = new long[batchSize]; + var labelBatchHandle = GCHandle.Alloc(labelBatch, GCHandleType.Pinned); + IntPtr labelBatchPtr = labelBatchHandle.AddrOfPinnedObject(); + int labelBatchSizeInBytes = sizeof(long) * labelBatch.Length; + var labelTensorShape = _labelTensor.TensorShape.dims.Select(x => (long)x).ToArray(); + labelTensorShape[0] = batchSize; + int batchIndex = 0; + var runner = new Runner(_session); + var testEvalRunner = new Runner(_session); + testEvalRunner.AddOutputs(_evaluationStep.name); + testEvalRunner.AddOutputs(_crossEntropy.name); - // Configure bottleneck tensor based on the model. - if (arch == ImageClassificationEstimator.Architecture.ResnetV2101) - { - _bottleneckOperationName = "resnet_v2_101/SpatialSqueeze"; - _inputTensorName = "input"; - } - else if (arch == ImageClassificationEstimator.Architecture.InceptionV3) - { - _bottleneckOperationName = "module_apply_default/hub_output/feature_vector/SpatialSqueeze"; - _inputTensorName = "Placeholder"; - } - else if(arch == ImageClassificationEstimator.Architecture.MobilenetV2) - { - _bottleneckOperationName = "import/MobilenetV2/Logits/Squeeze"; - _inputTensorName = "import/input"; - } - else if (arch == ImageClassificationEstimator.Architecture.ResnetV250) + Runner validationEvalRunner = null; + if (validationSet != null) { - _bottleneckOperationName = "resnet_v2_50/SpatialSqueeze"; - _inputTensorName = "input"; + validationEvalRunner = new Runner(_session); + validationEvalRunner.AddOutputs(_evaluationStep.name); + validationEvalRunner.AddInput(_bottleneckInput.name).AddInput(_labelTensor.name); } - _outputs = new[] { scoreColumnName, predictedLabelColumnName }; + runner.AddOperation(_trainStep); + var featureTensorShape = _bottleneckInput.TensorShape.dims.Select(x => (long)x).ToArray(); + featureTensorShape[0] = batchSize; - if (loadModel == false) - { - var imageSize = ImageClassificationEstimator.ImagePreprocessingSize[arch]; - (_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3); - _jpegDataTensorName = _jpegData.name; - _resizedImageTensorName = _resizedImage.name; + Saver trainSaver = null; + FileWriter trainWriter = null; + Tensor merged = tf.summary.merge_all(); + trainWriter = tf.summary.FileWriter(Path.Combine(Directory.GetCurrentDirectory(), "train"), + _session.graph); - // Add transfer learning layer. - AddTransferLearningLayer(labelColumnName, scoreColumnName, learningRate, useLearningRateScheduling,_classCount); + trainSaver = tf.train.Saver(); + trainSaver.save(_session, _checkpointPath); - // Initialize the variables. - new Runner(_session).AddOperation(tf.global_variables_initializer()).Run(); + runner.AddInput(_bottleneckInput.name).AddInput(_labelTensor.name); + if (options.LearningRateScheduler != null) + runner.AddInput(_learningRateInput.name); + testEvalRunner.AddInput(_bottleneckInput.name).AddInput(_labelTensor.name); + Dictionary classStatsTrain = new Dictionary(); + Dictionary classStatsValidate = new Dictionary(); + for (int index = 0; index < _classCount; index += 1) + classStatsTrain[index] = classStatsValidate[index] = 0; - // Add evaluation layer. - (_evaluationStep, _) = AddEvaluationStep(_softMaxTensor, _labelTensor); - _softmaxTensorName = _softMaxTensor.name; - _predictionTensorName = _prediction.name; + ImageClassificationMetrics metrics = new ImageClassificationMetrics(); + metrics.Train = new TrainMetrics(); + float accuracy = 0; + float crossentropy = 0; + TrainState trainstate = new TrainState + { + BatchSize = options.BatchSize, + BatchesPerEpoch = + (trainingsetSize < 0 ? GetNumSamples("TrainingSetSize.txt") : trainingsetSize) / options.BatchSize + }; - // Add annotations as key values, if they exist. - VBuffer> keysVBuffer = default; - if (inputSchema[labelColumnName].HasKeyValues()) - { - inputSchema[labelColumnName].GetKeyValues(ref keysVBuffer); - _keyValueAnnotations = keysVBuffer.DenseValues().ToArray(); - } - else - { - _keyValueAnnotations = Enumerable.Range(0, _classCount).Select(x => x.ToString().AsMemory()).ToArray(); - } - } - else + for (int epoch = 0; epoch < epochs; epoch += 1) { - // Load annotations as key values, if they exist - if (labelAnnotations != null) - _keyValueAnnotations = labelAnnotations.Select(v => v.AsMemory()).ToArray(); - } - } + batchIndex = 0; + metrics.Train.Accuracy = 0; + metrics.Train.CrossEntropy = 0; + metrics.Train.BatchProcessedCount = 0; + metrics.Train.LearningRate = learningRate; + // Update train state. + trainstate.CurrentEpoch = epoch; + using (var cursor = trainingSet.GetRowCursor(trainingSet.Schema.ToArray(), new Random())) + { + var labelGetter = cursor.GetGetter(trainingSet.Schema[0]); + var featuresGetter = cursor.GetGetter>(featureColumn); + while (cursor.MoveNext()) + { + labelGetter(ref label); + featuresGetter(ref features); + classStatsTrain[label]++; - private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema); + if (featureValues == default) + featureValues = features.GetValues(); - private protected override void SaveModel(ModelSaveContext ctx) - { - // *** Binary format *** - // byte: indicator for frozen models - // byte: indicator for adding batch dimension in input - // int: number of input columns - // for each input column - // int: id of int column name - // int: number of output columns - // for each output column - // int: id of output column name - // string: value of label column name - // string: prefix pf final model and checkpoint files/folder for storing graph files - // int: value of the utilized model architecture for transfer learning - // string: value of score column name - // string: value of predicted label column name - // float: value of learning rate - // int: number of prediction classes - // for each key value annotation column - // string: value of key value annotations - // string: name of prediction tensor - // string: name of softmax tensor - // string: name of JPEG data tensor - // string: name of resized image tensor - // stream (byte): tensorFlow model. - - Host.AssertValue(ctx); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - ctx.Writer.WriteBoolByte(_addBatchDimensionInput); - - Host.AssertNonEmpty(_inputs); - ctx.Writer.Write(_inputs.Length); - foreach (var colName in _inputs) - ctx.SaveNonEmptyString(colName); - - Host.AssertNonEmpty(_outputs); - ctx.Writer.Write(_outputs.Length); - foreach (var colName in _outputs) - ctx.SaveNonEmptyString(colName); - - ctx.Writer.Write(_labelColumnName); - ctx.Writer.Write(_finalModelPrefix); - ctx.Writer.Write((int)_arch); - ctx.Writer.Write(_scoreColumnName); - ctx.Writer.Write(_predictedLabelColumnName); - ctx.Writer.Write(_learningRate); - ctx.Writer.Write(_useLRScheduling); - ctx.Writer.Write(_classCount); + // Buffer the values. + for (int index = 0; index < featureLength; index += 1) + featureBatch[batchIndex * featureLength + index] = featureValues[index]; - Host.AssertNonEmpty(_keyValueAnnotations); - Host.Assert(_keyValueAnnotations.Length == _classCount); - for (int j = 0; j < _classCount; j++) - ctx.SaveNonEmptyString(_keyValueAnnotations[j]); + labelBatch[batchIndex] = label; + batchIndex += 1; + trainstate.CurrentBatchIndex = batchIndex; + // Train. + if (batchIndex == batchSize) + { + runner.AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, + featureBatchSizeInBytes), 0) + .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, + labelBatchSizeInBytes), 1); - ctx.Writer.Write(_predictionTensorName); - ctx.Writer.Write(_softmaxTensorName); - ctx.Writer.Write(_jpegDataTensorName); - ctx.Writer.Write(_resizedImageTensorName); - Status status = new Status(); - var buffer = _session.graph.ToGraphDef(status); - ctx.SaveBinaryStream("TFModel", w => - { - w.WriteByteArray(buffer.MemoryBlock.ToArray()); - }); - status.Check(true); - } + if (options.LearningRateScheduler != null) + { + // Add learning rate as a placeholder only when learning rate scheduling is used. + learningRate = options.LearningRateScheduler.GetLearningRate(trainstate); + metrics.Train.LearningRate = learningRate; + runner.AddInput(new Tensor(learningRate, TF_DataType.TF_FLOAT), 2); + } + runner.Run(); - ~ImageClassificationTransformer() - { - Dispose(false); - } + metrics.Train.BatchProcessedCount += 1; + if (options.TestOnTrainSet && statisticsCallback != null) + { + var outputTensors = testEvalRunner + .AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, + featureBatchSizeInBytes), 0) + .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, + labelBatchSizeInBytes), 1) + .Run(); - private void Dispose(bool disposing) - { - // Ensure that the Session is not null and it's handle is not Zero, as it may have already been disposed/finalized. - // Technically we shouldn't be calling this if disposing == false, since we're running in finalizer - // and the GC doesn't guarantee ordering of finalization of managed objects, but we have to make sure - // that the Session is closed before deleting our temporary directory. - if (_session != null && _session != IntPtr.Zero) - { - _session.close(); - } - } + outputTensors[0].ToScalar(ref accuracy); + outputTensors[1].ToScalar(ref crossentropy); + metrics.Train.Accuracy += accuracy; + metrics.Train.CrossEntropy += crossentropy; - private sealed class Mapper : MapperBase - { - private readonly ImageClassificationTransformer _parent; - private readonly int[] _inputColIndices; + outputTensors[0].Dispose(); + outputTensors[1].Dispose(); + } - public Mapper(ImageClassificationTransformer parent, DataViewSchema inputSchema) : - base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent) - { - Host.CheckValue(parent, nameof(parent)); - _parent = parent; - _inputColIndices = new int[1]; - if (!inputSchema.TryGetColumnIndex(_parent._inputs[0], out _inputColIndices[0])) - throw Host.ExceptSchemaMismatch(nameof(InputSchema), "source", _parent._inputs[0]); - } + batchIndex = 0; + } + } - private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); + //Process last incomplete batch + if (batchIndex > 0) + { + featureTensorShape[0] = batchIndex; + featureBatchSizeInBytes = sizeof(float) * featureLength * batchIndex; + labelTensorShape[0] = batchIndex; + labelBatchSizeInBytes = sizeof(long) * batchIndex; + runner.AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, + featureBatchSizeInBytes), 0) + .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, + labelBatchSizeInBytes), 1); + if (options.LearningRateScheduler != null) + { + // Add learning rate as a placeholder only when learning rate scheduling is used. + learningRate = options.LearningRateScheduler.GetLearningRate(trainstate); + metrics.Train.LearningRate = learningRate; + runner.AddInput(new Tensor(learningRate, TF_DataType.TF_FLOAT), 2); + } + runner.Run(); - private class OutputCache - { - public long Position; - private ValueGetter> _imageGetter; - private VBuffer _image; - private Runner _runner; - private ImageProcessor _imageProcessor; - private long _predictedLabel; - public UInt32 PredictedLabel => (uint)_predictedLabel; - private float[] _classProbability; - public float[] ClassProbabilities => _classProbability; - private DataViewRow _inputRow; - - public OutputCache(DataViewRow input, ImageClassificationTransformer transformer) - { - _image = default; - _imageGetter = input.GetGetter>(input.Schema[transformer._inputs[0]]); - _runner = new Runner(transformer._session); - _runner.AddInput(transformer._inputTensorName); - _runner.AddOutputs(transformer._softmaxTensorName); - _runner.AddOutputs(transformer._predictionTensorName); - _imageProcessor = new ImageProcessor(transformer); - _inputRow = input; - Position = -1; - } + metrics.Train.BatchProcessedCount += 1; - public void UpdateCacheIfNeeded() - { - lock (this) - { - if (_inputRow.Position != Position) + if (options.TestOnTrainSet && statisticsCallback != null) { - Position = _inputRow.Position; - _imageGetter(ref _image); - var processedTensor = _imageProcessor.ProcessImage(_image); - var outputTensor = _runner.AddInput(processedTensor, 0).Run(); - outputTensor[0].ToArray(ref _classProbability); - outputTensor[1].ToScalar(ref _predictedLabel); - _predictedLabel += 1; - outputTensor[0].Dispose(); - outputTensor[1].Dispose(); - processedTensor.Dispose(); + var outputTensors = testEvalRunner + .AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, + featureBatchSizeInBytes), 0) + .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, + labelBatchSizeInBytes), 1) + .Run(); + + outputTensors[0].ToScalar(ref accuracy); + outputTensors[1].ToScalar(ref crossentropy); + metrics.Train.Accuracy += accuracy; + metrics.Train.CrossEntropy += crossentropy; + + outputTensors[0].Dispose(); + outputTensors[1].Dispose(); } - } - } - } - protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) - { - disposer = null; - _parent._session.graph.as_default(); - Host.AssertValue(input); - var cache = new OutputCache(input, _parent); + batchIndex = 0; + featureTensorShape[0] = batchSize; + featureBatchSizeInBytes = sizeof(float) * featureBatch.Length; + labelTensorShape[0] = batchSize; + labelBatchSizeInBytes = sizeof(long) * batchSize; + } - if (iinfo == 0) - { - ValueGetter> valuegetter = (ref VBuffer dst) => + if (options.TestOnTrainSet && statisticsCallback != null) { - cache.UpdateCacheIfNeeded(); - var editor = VBufferEditor.Create(ref dst, cache.ClassProbabilities.Length); - new Span(cache.ClassProbabilities, 0, cache.ClassProbabilities.Length).CopyTo(editor.Values); - dst = editor.Commit(); - }; - return valuegetter; + metrics.Train.Epoch = epoch; + metrics.Train.Accuracy /= metrics.Train.BatchProcessedCount; + metrics.Train.CrossEntropy /= metrics.Train.BatchProcessedCount; + metrics.Train.DatasetUsed = ImageClassificationMetrics.Dataset.Train; + statisticsCallback(metrics); + } } - else + + if (validationSet == null) { - ValueGetter valuegetter = (ref UInt32 dst) => + //Early stopping check + if (options.EarlyStoppingCriteria != null) { - cache.UpdateCacheIfNeeded(); - dst = cache.PredictedLabel; - }; - - return valuegetter; + if (options.EarlyStoppingCriteria.ShouldStop(metrics.Train)) + break; + } + continue; } - } - - private protected override Func GetDependenciesCore(Func activeOutput) - { - return col => Enumerable.Range(0, _parent._outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col); - } - protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() - { - var annotationBuilder = new DataViewSchema.Annotations.Builder(); - annotationBuilder.AddKeyValues(_parent._classCount, TextDataViewType.Instance, (ref VBuffer> dst) => + batchIndex = 0; + metrics.Train.BatchProcessedCount = 0; + metrics.Train.Accuracy = 0; + metrics.Train.CrossEntropy = 0; + using (var cursor = validationSet.GetRowCursor(validationSet.Schema.ToArray(), new Random())) { - var editor = VBufferEditor.Create(ref dst, _parent._classCount); - for (int i = 0; i < _parent._classCount; i++) - editor.Values[i] = _parent._keyValueAnnotations[i]; - dst = editor.Commit(); - }); - - var info = new DataViewSchema.DetachedColumn[_parent._outputs.Length]; - info[0] = new DataViewSchema.DetachedColumn(_parent._scoreColumnName, new VectorDataViewType(NumberDataViewType.Single, _parent._classCount), null); - info[1] = new DataViewSchema.DetachedColumn(_parent._predictedLabelColumnName, new KeyDataViewType(typeof(uint), _parent._classCount), annotationBuilder.ToAnnotations()); - return info; - } - } - } - - /// - public sealed class ImageClassificationEstimator : IEstimator - { - /// - /// Image classification model. - /// - public enum Architecture - { - ResnetV2101, - InceptionV3, - MobilenetV2, - ResnetV250 - }; - - /// - /// Dictionary mapping model architecture to model location. - /// - internal static IReadOnlyDictionary ModelLocation = new Dictionary - { - { Architecture.ResnetV2101, @"resnet_v2_101_299.meta" }, - { Architecture.InceptionV3, @"InceptionV3.meta" }, - { Architecture.MobilenetV2, @"mobilenet_v2.meta" }, - { Architecture.ResnetV250, @"resnet_v2_50_299.meta" } - }; - - /// - /// Dictionary mapping model architecture to image input size supported. - /// - internal static IReadOnlyDictionary> ImagePreprocessingSize = new Dictionary> - { - { Architecture.ResnetV2101, new Tuple(299,299) }, - { Architecture.InceptionV3, new Tuple(299,299) }, - { Architecture.MobilenetV2, new Tuple(224,224) }, - { Architecture.ResnetV250, new Tuple(299,299) } - }; - - /// - /// Backend DNN training framework. - /// - public enum DnnFramework - { - Tensorflow - }; - - /// - /// Indicates the metric to be monitored to decide Early Stopping criteria. - /// - public enum EarlyStoppingMetric - { - Accuracy, - Loss - } + var labelGetter = cursor.GetGetter(validationSet.Schema[0]); + var featuresGetter = cursor.GetGetter>(featureColumn); + while (cursor.MoveNext()) + { + labelGetter(ref label); + featuresGetter(ref features); + classStatsValidate[label]++; + // Buffer the values. + for (int index = 0; index < featureLength; index += 1) + featureBatch[batchIndex * featureLength + index] = featureValues[index]; - /// - /// Callback that returns DNN statistics during bottlenack phase and training phase. - /// Train metrics may be null when bottleneck phase is running, so have check! - /// - public delegate void ImageClassificationMetricsCallback(ImageClassificationMetrics metrics); + labelBatch[batchIndex] = label; + batchIndex += 1; + // Evaluate. + if (batchIndex == batchSize) + { + var outputTensors = validationEvalRunner + .AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, + featureBatchSizeInBytes), 0) + .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, + labelBatchSizeInBytes), 1) + .Run(); - /// - /// DNN training metrics. - /// - public sealed class TrainMetrics - { - /// - /// Indicates the dataset on which metrics are being reported. - /// - /// - public ImageClassificationMetrics.Dataset DatasetUsed { get; set; } + outputTensors[0].ToScalar(ref accuracy); + metrics.Train.Accuracy += accuracy; + metrics.Train.BatchProcessedCount += 1; + batchIndex = 0; - /// - /// The number of batches processed in an epoch. - /// - public int BatchProcessedCount { get; set; } + outputTensors[0].Dispose(); + } + } - /// - /// The training epoch index for which this metric is reported. - /// - public int Epoch { get; set; } + //Process last incomplete batch + if (batchIndex > 0) + { + featureTensorShape[0] = batchIndex; + featureBatchSizeInBytes = sizeof(float) * featureLength * batchIndex; + labelTensorShape[0] = batchIndex; + labelBatchSizeInBytes = sizeof(long) * batchIndex; + var outputTensors = validationEvalRunner + .AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, + featureBatchSizeInBytes), 0) + .AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, + labelBatchSizeInBytes), 1) + .Run(); - /// - /// Accuracy of the batch on this . Higher the better. - /// - public float Accuracy { get; set; } + outputTensors[0].ToScalar(ref accuracy); + metrics.Train.Accuracy += accuracy; + metrics.Train.BatchProcessedCount += 1; + batchIndex = 0; - /// - /// Cross-Entropy (loss) of the batch on this . Lower - /// the better. - /// - public float CrossEntropy { get; set; } + featureTensorShape[0] = batchSize; + featureBatchSizeInBytes = sizeof(float) * featureBatch.Length; + labelTensorShape[0] = batchSize; + labelBatchSizeInBytes = sizeof(long) * batchSize; - /// - /// Learning Rate used for this . Changes for learning rate scheduling. - /// - public float LearningRate { get; set; } + outputTensors[0].Dispose(); + } - /// - /// String representation of the metrics. - /// - public override string ToString() - { - if (DatasetUsed == ImageClassificationMetrics.Dataset.Train) - return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, Learning Rate: {LearningRate,10} " + - $"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}, Cross-Entropy: {CrossEntropy,10}"; - else - return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, " + - $"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}"; + if (statisticsCallback != null) + { + metrics.Train.Epoch = epoch; + metrics.Train.Accuracy /= metrics.Train.BatchProcessedCount; + metrics.Train.DatasetUsed = ImageClassificationMetrics.Dataset.Validation; + statisticsCallback(metrics); + } + } + + //Early stopping check + if (options.EarlyStoppingCriteria != null) + { + if (options.EarlyStoppingCriteria.ShouldStop(metrics.Train)) + break; + } } + + trainSaver.save(_session, _checkpointPath); + UpdateTransferLearningModelOnDisk(_classCount); } - /// - /// Metrics for image featurization values. The input image is passed through - /// the network and features are extracted from second or last layer to - /// train a custom full connected layer that serves as classifier. - /// - public sealed class BottleneckMetrics + private (Session, Tensor, Tensor, Tensor) BuildEvaluationSession(int classCount) { - /// - /// Indicates the dataset on which metrics are being reported. - /// - /// - public ImageClassificationMetrics.Dataset DatasetUsed { get; set; } - - /// - /// Index of the input image. - /// - public int Index { get; set; } - - /// - /// String representation of the metrics. - /// - public override string ToString() => $"Phase: Bottleneck Computation, Dataset used: {DatasetUsed.ToString(),10}, Image Index: {Index,3}"; + var evalGraph = DnnUtils.LoadMetaGraph(ModelLocation[_options.Arch]); + var evalSess = tf.Session(graph: evalGraph); + Tensor evaluationStep = null; + Tensor prediction = null; + Tensor bottleneckTensor = evalGraph.OperationByName(_bottleneckOperationName); + evalGraph.as_default(); + var (_, _, groundTruthInput, finalTensor) = AddFinalRetrainOps(classCount, _options.LabelColumnName, + _options.ScoreColumnName, bottleneckTensor, false, (_options.LearningRateScheduler == null ? false : true), _options.LearningRate); + tf.train.Saver().restore(evalSess, _checkpointPath); + (evaluationStep, prediction) = AddEvaluationStep(finalTensor, groundTruthInput); + var imageSize = ImagePreprocessingSize[_options.Arch]; + (_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3); + return (evalSess, _labelTensor, evaluationStep, prediction); } - /// - /// Early Stopping feature stops training when monitored quantity stops improving'. - /// Modeled after https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/callbacks.py#L1143 - /// - public sealed class EarlyStopping + private (Tensor, Tensor) AddEvaluationStep(Tensor resultTensor, Tensor groundTruthTensor) { - /// - /// Best value of metric seen so far. - /// - private float _bestMetricValue; - - /// - /// Current counter for number of epochs where there has been no improvement. - /// - private int _wait; - - /// - /// The metric to be monitored (eg Accuracy, Loss). - /// - private EarlyStoppingMetric _metric; + Tensor evaluationStep = null; + Tensor correctPrediction = null; - /// - /// Minimum change in the monitored quantity to be considered as an improvement. - /// - public float MinDelta { get; set; } + tf_with(tf.name_scope("accuracy"), scope => + { + tf_with(tf.name_scope("correct_prediction"), delegate + { + _prediction = tf.argmax(resultTensor, 1); + correctPrediction = tf.equal(_prediction, groundTruthTensor); + }); - /// - /// Number of epochs to wait after no improvement is seen consecutively - /// before stopping the training. - /// - public int Patience { get; set; } + tf_with(tf.name_scope("accuracy"), delegate + { + evaluationStep = tf.reduce_mean(tf.cast(correctPrediction, tf.float32)); + }); + }); - /// - /// Whether the monitored quantity is to be increasing (eg. Accuracy, CheckIncreasing = true) - /// or decreasing (eg. Loss, CheckIncreasing = false). - /// - public bool CheckIncreasing { get; set; } + tf.summary.scalar("accuracy", evaluationStep); + return (evaluationStep, _prediction); + } - /// - /// - /// - /// - public EarlyStopping(float minDelta = 0.01f, int patience = 20, EarlyStoppingMetric metric = EarlyStoppingMetric.Accuracy, bool checkIncreasing = true) - { - _bestMetricValue = 0.0f; - _wait = 0; - _metric = metric; - MinDelta = Math.Abs(minDelta); - Patience = patience; - CheckIncreasing = checkIncreasing; + private void UpdateTransferLearningModelOnDisk(int classCount) + { + var (sess, _, _, _) = BuildEvaluationSession(classCount); + var graph = sess.graph; + var outputGraphDef = tf.graph_util.convert_variables_to_constants( + sess, graph.as_graph_def(), new string[] { _softMaxTensor.name.Split(':')[0], + _prediction.name.Split(':')[0], _jpegData.name.Split(':')[0], _resizedImage.name.Split(':')[0] }); - //Set the CheckIncreasing according to the metric being monitored - if (metric == EarlyStoppingMetric.Accuracy) - CheckIncreasing = true; - else if (metric == EarlyStoppingMetric.Loss) - CheckIncreasing = false; - } + string frozenModelPath = _checkpointPath + ".pb"; + File.WriteAllBytes(_checkpointPath + ".pb", outputGraphDef.ToByteArray()); + _session.graph.Dispose(); + _session.Dispose(); + _session = LoadTFSessionByModelFilePath(Host, frozenModelPath, false); + } - /// - /// To be called at the end of every epoch to check if training should stop. - /// For increasing metric(eg.: Accuracy), if metric stops increasing, stop training if - /// value of metric doesn't increase within 'patience' number of epochs. - /// For decreasing metric(eg.: Loss), stop training if value of metric doesn't decrease - /// within 'patience' number of epochs. - /// Any change in the value of metric of less than 'minDelta' is not considered a change. - /// - public bool ShouldStop(TrainMetrics currentMetrics) + private void VariableSummaries(RefVariable var) + { + tf_with(tf.name_scope("summaries"), delegate { - float currentMetricValue = _metric == EarlyStoppingMetric.Accuracy ? currentMetrics.Accuracy : currentMetrics.CrossEntropy; - - if(CheckIncreasing) - { - if((currentMetricValue- _bestMetricValue) < MinDelta) - { - _wait += 1; - if(_wait >= Patience) - return true; - } - else - { - _wait = 0; - _bestMetricValue = currentMetricValue; - } - } - else + var mean = tf.reduce_mean(var); + tf.summary.scalar("mean", mean); + Tensor stddev = null; + tf_with(tf.name_scope("stddev"), delegate { - if ((_bestMetricValue - currentMetricValue) < MinDelta) - { - _wait += 1; - if (_wait >= Patience) - return true; - } - else - { - _wait = 0; - _bestMetricValue = currentMetricValue; - } - } - return false; - } + stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))); + }); + tf.summary.scalar("stddev", stddev); + tf.summary.scalar("max", tf.reduce_max(var)); + tf.summary.scalar("min", tf.reduce_min(var)); + tf.summary.histogram("histogram", var); + }); } - /// - /// Metrics for image classification bottlenect phase and training. - /// Train metrics may be null when bottleneck phase is running, so have check! - /// - public sealed class ImageClassificationMetrics + private (Operation, Tensor, Tensor, Tensor) AddFinalRetrainOps(int classCount, string labelColumn, + string scoreColumnName, Tensor bottleneckTensor, bool isTraining, bool useLearningRateScheduler, + float learningRate) { - /// - /// Indicates the kind of the dataset of which metric is reported. - /// - public enum Dataset + var bottleneckTensorDims = bottleneckTensor.TensorShape.dims; + var (batch_size, bottleneck_tensor_size) = (bottleneckTensorDims[0], bottleneckTensorDims[1]); + tf_with(tf.name_scope("input"), scope => { - Train, - Validation - }; - - /// - /// Contains train time metrics. - /// - public TrainMetrics Train { get; set; } + if (isTraining) + { + _bottleneckInput = tf.placeholder_with_default( + bottleneckTensor, + shape: bottleneckTensorDims, + name: "BottleneckInputPlaceholder"); + if (useLearningRateScheduler) + _learningRateInput = tf.placeholder(tf.float32, null, name: "learningRateInputPlaceholder"); - /// - /// Contains pre-train time metrics. These contains metrics on image - /// featurization. - /// - public BottleneckMetrics Bottleneck { get; set; } + } + _labelTensor = tf.placeholder(tf.int64, new TensorShape(batch_size), name: labelColumn); + }); - /// - /// String representation of the metrics. - /// - public override string ToString() => Train != null ? Train.ToString() : Bottleneck.ToString(); - } + string layerName = "final_retrain_ops"; + Tensor logits = null; + tf_with(tf.name_scope(layerName), scope => + { + RefVariable layerWeights = null; + tf_with(tf.name_scope("weights"), delegate + { + var initialValue = tf.truncated_normal(new int[] { bottleneck_tensor_size, classCount }, + stddev: 0.001f); - /// - /// The options for the . - /// - public sealed class Options - { - /// - /// The names of the model inputs. - /// - [Argument(ArgumentType.Multiple , HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 1)] - internal string[] InputColumns; + layerWeights = tf.Variable(initialValue, name: "final_weights"); + VariableSummaries(layerWeights); + }); - /// - /// The names of the requested model outputs. - /// - [Argument(ArgumentType.Multiple , HelpText = "The name of the outputs", ShortName = "outputs", SortOrder = 2)] - internal string[] OutputColumns; + RefVariable layerBiases = null; + tf_with(tf.name_scope("biases"), delegate + { + TensorShape shape = new TensorShape(classCount); + layerBiases = tf.Variable(tf.zeros(shape), name: "final_biases"); + VariableSummaries(layerBiases); + }); - /// - /// The names of the model input features. - /// - [Argument(ArgumentType.AtMostOnce | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "features", SortOrder = 1)] - public string FeaturesColumnName; + tf_with(tf.name_scope("Wx_plus_b"), delegate + { + var matmul = tf.matmul(isTraining ? _bottleneckInput : bottleneckTensor, layerWeights); + logits = matmul + layerBiases; + tf.summary.histogram("pre_activations", logits); + }); + }); - /// - /// The name of the label column in that will be mapped to label node in TensorFlow model. - /// - [Argument(ArgumentType.AtMostOnce | ArgumentType.Required, HelpText = "Training labels.", ShortName = "label", SortOrder = 4)] - public string LabelColumnName; + _softMaxTensor = tf.nn.softmax(logits, name: scoreColumnName); - /// - /// Number of samples to use for mini-batch training. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Number of samples to use for mini-batch training.", SortOrder = 9)] - public int BatchSize = 64; + tf.summary.histogram("activations", _softMaxTensor); + if (!isTraining) + return (null, null, _labelTensor, _softMaxTensor); - /// - /// Number of training iterations. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Number of training iterations.", SortOrder = 10)] - public int Epoch = 100; + Tensor crossEntropyMean = null; + tf_with(tf.name_scope("cross_entropy"), delegate + { + crossEntropyMean = tf.losses.sparse_softmax_cross_entropy( + labels: _labelTensor, logits: logits); + }); - /// - /// Learning rate to use during optimization. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate to use during optimization.", SortOrder = 12)] - public float LearningRate = 0.01f; + tf.summary.scalar("cross_entropy", crossEntropyMean); - /// - /// Whether to disable use of early stopping technique. Training will go on for the full epoch count. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to disable use of early stopping technique. Training will go on for the full epoch count.", SortOrder = 15)] - public bool DisableEarlyStopping = false; + tf_with(tf.name_scope("train"), delegate + { + var optimizer = useLearningRateScheduler ? tf.train.GradientDescentOptimizer(_learningRateInput) : + tf.train.GradientDescentOptimizer(learningRate); - /// - /// Early stopping technique parameters to be used to terminate training when training metric stops improving. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Early stopping technique parameters to be used to terminate training when training metric stops improving.", SortOrder = 15)] - public EarlyStopping EarlyStoppingCriteria; + _trainStep = optimizer.minimize(crossEntropyMean); + }); - /// - /// Specifies the model architecture to be used in the case of image classification training using transfer learning. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Model architecture to be used in transfer learning for image classification.", SortOrder = 15)] - public Architecture Arch = Architecture.InceptionV3; + return (_trainStep, crossEntropyMean, _labelTensor, _softMaxTensor); + } - /// - /// Name of the tensor that will contain the output scores of the last layer when transfer learning is done. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Softmax tensor of the last layer in transfer learning.", SortOrder = 15)] - public string ScoreColumnName = "Score"; + private void AddTransferLearningLayer(string labelColumn, + string scoreColumnName, float learningRate, bool useLearningRateScheduling, int classCount) + { + _bottleneckTensor = Graph.OperationByName(_bottleneckOperationName); + (_trainStep, _crossEntropy, _labelTensor, _softMaxTensor) = + AddFinalRetrainOps(classCount, labelColumn, scoreColumnName, _bottleneckTensor, true, + useLearningRateScheduling, learningRate); - /// - /// Name of the tensor that will contain the predicted label from output scores of the last layer when transfer learning is done. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Argmax tensor of the last layer in transfer learning.", SortOrder = 15)] - public string PredictedLabelColumnName = "PredictedLabel"; + } - /// - /// Final model and checkpoint files/folder prefix for storing graph files. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Final model and checkpoint files/folder prefix for storing graph files.", SortOrder = 15)] - public string FinalModelPrefix = "custom_retrained_model_based_on_"; + ~ImageClassificationTrainer() + { + Dispose(false); + } - /// - /// Callback to report statistics on accuracy/cross entropy during training phase. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Callback to report metrics during training and validation phase.", SortOrder = 15)] - public ImageClassificationMetricsCallback MetricsCallback = null; + private void Dispose(bool disposing) + { + // Ensure that the Session is not null and it's handle is not Zero, as it may have already been + // disposed/finalized. Technically we shouldn't be calling this if disposing == false, + // since we're running in finalizer and the GC doesn't guarantee ordering of finalization of managed + // objects, but we have to make sure that the Session is closed before deleting our temporary directory. + if (_session != null && _session != IntPtr.Zero) + { + _session.close(); + } + } - /// - /// Indicates the path where the newly retrained model should be saved. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the path where the newly retrained model should be saved.", SortOrder = 15)] - public string ModelSavePath = null; + /// + /// Trains a using both training and validation data, + /// returns a . + /// + /// The training data set. + /// The validation data set. + public MulticlassPredictionTransformer Fit( + IDataView trainData, IDataView validationData) => TrainTransformer(trainData, validationData); + } - /// - /// Indicates to evaluate the model on train set after every epoch. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to evaluate the model on train set after every epoch.", SortOrder = 15)] - public bool TestOnTrainSet = true; + /// + /// Image Classification predictor. This class encapsulates the trained Deep Neural Network(DNN) model + /// and is used to score images. + /// + public sealed class ImageClassificationModelParameters : ModelParametersBase>, IValueMapper + { + internal const string LoaderSignature = "ImageClassificationPred"; + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "IMAGPRED", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ImageClassificationModelParameters).Assembly.FullName); + } - /// - /// Indicates to not re-compute cached bottleneck trainset values if already available in the bin folder. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to not re-compute trained cached bottleneck values if already available in the bin folder.", SortOrder = 15)] - public bool ReuseTrainSetBottleneckCachedValues = false; + private readonly VectorDataViewType _inputType; + private readonly VectorDataViewType _outputType; + private readonly int _classCount; + private readonly string _imagePreprocessorTensorInput; + private readonly string _imagePreprocessorTensorOutput; + private readonly string _graphInputTensor; + private readonly string _graphOutputTensor; + private readonly Session _session; + + internal ImageClassificationModelParameters(IHostEnvironment env, Session session, int classCount, + string imagePreprocessorTensorInput, string imagePreprocessorTensorOutput, string graphInputTensor, + string graphOutputTensor) : base(env, LoaderSignature) + { + Host.AssertValue(session); + Host.Assert(classCount > 1); + Host.AssertNonEmpty(imagePreprocessorTensorInput); + Host.AssertNonEmpty(imagePreprocessorTensorOutput); + Host.AssertNonEmpty(graphInputTensor); + Host.AssertNonEmpty(graphOutputTensor); + + _inputType = new VectorDataViewType(NumberDataViewType.Byte); + _outputType = new VectorDataViewType(NumberDataViewType.Single, classCount); + _classCount = classCount; + _session = session; + _imagePreprocessorTensorInput = imagePreprocessorTensorInput; + _imagePreprocessorTensorOutput = imagePreprocessorTensorOutput; + _graphInputTensor = graphInputTensor; + _graphOutputTensor = graphOutputTensor; + } - /// - /// Indicates to not re-compute cached bottleneck validationset values if already available in the bin folder. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to not re-compute validataionset cached bottleneck validationset values if already available in the bin folder.", SortOrder = 15)] - public bool ReuseValidationSetBottleneckCachedValues = false; + /// Return the type of prediction task. + private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification; - /// - /// Validation set. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Validation set.", SortOrder = 15)] - public IDataView ValidationSet; + DataViewType IValueMapper.InputType => _inputType; - /// - /// Indicates the file path to store trainset bottleneck values for caching. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the file path to store trainset bottleneck values for caching.", SortOrder = 15)] - public string TrainSetBottleneckCachedValuesFilePath = "trainSetBottleneckFile.csv"; + DataViewType IValueMapper.OutputType => _outputType; - /// - /// Indicates the file path to store validationset bottleneck values for caching. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the file path to store validationset bottleneck values for caching.", SortOrder = 15)] - public string ValidationSetBottleneckCachedValuesFilePath = "validationSetBottleneckFile.csv"; + private ImageClassificationModelParameters(IHostEnvironment env, ModelLoadContext ctx) + : base(env, LoaderSignature, ctx) + { + // *** Binary format *** + // int: _classCount + // string: _imagePreprocessorTensorInput + // string: _imagePreprocessorTensorOutput + // string: _graphInputTensor + // string: _graphOutputTensor + // Graph. + + _classCount = ctx.Reader.ReadInt32(); + _imagePreprocessorTensorInput = ctx.Reader.ReadString(); + _imagePreprocessorTensorOutput = ctx.Reader.ReadString(); + _graphInputTensor = ctx.Reader.ReadString(); + _graphOutputTensor = ctx.Reader.ReadString(); + byte[] modelBytes = null; + if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) + throw env.ExceptDecode(); - /// - /// A class that performs learning rate scheduling. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "A class that performs learning rate scheduling.", SortOrder = 15)] - public LearningRateScheduler LearningRateScheduler; + _session = LoadTFSession(env, modelBytes); + _inputType = new VectorDataViewType(NumberDataViewType.Byte); + _outputType = new VectorDataViewType(NumberDataViewType.Single, _classCount); } - private readonly IHost _host; - private readonly Options _options; - private readonly DnnModel _dnnModel; - private readonly DataViewType[] _inputTypes; - private ImageClassificationTransformer _transformer; - - internal ImageClassificationEstimator(IHostEnvironment env, Options options, DnnModel dnnModel) + private static ImageClassificationModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { - _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageClassificationEstimator)); - _options = options; - _dnnModel = dnnModel; - _inputTypes = new[] { new VectorDataViewType(NumberDataViewType.Byte) }; - options.InputColumns = new[] { options.FeaturesColumnName }; - options.OutputColumns = new[] { options.ScoreColumnName, options.PredictedLabelColumnName }; + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + return new ImageClassificationModelParameters(env, ctx); } - private static Options CreateArguments(DnnModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput) + private protected override void SaveCore(ModelSaveContext ctx) { - var options = new Options(); - options.InputColumns = inputColumnName; - options.OutputColumns = outputColumnNames; - return options; + base.SaveCore(ctx); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // int: _classCount + // string: _imagePreprocessorTensorInput + // string: _imagePreprocessorTensorOutput + // string: _graphInputTensor + // string: _graphOutputTensor + // Graph. + + ctx.Writer.Write(_classCount); + ctx.Writer.Write(_imagePreprocessorTensorInput); + ctx.Writer.Write(_imagePreprocessorTensorOutput); + ctx.Writer.Write(_graphInputTensor); + ctx.Writer.Write(_graphOutputTensor); + + Status status = new Status(); + var buffer = _session.graph.ToGraphDef(status); + ctx.SaveBinaryStream("TFModel", w => + { + w.WriteByteArray(buffer.MemoryBlock.ToArray()); + }); + status.Check(true); } - /// - /// Returns the of the schema which will be produced by the transformer. - /// Used for schema propagation and verification in a pipeline. - /// - public SchemaShape GetOutputSchema(SchemaShape inputSchema) + private class Classifier { - _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.ToDictionary(x => x.Name); - var resultDic = inputSchema.ToDictionary(x => x.Name); - for (var i = 0; i < _options.InputColumns.Length; i++) + private Runner _runner; + private ImageClassificationTrainer.ImageProcessor _imageProcessor; + + public Classifier(ImageClassificationModelParameters model) { - var input = _options.InputColumns[i]; - if (!inputSchema.TryFindColumn(input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); - var expectedType = _inputTypes[i]; - if (!col.ItemType.Equals(expectedType.GetItemType())) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString()); + _runner = new Runner(model._session); + _runner.AddInput(model._graphInputTensor); + _runner.AddOutputs(model._graphOutputTensor); + _imageProcessor = new ImageClassificationTrainer.ImageProcessor(model._session, + model._imagePreprocessorTensorInput, model._imagePreprocessorTensorOutput); } - resultDic[_options.OutputColumns[0]] = new SchemaShape.Column(_options.OutputColumns[0], - SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); - - var metadata = new List(); - metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false)); + public void Score(in VBuffer image, Span classProbabilities) + { + var processedTensor = _imageProcessor.ProcessImage(image); + var outputTensor = _runner.AddInput(processedTensor, 0).Run(); + outputTensor[0].CopyTo(classProbabilities); + outputTensor[0].Dispose(); + processedTensor.Dispose(); + } + } - resultDic[_options.OutputColumns[1]] = new SchemaShape.Column(_options.OutputColumns[1], - SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, new SchemaShape(metadata.ToArray())); + ValueMapper IValueMapper.GetMapper() + { + Host.Check(typeof(TSrc) == typeof(VBuffer)); + Host.Check(typeof(TDst) == typeof(VBuffer)); + _session.graph.as_default(); + Classifier classifier = new Classifier(this); + ValueMapper, VBuffer> del = (in VBuffer src, ref VBuffer dst) => + { + var editor = VBufferEditor.Create(ref dst, _classCount); + classifier.Score(src, editor.Values); + dst = editor.Commit(); + }; - return new SchemaShape(resultDic.Values); + return (ValueMapper)(Delegate)del; } - /// - /// Trains and returns a . - /// - public ImageClassificationTransformer Fit(IDataView input) + ~ImageClassificationModelParameters() { - _host.CheckValue(input, nameof(input)); - _transformer = new ImageClassificationTransformer(_host, _options, _dnnModel, input); + Dispose(false); + } - // Validate input schema. - _transformer.GetOutputSchema(input.Schema); - return _transformer; + private void Dispose(bool disposing) + { + // Ensure that the Session is not null and it's handle is not Zero, as it may have already been + // disposed/finalized. Technically we shouldn't be calling this if disposing == false, + // since we're running in finalizer and the GC doesn't guarantee ordering of finalization of managed + // objects, but we have to make sure that the Session is closed before deleting our temporary directory. + if (_session != null && _session != IntPtr.Zero) + { + _session.close(); + } } } } diff --git a/src/Microsoft.ML.Dnn/LearningRateScheduler.cs b/src/Microsoft.ML.Dnn/LearningRateScheduler.cs index 1d20c5941e..4cc11d9380 100644 --- a/src/Microsoft.ML.Dnn/LearningRateScheduler.cs +++ b/src/Microsoft.ML.Dnn/LearningRateScheduler.cs @@ -1,8 +1,11 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; -using System.Text; -namespace Microsoft.ML.Transforms +namespace Microsoft.ML.Dnn { /// /// A class that contains the current train state to use for learning rate scheduling. diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs index 6a981b6b70..5d58c18f2d 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs @@ -430,7 +430,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) if (_type) result[outputColumnName] = new SchemaShape.Column(outputColumnName, SchemaShape.Column.VectorKind.Scalar, new ImageDataViewType(), false); else - result[outputColumnName] = new SchemaShape.Column(outputColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Byte, false); + result[outputColumnName] = new SchemaShape.Column(outputColumnName, SchemaShape.Column.VectorKind.VariableVector, NumberDataViewType.Byte, false); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 3f351eb119..5d58c35de6 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -12,15 +12,15 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; +using Microsoft.ML.Dnn; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Dnn; using Microsoft.ML.Transforms.TensorFlow; using NumSharp; using Tensorflow; -using static Microsoft.ML.Transforms.Dnn.DnnUtils; +using static Microsoft.ML.Dnn.DnnUtils; using static Tensorflow.Binding; [assembly: LoadableClass(TensorFlowTransformer.Summary, typeof(IDataTransform), typeof(TensorFlowTransformer), diff --git a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs index 4679cf2470..b56464d28a 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs @@ -6,9 +6,9 @@ using System.IO; using System.Linq; using Microsoft.ML.Data; +using Microsoft.ML.Dnn; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; -using Microsoft.ML.Transforms.Dnn; using Tensorflow; namespace Microsoft.ML.Transforms.TensorFlow diff --git a/test/Microsoft.ML.Benchmarks/ImageClassificationBench.cs b/test/Microsoft.ML.Benchmarks/ImageClassificationBench.cs index 9916be3fe9..a07e717829 100644 --- a/test/Microsoft.ML.Benchmarks/ImageClassificationBench.cs +++ b/test/Microsoft.ML.Benchmarks/ImageClassificationBench.cs @@ -7,15 +7,13 @@ using System.IO.Compression; using System.Collections.Generic; using System.Linq; -using System.Net; -using System.Threading; using System.Threading.Tasks; using Microsoft.ML.Data; using Microsoft.ML.Transforms; using BenchmarkDotNet.Attributes; using static Microsoft.ML.DataOperationsCatalog; using System.Net.Http; -using System.Diagnostics; +using Microsoft.ML.Dnn; namespace Microsoft.ML.Benchmarks { @@ -80,20 +78,19 @@ public void SetupData() [Benchmark] public TransformerChain TrainResnetV250() { - var options = new ImageClassificationEstimator.Options() + var options = new ImageClassificationTrainer.Options() { - FeaturesColumnName = "Image", + FeatureColumnName = "Image", LabelColumnName = "Label", - Arch = ImageClassificationEstimator.Architecture.ResnetV250, + Arch = ImageClassificationTrainer.Architecture.ResnetV250, Epoch = 50, BatchSize = 10, LearningRate = 0.01f, - EarlyStoppingCriteria = new ImageClassificationEstimator.EarlyStopping(minDelta: 0.001f, patience: 20, metric: ImageClassificationEstimator.EarlyStoppingMetric.Loss), + EarlyStoppingCriteria = new ImageClassificationTrainer.EarlyStopping(minDelta: 0.001f, patience: 20, metric: ImageClassificationTrainer.EarlyStoppingMetric.Loss), ValidationSet = testDataset, - ModelSavePath = assetsPath, - DisableEarlyStopping = true + ModelSavePath = assetsPath }; - var pipeline = mlContext.Model.ImageClassification(options) + var pipeline = mlContext.MulticlassClassification.Trainers.ImageClassification(options) .Append(mlContext.Transforms.Conversion.MapKeyToValue( outputColumnName: "PredictedLabel", inputColumnName: "PredictedLabel")); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 4250fb5d1b..395a76358c 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -10,6 +10,7 @@ using System.Net; using System.Runtime.InteropServices; using Microsoft.ML.Data; +using Microsoft.ML.Dnn; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; using Microsoft.ML.TestFramework.Attributes; @@ -1247,7 +1248,7 @@ public void TensorFlowImageClassificationDefault() .Transform(testDataset); var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer - .Append(mlContext.Model.ImageClassification("Image", "Label", validationSet: validationSet) + .Append(mlContext.MulticlassClassification.Trainers.ImageClassification("Label", "Image", validationSet: validationSet) .Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel", inputColumnName: "PredictedLabel"))); ; var trainedModel = pipeline.Fit(trainDataset); @@ -1284,10 +1285,10 @@ public void TensorFlowImageClassificationDefault() } [TensorFlowTheory] - [InlineData(ImageClassificationEstimator.Architecture.ResnetV2101)] - [InlineData(ImageClassificationEstimator.Architecture.MobilenetV2)] - [InlineData(ImageClassificationEstimator.Architecture.ResnetV250)] - public void TensorFlowImageClassification(ImageClassificationEstimator.Architecture arch) + [InlineData(ImageClassificationTrainer.Architecture.ResnetV2101)] + [InlineData(ImageClassificationTrainer.Architecture.MobilenetV2)] + [InlineData(ImageClassificationTrainer.Architecture.ResnetV250)] + public void TensorFlowImageClassification(ImageClassificationTrainer.Architecture arch) { string assetsRelativePath = @"assets"; string assetsPath = GetAbsolutePath(assetsRelativePath); @@ -1325,9 +1326,9 @@ public void TensorFlowImageClassification(ImageClassificationEstimator.Architect .Fit(testDataset) .Transform(testDataset); - var options = new ImageClassificationEstimator.Options() + var options = new ImageClassificationTrainer.Options() { - FeaturesColumnName = "Image", + FeatureColumnName = "Image", LabelColumnName = "Label", // Just by changing/selecting InceptionV3/MobilenetV2 here instead of // ResnetV2101 you can try a different architecture/ @@ -1338,12 +1339,11 @@ public void TensorFlowImageClassification(ImageClassificationEstimator.Architect LearningRate = 0.01f, MetricsCallback = (metrics) => Console.WriteLine(metrics), TestOnTrainSet = false, - ValidationSet = validationSet, - DisableEarlyStopping = true + ValidationSet = validationSet }; var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer - .Append(mlContext.Model.ImageClassification(options) + .Append(mlContext.MulticlassClassification.Trainers.ImageClassification(options) .Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel", inputColumnName: "PredictedLabel"))); var trainedModel = pipeline.Fit(trainDataset); @@ -1373,8 +1373,8 @@ public void TensorFlowImageClassification(ImageClassificationEstimator.Architect } else { - Assert.Equal(1, metrics.MicroAccuracy); - Assert.Equal(1, metrics.MacroAccuracy); + Assert.InRange(metrics.MicroAccuracy, 0.8, 1); + Assert.InRange(metrics.MacroAccuracy, 0.8, 1); } // Testing TrySinglePrediction: Utilizing PredictionEngine for single @@ -1463,22 +1463,22 @@ public void TensorFlowImageClassificationWithLRScheduling() .Fit(testDataset) .Transform(testDataset); - var options = new ImageClassificationEstimator.Options() + var options = new ImageClassificationTrainer.Options() { - FeaturesColumnName = "Image", + FeatureColumnName = "Image", LabelColumnName = "Label", // Just by changing/selecting InceptionV3/MobilenetV2 here instead of // ResnetV2101 you can try a different architecture/ // pre-trained model. - Arch = ImageClassificationEstimator.Architecture.ResnetV2101, + Arch = ImageClassificationTrainer.Architecture.ResnetV2101, Epoch = 50, BatchSize = 10, LearningRate = 0.01f, MetricsCallback = (metrics) => Console.WriteLine(metrics), ValidationSet = validationSet, - DisableEarlyStopping = true, ReuseValidationSetBottleneckCachedValues = false, ReuseTrainSetBottleneckCachedValues = false, + EarlyStoppingCriteria = null, // Using Exponential Decay for learning rate scheduling // You can also try other types of Learning rate scheduling methods // available in LearningRateScheduler.cs @@ -1486,7 +1486,7 @@ public void TensorFlowImageClassificationWithLRScheduling() }; var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer - .Append(mlContext.Model.ImageClassification(options)) + .Append(mlContext.MulticlassClassification.Trainers.ImageClassification(options)) .Append(mlContext.Transforms.Conversion.MapKeyToValue( outputColumnName: "PredictedLabel", inputColumnName: "PredictedLabel")); @@ -1517,8 +1517,8 @@ public void TensorFlowImageClassificationWithLRScheduling() } else { - Assert.Equal(1, metrics.MicroAccuracy); - Assert.Equal(1, metrics.MacroAccuracy); + Assert.InRange(metrics.MicroAccuracy, 0.8, 1); + Assert.InRange(metrics.MacroAccuracy, 0.8, 1); } // Testing TrySinglePrediction: Utilizing PredictionEngine for single @@ -1609,25 +1609,24 @@ public void TensorFlowImageClassificationEarlyStoppingIncreasing() .Fit(testDataset) .Transform(testDataset); - var options = new ImageClassificationEstimator.Options() + var options = new ImageClassificationTrainer.Options() { - FeaturesColumnName = "Image", + FeatureColumnName = "Image", LabelColumnName = "Label", // Just by changing/selecting InceptionV3/MobilenetV2 here instead of // ResnetV2101 you can try a different architecture/ // pre-trained model. - Arch = ImageClassificationEstimator.Architecture.ResnetV2101, + Arch = ImageClassificationTrainer.Architecture.ResnetV2101, Epoch = 100, BatchSize = 5, LearningRate = 0.01f, - EarlyStoppingCriteria = new ImageClassificationEstimator.EarlyStopping(), MetricsCallback = (metrics) => { Console.WriteLine(metrics); lastEpoch = metrics.Train != null ? metrics.Train.Epoch : 0; }, TestOnTrainSet = false, - ValidationSet = validationSet, + ValidationSet = validationSet }; var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer - .Append(mlContext.Model.ImageClassification(options)); + .Append(mlContext.MulticlassClassification.Trainers.ImageClassification(options)); var trainedModel = pipeline.Fit(trainDataset); mlContext.Model.Save(trainedModel, shuffledFullImagesDataset.Schema, @@ -1654,8 +1653,8 @@ public void TensorFlowImageClassificationEarlyStoppingIncreasing() } else { - Assert.Equal(1, metrics.MicroAccuracy); - Assert.Equal(1, metrics.MacroAccuracy); + Assert.InRange(metrics.MicroAccuracy, 0.8, 1); + Assert.InRange(metrics.MacroAccuracy, 0.8, 1); } //Assert that the training ran and stopped within half epochs due to EarlyStopping @@ -1703,25 +1702,25 @@ public void TensorFlowImageClassificationEarlyStoppingDecreasing() .Fit(testDataset) .Transform(testDataset); - var options = new ImageClassificationEstimator.Options() + var options = new ImageClassificationTrainer.Options() { - FeaturesColumnName = "Image", + FeatureColumnName = "Image", LabelColumnName = "Label", // Just by changing/selecting InceptionV3/MobilenetV2 here instead of // ResnetV2101 you can try a different architecture/ // pre-trained model. - Arch = ImageClassificationEstimator.Architecture.ResnetV2101, + Arch = ImageClassificationTrainer.Architecture.ResnetV2101, Epoch = 100, BatchSize = 5, LearningRate = 0.01f, - EarlyStoppingCriteria = new ImageClassificationEstimator.EarlyStopping(metric: ImageClassificationEstimator.EarlyStoppingMetric.Loss), + EarlyStoppingCriteria = new ImageClassificationTrainer.EarlyStopping(metric: ImageClassificationTrainer.EarlyStoppingMetric.Loss), MetricsCallback = (metrics) => { Console.WriteLine(metrics); lastEpoch = metrics.Train != null ? metrics.Train.Epoch : 0; }, TestOnTrainSet = false, ValidationSet = validationSet, }; var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer - .Append(mlContext.Model.ImageClassification(options)); + .Append(mlContext.MulticlassClassification.Trainers.ImageClassification(options)); var trainedModel = pipeline.Fit(trainDataset); mlContext.Model.Save(trainedModel, shuffledFullImagesDataset.Schema, @@ -1748,8 +1747,8 @@ public void TensorFlowImageClassificationEarlyStoppingDecreasing() } else { - Assert.Equal(1, metrics.MicroAccuracy); - Assert.Equal(1, metrics.MacroAccuracy); + Assert.InRange(metrics.MicroAccuracy, 0.8, 1); + Assert.InRange(metrics.MacroAccuracy, 0.8, 1); } //Assert that the training ran and stopped within half epochs due to EarlyStopping