Skip to content

Make ImageClassification API an ITrainerEstimator and refactor code. #4372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Oct 25, 2019
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

<ItemGroup>
<Compile Include="..\Microsoft.ML.Samples\Program.cs" Link="Program.cs" />
<Compile Include="..\Microsoft.ML.Samples\Dynamic\ImageClassification\*.cs">
<Compile Include="..\Microsoft.ML.Samples\Dynamic\Trainers\MulticlassClassification\ImageClassification\*.cs">
<Link>Dynamic\ImageClassification\%(FileName)</Link>
</Compile>
<Compile Include="..\Microsoft.ML.Samples\Dynamic\TensorFlow\*.cs">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms;
using static Microsoft.ML.DataOperationsCatalog;

namespace Samples.Dynamic
Expand Down Expand Up @@ -60,12 +59,11 @@ public static void Example()
IDataView trainDataset = trainTestData.TrainSet;
IDataView testDataset = trainTestData.TestSet;

var pipeline = mlContext.Model.ImageClassification("Image", "Label", validationSet: testDataset)
.Append(mlContext.Transforms.Conversion.MapKeyToValue(
outputColumnName: "PredictedLabel",
var pipeline = mlContext.MulticlassClassification.Trainers
.ImageClassification(featureColumnName:"Image", validationSet:testDataset)
.Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel",
inputColumnName: "PredictedLabel"));


Console.WriteLine("*** Training the image classification model " +
"with DNN Transfer Learning on top of the selected " +
"pre-trained model/architecture ***");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Dnn;
using Microsoft.ML.Transforms;
using static Microsoft.ML.DataOperationsCatalog;

