Skip to content

Enable TensorFlowTransform to work with pre-trained models that are not frozen #853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Sep 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
1007955
building transform from ground up
abgoswam Sep 6, 2018
40fbedc
dummy transform works after fixing the getters
abgoswam Sep 6, 2018
48d14c6
SavedModel format works for Train, but fails for Save&Predict
abgoswam Sep 6, 2018
35ff43a
remove dummy transform
abgoswam Sep 6, 2018
6291d0d
remove dummy unit test
abgoswam Sep 7, 2018
57508d3
Works with non-frozen models
abgoswam Sep 7, 2018
cfcd70f
building transform from ground up
abgoswam Sep 6, 2018
236de73
dummy transform works after fixing the getters
abgoswam Sep 6, 2018
781cff0
SavedModel format works for Train, but fails for Save&Predict
abgoswam Sep 6, 2018
07b15a0
remove dummy transform
abgoswam Sep 6, 2018
47f75b5
remove dummy unit test
abgoswam Sep 7, 2018
c304257
Merge branch 'abgoswam/tf_savedmodel' of https://github.com/abgoswam/…
abgoswam Sep 7, 2018
d0430b5
merge with master
abgoswam Sep 7, 2018
950a210
fix compilation issues; verify existing tests work fine
abgoswam Sep 7, 2018
97eb497
works locally; need to refactor code
abgoswam Sep 8, 2018
173729f
refactored code; keeping only 1 version of the convenience API
abgoswam Sep 10, 2018
292140b
Merge branch 'master' into abgoswam/tf_savedmodel
abgoswam Sep 10, 2018
655a8aa
added class for directory structure
abgoswam Sep 11, 2018
be5285a
using latest nuget package (0.0.3) for Microsoft.ML.TensorFlow.TestMo…
abgoswam Sep 11, 2018
46c04a3
delete temporary files used when loading/saving models
abgoswam Sep 11, 2018
e705f93
delete local models; the updated nuget version (0.0.3) for Microsoft.…
abgoswam Sep 11, 2018
84214e2
modified logic for load/restore of models
abgoswam Sep 12, 2018
04d02b8
modified logic for load&restore of unfrozen models
abgoswam Sep 13, 2018
89693bd
merge with latest dotnet/master
abgoswam Sep 13, 2018
8c8d92e
model version update to support non-frozen models
abgoswam Sep 13, 2018
d8edc64
based on the code review comments, we now infer if the provided model…
abgoswam Sep 13, 2018
eea524e
simplify the logic in Save() related to loading of SavedModel.
abgoswam Sep 13, 2018
b609ffd
trying Eric's suggestion
abgoswam Sep 13, 2018
74b8899
revert back to previous changes
abgoswam Sep 13, 2018
3382a83
attempt to use stream copy approach instead of in-memory
abgoswam Sep 14, 2018
aa8e844
taking care of some code review comments
abgoswam Sep 14, 2018
25b1e64
deleting some commented out code
abgoswam Sep 14, 2018
e32acca
Ensure we only copy the file segment & cleanup path logic
ericstj Sep 14, 2018
ac45539
added finalizer that closes the session (if it isn't closed) and dele…
abgoswam Sep 14, 2018
ce4efef
move away from using Dictionary<string, byte[]> and instead use strea…
abgoswam Sep 14, 2018
8b8764b
cleanup + misc review comments
abgoswam Sep 15, 2018
f955488
Merge branch 'master' into abgoswam/tf_savedmodel
abgoswam Sep 17, 2018
6e11f2c
trying to create temp dir with proper ACLs for high priviledge users
abgoswam Sep 19, 2018
ed71513
create temp dir with proper ACLs for high-privilege processes
abgoswam Sep 19, 2018
7df343d
Merge branch 'master' into abgoswam/tf_savedmodel
abgoswam Sep 19, 2018
f883d78
fix build after merge with latest master
abgoswam Sep 19, 2018
ae672d6
taking care of review comments related to model versioning of TFTrans…
abgoswam Sep 19, 2018
fac8dae
remove IDisposable from the TensorFlowTransform; renaming some methods
abgoswam Sep 20, 2018
2b1a576
refactor code so we have only 1 constructor for the TensorFlowTransfo…
abgoswam Sep 20, 2018
21879f6
merge with latest master
abgoswam Sep 20, 2018
a1d912d
fix issues with nuget packaging; refactored the code + added comments
abgoswam Sep 21, 2018
f6a1c84
add checks in code to make sure that the input is not a variable leng…
abgoswam Sep 21, 2018
5957c53
merge with latest master
abgoswam Sep 22, 2018
5120bb9
fix typo in name of package
abgoswam Sep 22, 2018
a624a3b
(1) added SavedModel test for MNIST model (2) added try/finally for d…
abgoswam Sep 24, 2018
8d9fdc5
remove and sort usings in file TrainSaveModelAndPredict.cs
abgoswam Sep 24, 2018
685ab99
using spaces in nupkgproj
abgoswam Sep 25, 2018
2d0ec1e
error checking for passed in IHostEnvironment
abgoswam Sep 25, 2018
8d8b986
fix TargetFramework version (netcore 2.0) of DnnAnalyzer to match tha…
abgoswam Sep 25, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build/Dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@
<MicrosoftCodeAnalysisCSharpVersion>2.9.0</MicrosoftCodeAnalysisCSharpVersion>
<MicrosoftCSharpVersion>4.5.0</MicrosoftCSharpVersion>
<SystemCompositionVersion>1.2.0</SystemCompositionVersion>
<SystemIOFileSystemAccessControl>4.5.0</SystemIOFileSystemAccessControl>
<SystemSecurityPrincipalWindows>4.5.0</SystemSecurityPrincipalWindows>
</PropertyGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.TensorFlow.Redist\Microsoft.ML.TensorFlow.Redist.nupkgproj" />
<PackageReference Include="System.IO.FileSystem.AccessControl" Version="$(SystemIOFileSystemAccessControl)" />
<PackageReference Include="System.Security.Principal.Windows" Version="$(SystemSecurityPrincipalWindows)" />
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
<ProjectReference Include="../Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj" />
</ItemGroup>

</Project>
24 changes: 24 additions & 0 deletions src/Microsoft.ML.Core/Utilities/Stream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,30 @@ public static void CloseEx(this TextWriter writer)
writer.Close();
}

