Skip to content

Add support for Mobilenet v2 in Image Classification transfer learning #4351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public static void Example()
var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>
.Append(mlContext.Model.ImageClassification(
"Image", "Label",
// Just by changing/selecting InceptionV3 here instead of
// Just by changing/selecting InceptionV3/MobilenetV2 here instead of
// ResnetV2101 you can try a different architecture/pre-trained
// model.
arch: ImageClassificationEstimator.Architecture.ResnetV2101,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public static void Example()

var pipeline = mlContext.Model.ImageClassification(
"Image", "Label",
// Just by changing/selecting InceptionV3 here instead of
// Just by changing/selecting InceptionV3/MobilenetV2 here instead of
// ResnetV2101 you can try a different architecture/
// pre-trained model.
arch: ImageClassificationEstimator.Architecture.ResnetV2101,
Expand Down
10 changes: 9 additions & 1 deletion src/Microsoft.ML.Dnn/DnnCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public static ImageClassificationEstimator ImageClassification(
{
var options = new ImageClassificationEstimator.Options()
{
ModelLocation = arch == Architecture.ResnetV2101 ? @"resnet_v2_101_299.meta" : @"InceptionV3.meta",
ModelLocation = ModelLocation[arch],
InputColumns = new[] { featuresColumnName },
OutputColumns = new[] { scoreColumnName, predictedLabelColumnName },
LabelColumn = labelColumnName,
Expand Down Expand Up @@ -194,6 +194,14 @@ public static ImageClassificationEstimator ImageClassification(
client.DownloadFile(new Uri($"{baseGitPath}"), @"resnet_v2_101_299.meta");
}
}
else if(options.Arch == Architecture.MobilenetV2)
{
var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/MobileNetV2TensorFlow/mobilenet_v2.meta";
using (WebClient client = new WebClient())
{
client.DownloadFile(new Uri($"{baseGitPath}"), @"mobilenet_v2.meta");
}
}
}

var env = CatalogUtils.GetEnvironment(catalog);
Expand Down
34 changes: 31 additions & 3 deletions src/Microsoft.ML.Dnn/ImageClassificationTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,8 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,

tf.train.Saver().restore(evalSess, _checkpointPath);
(evaluationStep, prediction) = AddEvaluationStep(finalTensor, groundTruthInput);
(_jpegData, _resizedImage) = AddJpegDecoding(299, 299, 3);
var imageSize = ImageClassificationEstimator.ImagePreprocessingSize[options.Arch];
(_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3);
return (evalSess, _labelTensor, evaluationStep, prediction);
}

Expand Down Expand Up @@ -827,12 +828,18 @@ internal ImageClassificationTransformer(IHostEnvironment env, Session session, s
_bottleneckOperationName = "module_apply_default/hub_output/feature_vector/SpatialSqueeze";
_inputTensorName = "Placeholder";
}
else if(arch == ImageClassificationEstimator.Architecture.MobilenetV2)
{
_bottleneckOperationName = "import/MobilenetV2/Logits/Squeeze";
_inputTensorName = "import/input";
}

_outputs = new[] { scoreColumnName, predictedLabelColumnName };

if (loadModel == false)
{
(_jpegData, _resizedImage) = AddJpegDecoding(299, 299, 3);
var imageSize = ImageClassificationEstimator.ImagePreprocessingSize[arch];
(_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3);
_jpegDataTensorName = _jpegData.name;
_resizedImageTensorName = _resizedImage.name;

Expand Down Expand Up @@ -1080,7 +1087,28 @@ public sealed class ImageClassificationEstimator : IEstimator<ImageClassificatio
public enum Architecture
{
ResnetV2101,
InceptionV3
InceptionV3,
MobilenetV2
};

/// <summary>
/// Dictionary mapping model architecture to model location.
/// </summary>
internal static IReadOnlyDictionary<Architecture, string> ModelLocation = new Dictionary<Architecture, string>
{
{ Architecture.ResnetV2101, @"resnet_v2_101_299.meta" },
{ Architecture.InceptionV3, @"InceptionV3.meta" },
{ Architecture.MobilenetV2, @"mobilenet_v2.meta" }
};

/// <summary>
/// Dictionary mapping model architecture to image input size supported.
/// </summary>
internal static IReadOnlyDictionary<Architecture, Tuple<int,int>> ImagePreprocessingSize = new Dictionary<Architecture, Tuple<int,int>>
{
{ Architecture.ResnetV2101, new Tuple<int, int>(299,299) },
{ Architecture.InceptionV3, new Tuple<int, int>(299,299) },
{ Architecture.MobilenetV2, new Tuple<int, int>(224,224) }
};

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
namespace Microsoft.ML.TestFramework.Attributes
{
/// <summary>
/// A theory for tests requiring TensorFlow.
/// </summary>
public sealed class TensorFlowTheoryAttribute : EnvironmentSpecificTheoryAttribute
{
public TensorFlowTheoryAttribute() : base("TensorFlow is 64-bit only")
{
}

/// <inheritdoc />
protected override bool IsEnvironmentSupported()
{
return Environment.Is64BitProcess;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1207,8 +1207,10 @@ public void TensorFlowStringTest()
Assert.Equal(string.Join(" ", input.B).Replace("/", " "), textOutput.BOut[0]);
}

[TensorFlowFact]
public void TensorFlowImageClassification()
[TensorFlowTheory]
[InlineData(ImageClassificationEstimator.Architecture.ResnetV2101)]
[InlineData(ImageClassificationEstimator.Architecture.MobilenetV2)]
public void TensorFlowImageClassification(ImageClassificationEstimator.Architecture arch)
{
string assetsRelativePath = @"assets";
string assetsPath = GetAbsolutePath(assetsRelativePath);
Expand Down Expand Up @@ -1249,10 +1251,10 @@ public void TensorFlowImageClassification()
var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>
.Append(mlContext.Model.ImageClassification(
"Image", "Label",
// Just by changing/selecting InceptionV3 here instead of
// Just by changing/selecting InceptionV3/MobilenetV2 here instead of
Copy link
Member

@eerhardt eerhardt Oct 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to add a test for the new functionality? The CI isn't executing any of your new code. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added unit test.


In reply to: 336686355 [](ancestors = 336686355)

// ResnetV2101 you can try a different architecture/pre-trained
// model.
arch: ImageClassificationEstimator.Architecture.ResnetV2101,
arch: arch,
epoch: 50,
batchSize: 10,
learningRate: 0.01f,
Expand Down Expand Up @@ -1384,7 +1386,7 @@ public void TensorFlowImageClassificationEarlyStoppingIncreasing()
var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>
.Append(mlContext.Model.ImageClassification(
"Image", "Label",
// Just by changing/selecting InceptionV3 here instead of
// Just by changing/selecting InceptionV3/MobilenetV2 here instead of
// ResnetV2101 you can try a different architecture/pre-trained
// model.
arch: ImageClassificationEstimator.Architecture.ResnetV2101,
Expand Down Expand Up @@ -1473,7 +1475,7 @@ public void TensorFlowImageClassificationEarlyStoppingDecreasing()
var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>
.Append(mlContext.Model.ImageClassification(
"Image", "Label",
// Just by changing/selecting InceptionV3 here instead of
// Just by changing/selecting InceptionV3/MobilenetV2 here instead of
// ResnetV2101 you can try a different architecture/pre-trained
// model.
arch: ImageClassificationEstimator.Architecture.ResnetV2101,
Expand Down