namespace Samples.Dynamic
{
Expand Down Expand Up @@ -65,20 +65,19 @@ public static void Example()
.Fit(testDataset)
.Transform(testDataset);

var options = new ImageClassificationEstimator.Options()
var options = new ImageClassificationTrainer.Options()
{
FeaturesColumnName = "Image",
FeatureColumnName = "Image",
LabelColumnName = "Label",
// Just by changing/selecting InceptionV3/MobilenetV2 here instead of
// ResnetV2101 you can try a different architecture/
// pre-trained model.
Arch = ImageClassificationEstimator.Architecture.ResnetV2101,
Arch = ImageClassificationTrainer.Architecture.ResnetV2101,
Epoch = 182,
BatchSize = 128,
LearningRate = 0.01f,
MetricsCallback = (metrics) => Console.WriteLine(metrics),
ValidationSet = testDataset,
DisableEarlyStopping = true,
ReuseValidationSetBottleneckCachedValues = false,
ReuseTrainSetBottleneckCachedValues = false,
// Use linear scaling rule and Learning rate decay as an option
Expand All @@ -88,7 +87,7 @@ public static void Example()
LearningRateScheduler = new LsrDecay()
};

var pipeline = mlContext.Model.ImageClassification(options)
var pipeline = mlContext.MulticlassClassification.Trainers.ImageClassification(options)
.Append(mlContext.Transforms.Conversion.MapKeyToValue(
outputColumnName: "PredictedLabel",
inputColumnName: "PredictedLabel"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.IO.Compression;
using System.Threading;
using System.Net;
using Microsoft.ML.Dnn;

namespace Samples.Dynamic
{
Expand Down Expand Up @@ -62,23 +63,24 @@ public static void Example()
.Fit(testDataset)
.Transform(testDataset);

var options = new ImageClassificationEstimator.Options()
var options = new ImageClassificationTrainer.Options()
{
FeaturesColumnName = "Image",
FeatureColumnName = "Image",
LabelColumnName = "Label",
// Just by changing/selecting InceptionV3/MobilenetV2/ResnetV250 here instead of
// ResnetV2101 you can try a different architecture/
// pre-trained model.
Arch = ImageClassificationEstimator.Architecture.ResnetV2101,
Arch = ImageClassificationTrainer.Architecture.ResnetV2101,
BatchSize = 10,
LearningRate = 0.01f,
EarlyStoppingCriteria = new ImageClassificationEstimator.EarlyStopping(minDelta: 0.001f, patience: 20, metric: ImageClassificationEstimator.EarlyStoppingMetric.Loss),
EarlyStoppingCriteria = new ImageClassificationTrainer.EarlyStopping(
minDelta: 0.001f, patience: 20, metric: ImageClassificationTrainer.EarlyStoppingMetric.Loss),
MetricsCallback = (metrics) => Console.WriteLine(metrics),
ValidationSet = validationSet
};

var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>
.Append(mlContext.Model.ImageClassification(options));
.Append(mlContext.MulticlassClassification.Trainers.ImageClassification(options));

Console.WriteLine("*** Training the image classification model with " +
"DNN Transfer Learning on top of the selected pre-trained " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms;
using Microsoft.ML.Dnn;
using static Microsoft.ML.DataOperationsCatalog;

namespace Samples.Dynamic
Expand Down Expand Up @@ -60,23 +60,22 @@ public static void Example()
IDataView trainDataset = trainTestData.TrainSet;
IDataView testDataset = trainTestData.TestSet;

var options = new ImageClassificationEstimator.Options()
var options = new ImageClassificationTrainer.Options()
{
FeaturesColumnName = "Image",
FeatureColumnName = "Image",
LabelColumnName = "Label",
// Just by changing/selecting InceptionV3/MobilenetV2/ResnetV250 here instead of
// ResnetV2101 you can try a different architecture/
// pre-trained model.
Arch = ImageClassificationEstimator.Architecture.ResnetV2101,
Arch = ImageClassificationTrainer.Architecture.ResnetV2101,
Epoch = 50,
BatchSize = 10,
LearningRate = 0.01f,
MetricsCallback = (metrics) => Console.WriteLine(metrics),
ValidationSet = testDataset,
DisableEarlyStopping = true
ValidationSet = testDataset
};

var pipeline = mlContext.Model.ImageClassification(options)
var pipeline = mlContext.MulticlassClassification.Trainers.ImageClassification(options)
.Append(mlContext.Transforms.Conversion.MapKeyToValue(
outputColumnName: "PredictedLabel",
inputColumnName: "PredictedLabel"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Threading;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
Expand Down Expand Up @@ -34,6 +35,11 @@ internal sealed class MulticlassClassificationScorer : PredictedLabelScorerBase
// between scores and probabilities (using IDistributionPredictor)
public sealed class Arguments : ScorerArgumentsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Score Column Name.", ShortName = "scn")]
public string ScoreColumnName = AnnotationUtils.Const.ScoreValueKind.Score;

[Argument(ArgumentType.AtMostOnce, HelpText = "Predicted Label Column Name.", ShortName = "plcn")]
public string PredictedLabelColumnName = DefaultColumnNames.PredictedLabel;
}

public const string LoaderSignature = "MultiClassScoreTrans";
Expand Down Expand Up @@ -486,7 +492,7 @@ internal static ISchemaBoundMapper WrapCore<T>(IHostEnvironment env, ISchemaBoun
[BestFriend]
internal MulticlassClassificationScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
: base(args, env, data, WrapIfNeeded(env, mapper, trainSchema), trainSchema, RegistrationName, AnnotationUtils.Const.ScoreColumnKind.MulticlassClassification,
AnnotationUtils.Const.ScoreValueKind.Score, OutputTypeMatches, GetPredColType)
args.ScoreColumnName, OutputTypeMatches, GetPredColType, args.PredictedLabelColumnName)
{
}

Expand Down
12 changes: 6 additions & 6 deletions src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ private protected sealed class BindingsImpl : BindingsBase
private readonly AnnotationUtils.AnnotationGetter<ReadOnlyMemory<char>> _getScoreValueKind;
private readonly DataViewSchema.Annotations _predColMetadata;
private BindingsImpl(DataViewSchema input, ISchemaBoundRowMapper mapper, string suffix, string scoreColumnKind,
bool user, int scoreColIndex, DataViewType predColType)
: base(input, mapper, suffix, user, DefaultColumnNames.PredictedLabel)
bool user, int scoreColIndex, DataViewType predColType, string predictedLabelColumnName = DefaultColumnNames.PredictedLabel)
: base(input, mapper, suffix, user, predictedLabelColumnName)
{
Contracts.AssertNonEmpty(scoreColumnKind);
Contracts.Assert(DerivedColumnCount == 1);
Expand Down Expand Up @@ -82,15 +82,15 @@ private static DataViewSchema.Annotations KeyValueMetadataFromMetadata<T>(DataVi
}

public static BindingsImpl Create(DataViewSchema input, ISchemaBoundRowMapper mapper, string suffix,
string scoreColKind, int scoreColIndex, DataViewType predColType)
string scoreColKind, int scoreColIndex, DataViewType predColType, string predictedLabelColumnName = DefaultColumnNames.PredictedLabel)
{
Contracts.AssertValue(input);
Contracts.AssertValue(mapper);
Contracts.AssertValueOrNull(suffix);
Contracts.AssertNonEmpty(scoreColKind);

return new BindingsImpl(input, mapper, suffix, scoreColKind, true,
scoreColIndex, predColType);
scoreColIndex, predColType, predictedLabelColumnName);
}

public BindingsImpl ApplyToSchema(DataViewSchema input, ISchemaBindableMapper bindable, IHostEnvironment env)
Expand Down Expand Up @@ -272,7 +272,7 @@ public override Func<int, bool> GetActiveMapperColumns(bool[] active)
[BestFriend]
private protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data,
ISchemaBoundMapper mapper, RoleMappedSchema trainSchema, string registrationName, string scoreColKind, string scoreColName,
Func<DataViewType, bool> outputTypeMatches, Func<DataViewType, ISchemaBoundRowMapper, DataViewType> getPredColType)
Func<DataViewType, bool> outputTypeMatches, Func<DataViewType, ISchemaBoundRowMapper, DataViewType> getPredColType, string predictedLabelColumnName = DefaultColumnNames.PredictedLabel)
: base(env, data, registrationName, Contracts.CheckRef(mapper, nameof(mapper)).Bindable)
{
Host.CheckValue(args, nameof(args));
Expand All @@ -292,7 +292,7 @@ private protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnviro
Host.Check(outputTypeMatches(scoreType), "Unexpected predictor output type");
var predColType = getPredColType(scoreType, rowMapper);

Bindings = BindingsImpl.Create(data.Schema, rowMapper, args.Suffix, scoreColKind, scoreColIndex, predColType);
Bindings = BindingsImpl.Create(data.Schema, rowMapper, args.Suffix, scoreColKind, scoreColIndex, predColType, predictedLabelColumnName);
OutputSchema = Bindings.AsSchema;
}

Expand Down
34 changes: 27 additions & 7 deletions src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.IO;
using System.Reflection;
using System.Runtime.CompilerServices;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
Expand Down Expand Up @@ -448,44 +449,60 @@ public sealed class MulticlassPredictionTransformer<TModel> : SingleFeaturePredi
where TModel : class
{
private readonly string _trainLabelColumn;
private readonly string _scoreColumn;
private readonly string _predictedLabelColumn;

[BestFriend]
internal MulticlassPredictionTransformer(IHostEnvironment env, TModel model, DataViewSchema inputSchema, string featureColumn, string labelColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
internal MulticlassPredictionTransformer(IHostEnvironment env, TModel model, DataViewSchema inputSchema, string featureColumn, string labelColumn,
string scoreColumn = AnnotationUtils.Const.ScoreValueKind.Score, string predictedLabel = DefaultColumnNames.PredictedLabel) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
{
Host.CheckValueOrNull(labelColumn);

_trainLabelColumn = labelColumn;
_scoreColumn = scoreColumn;
_predictedLabelColumn = predictedLabel;
SetScorer();
}

internal MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer<TModel>)), ctx)
{
InitializationLogic(ctx, out _trainLabelColumn);
InitializationLogic(ctx, out _trainLabelColumn, out _scoreColumn, out _predictedLabelColumn);
}

internal MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx, IHost host, TModel model)
: base(host, ctx, model)
{

InitializationLogic(ctx, out _trainLabelColumn);
InitializationLogic(ctx, out _trainLabelColumn, out _scoreColumn, out _predictedLabelColumn);
}