/// <summary>
/// Similar to Stream.CopyTo but takes a length rather than assuming copy to end. Returns amount copied.
/// </summary>
/// <param name="source">Source stream to copy from</param>
/// <param name="destination">Destination stream to copy to</param>
/// <param name="length">Number of bytes to copy</param>
/// <param name="bufferSize">Size of buffer to use when copying, default is 81920 to match that of Stream</param>
/// <returns>number of bytes copied</returns>
public static long CopyRange(this Stream source, Stream destination, long length, int bufferSize = 81920)
{
// should use ArrayPool once we can take that dependency
byte[] buffer = new byte[bufferSize];
int read;
long remaining = length;
while (remaining != 0 &&
(read = source.Read(buffer, 0, (int)Math.Min(buffer.Length, remaining))) != 0)
{
destination.Write(buffer, 0, read);
remaining -= read;
}

return length - remaining;
}

public static void WriteBoolByte(this BinaryWriter writer, bool x)
{
Contracts.AssertValue(writer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.1</TargetFramework>
<TargetFramework>netcoreapp2.0</TargetFramework>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be netcoreapp2.1. Note that netcoreapp and netstandard are different things. We want netcoreapp2.1 for any and all executables and tests. We want netstandard2.0 for any libraries.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moreover, we don't even want this in the lib folder as it is just a commandline tool.

<AssemblyName>DnnAnalyzer</AssemblyName>
<IncludeInPackage>Microsoft.ML.TensorFlow</IncludeInPackage>
</PropertyGroup>
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Legacy/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15787,9 +15787,9 @@ public sealed partial class TensorFlowScorer : Microsoft.ML.Runtime.EntryPoints.


/// <summary>
/// This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.
/// TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.
/// </summary>
public string ModelFile { get; set; }
public string Model { get; set; }

/// <summary>
/// The names of the model inputs
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="System.IO.FileSystem.AccessControl" Version="$(SystemIOFileSystemAccessControl)" />
<PackageReference Include="System.Security.Principal.Windows" Version="$(SystemSecurityPrincipalWindows)" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,7 @@ public IEnumerable<DeviceAttributes> ListDevices(TFStatus status = null)
/// here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md
/// </para>
/// </remarks>
public TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null)
public static TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null)
{
if (graph == null)
throw new ArgumentNullException(nameof(graph));
Expand Down
156 changes: 152 additions & 4 deletions src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using System.Security.AccessControl;
using System.Security.Principal;

namespace Microsoft.ML.Transforms.TensorFlow
{
Expand Down Expand Up @@ -158,6 +160,152 @@ internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelByte
return new TFSession(graph);
}

private static TFSession LoadTFSession(IHostEnvironment env, string exportDirSavedModel)
{
Contracts.Check(env != null, nameof(env));
env.CheckValue(exportDirSavedModel, nameof(exportDirSavedModel));
var sessionOptions = new TFSessionOptions();
var tags = new string[] { "serve" };
var graph = new TFGraph();
var metaGraphDef = new TFBuffer();

return TFSession.FromSavedModel(sessionOptions, null, exportDirSavedModel, tags, graph, metaGraphDef);
}

// A TensorFlow frozen model is a single file. An un-frozen (SavedModel) on the other hand has a well-defined folder structure.
// Given a modelPath, this utility method determines if we should treat it as a SavedModel or not
internal static bool IsSavedModel(IHostEnvironment env, string modelPath)
{
Contracts.Check(env != null, nameof(env));
env.CheckNonWhiteSpace(modelPath, nameof(modelPath));
FileAttributes attr = File.GetAttributes(modelPath);
return attr.HasFlag(FileAttributes.Directory);
}

// Currently used in TensorFlowTransform to protect temporary folders used when working with TensorFlow's SavedModel format.
// Models are considered executable code, so we need to ACL tthe temp folders for high-rights process (so low-rights process can’t access it).
/// <summary>
/// Given a folder path, create it with proper ACL if it doesn't exist.
/// Fails if the folder name is empty, or can't create the folder.
/// </summary>
internal static void CreateFolderWithAclIfNotExists(IHostEnvironment env, string folder)
Copy link
Contributor

@zeahmed zeahmed Sep 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CreateFolderWithAclIfNotExists [](start = 29, length = 30)

How does creation of ACL folder work in Linux/Mac? #Closed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catches the PlatformNotSupported exception and does the normal create.


In reply to: 219903672 [](ancestors = 219903672)

Copy link
Contributor

@zeahmed zeahmed Sep 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Eric, got it.


In reply to: 219906478 [](ancestors = 219906478,219903672)

{
Contracts.Check(env != null, nameof(env));
env.CheckNonWhiteSpace(folder, nameof(folder));

//if directory exists, do nothing.
if (Directory.Exists(folder))
return;

WindowsIdentity currentIdentity = null;
try
{
currentIdentity = WindowsIdentity.GetCurrent();
}
catch (PlatformNotSupportedException)
{ }

if (currentIdentity != null && new WindowsPrincipal(currentIdentity).IsInRole(WindowsBuiltInRole.Administrator))
{
// Create high integrity dir and set no delete policy for all files under the directory.
// In case of failure, throw exception.
CreateTempDirectoryWithAcl(folder, currentIdentity.User.ToString());
}
else
{
try
{
Directory.CreateDirectory(folder);
}
catch (Exception exc)
{
throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}");
}
}
}

internal static void DeleteFolderWithRetries(IHostEnvironment env, string folder)
{
Contracts.Check(env != null, nameof(env));
int currentRetry = 0;
int maxRetryCount = 10;
using (var ch = env.Start("Delete folder"))
{
for (; ; )
{
try
{
currentRetry++;
Directory.Delete(folder, true);
break;
}
catch (IOException e)
{
if (currentRetry > maxRetryCount)
throw;
ch.Info("Error deleting folder. {0}. Retry,", e.Message);
}
}
}
}

private static void CreateTempDirectoryWithAcl(string folder, string identity)
{
// Dacl Sddl string:
// D: Dacl type
// D; Deny access
// OI; Object inherit ace
// SD; Standard delete function
// wIdentity.User Sid of the given user.
// A; Allow access
// OICI; Object inherit, container inherit
// FA File access
// BA Built-in administrators
// S: Sacl type
// ML;; Mandatory Label
// NW;;; No write policy
// HI High integrity processes only
string sddl = "D:(D;OI;SD;;;" + identity + ")(A;OICI;FA;;;BA)S:(ML;OI;NW;;;HI)";

try
{
var dir = Directory.CreateDirectory(folder);
DirectorySecurity dirSec = new DirectorySecurity();
dirSec.SetSecurityDescriptorSddlForm(sddl);
dirSec.SetAccessRuleProtection(true, false); // disable inheritance
dir.SetAccessControl(dirSec);

// Cleaning out the directory, in case someone managed to sneak in between creation and setting ACL.
DirectoryInfo dirInfo = new DirectoryInfo(folder);
foreach (FileInfo file in dirInfo.GetFiles())
{
file.Delete();
}
foreach (DirectoryInfo subDirInfo in dirInfo.GetDirectories())
{
subDirInfo.Delete(true);
}
}
catch (Exception exc)
{
throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}");
}
}

internal static TFSession GetSession(IHostEnvironment env, string modelPath)
{
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Contracts.Check(env, nameof(env); #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this check.

btw i believe you meant Contracts.Check(env != null, nameof(env));


In reply to: 220031028 [](ancestors = 220031028)

Contracts.Check(env != null, nameof(env));
if (IsSavedModel(env, modelPath))
{
env.CheckUserArg(Directory.Exists(modelPath), nameof(modelPath));
return LoadTFSession(env, modelPath);
}

env.CheckUserArg(File.Exists(modelPath), nameof(modelPath));
var bytes = File.ReadAllBytes(modelPath);
return LoadTFSession(env, bytes, modelPath);
}

internal static unsafe void FetchData<T>(IntPtr data, T[] result)
{
var size = result.Length;
Expand Down
Loading