private void InitializationLogic(ModelLoadContext ctx, out string trainLabelColumn)
private void InitializationLogic(ModelLoadContext ctx, out string trainLabelColumn, out string scoreColumn, out string predictedLabelColumn)
{
// *** Binary format ***
// <base info>
// id of string: train label column

trainLabelColumn = ctx.LoadStringOrNull();
if (ctx.Header.ModelVerWritten >= 0x00010002)
{
scoreColumn = ctx.LoadStringOrNull();
predictedLabelColumn = ctx.LoadStringOrNull();
}
else
{
scoreColumn = AnnotationUtils.Const.ScoreValueKind.Score;
predictedLabelColumn = DefaultColumnNames.PredictedLabel;
}

SetScorer();
}

private void SetScorer()
{
var schema = new RoleMappedSchema(TrainSchema, _trainLabelColumn, FeatureColumnName);
var args = new MulticlassClassificationScorer.Arguments();
var args = new MulticlassClassificationScorer.Arguments() { ScoreColumnName = _scoreColumn, PredictedLabelColumnName = _predictedLabelColumn};
Scorer = new MulticlassClassificationScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
}

Expand All @@ -500,13 +517,16 @@ private protected override void SaveCore(ModelSaveContext ctx)
base.SaveCore(ctx);

ctx.SaveStringOrNull(_trainLabelColumn);
ctx.SaveStringOrNull(_scoreColumn);
ctx.SaveStringOrNull(_predictedLabelColumn);
}

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "MC PRED",
verWrittenCur: 0x00010001, // Initial
//verWrittenCur: 0x00010001, // Initial
verWrittenCur: 0x00010002, // Score and Predicted Label column names.
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: MulticlassPredictionTransformer.LoaderSignature,
Expand Down
Loading