-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Enable TensorFlowTransform to work with pre-trained models that are not frozen #853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1007955
40fbedc
48d14c6
35ff43a
6291d0d
57508d3
cfcd70f
236de73
781cff0
07b15a0
47f75b5
c304257
d0430b5
950a210
97eb497
173729f
292140b
655a8aa
be5285a
46c04a3
e705f93
84214e2
04d02b8
89693bd
8c8d92e
d8edc64
eea524e
b609ffd
74b8899
3382a83
aa8e844
25b1e64
e32acca
ac45539
ce4efef
8b8764b
f955488
6e11f2c
ed71513
7df343d
f883d78
ae672d6
fac8dae
2b1a576
21879f6
a1d912d
f6a1c84
5957c53
5120bb9
a624a3b
8d9fdc5
685ab99
2d0ec1e
8d8b986
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,15 +2,17 @@ | |
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints; | ||
using Microsoft.ML.Runtime.Internal.Utilities; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.IO; | ||
using System.Linq; | ||
using System.Runtime.InteropServices; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints; | ||
using Microsoft.ML.Runtime.Internal.Utilities; | ||
using System.Security.AccessControl; | ||
using System.Security.Principal; | ||
|
||
namespace Microsoft.ML.Transforms.TensorFlow | ||
{ | ||
|
@@ -158,6 +160,152 @@ internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelByte | |
return new TFSession(graph); | ||
} | ||
|
||
private static TFSession LoadTFSession(IHostEnvironment env, string exportDirSavedModel) | ||
{ | ||
Contracts.Check(env != null, nameof(env)); | ||
env.CheckValue(exportDirSavedModel, nameof(exportDirSavedModel)); | ||
var sessionOptions = new TFSessionOptions(); | ||
var tags = new string[] { "serve" }; | ||
var graph = new TFGraph(); | ||
var metaGraphDef = new TFBuffer(); | ||
|
||
return TFSession.FromSavedModel(sessionOptions, null, exportDirSavedModel, tags, graph, metaGraphDef); | ||
} | ||
|
||
// A TensorFlow frozen model is a single file. An un-frozen (SavedModel) on the other hand has a well-defined folder structure. | ||
// Given a modelPath, this utility method determines if we should treat it as a SavedModel or not | ||
internal static bool IsSavedModel(IHostEnvironment env, string modelPath) | ||
{ | ||
Contracts.Check(env != null, nameof(env)); | ||
env.CheckNonWhiteSpace(modelPath, nameof(modelPath)); | ||
FileAttributes attr = File.GetAttributes(modelPath); | ||
return attr.HasFlag(FileAttributes.Directory); | ||
} | ||
|
||
// Currently used in TensorFlowTransform to protect temporary folders used when working with TensorFlow's SavedModel format. | ||
// Models are considered executable code, so we need to ACL tthe temp folders for high-rights process (so low-rights process can’t access it). | ||
/// <summary> | ||
/// Given a folder path, create it with proper ACL if it doesn't exist. | ||
/// Fails if the folder name is empty, or can't create the folder. | ||
/// </summary> | ||
internal static void CreateFolderWithAclIfNotExists(IHostEnvironment env, string folder) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
How does creation of ACL folder work in Linux/Mac? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catches the PlatformNotSupported exception and does the normal create. In reply to: 219903672 [](ancestors = 219903672) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
{ | ||
Contracts.Check(env != null, nameof(env)); | ||
env.CheckNonWhiteSpace(folder, nameof(folder)); | ||
|
||
//if directory exists, do nothing. | ||
if (Directory.Exists(folder)) | ||
return; | ||
|
||
WindowsIdentity currentIdentity = null; | ||
try | ||
{ | ||
currentIdentity = WindowsIdentity.GetCurrent(); | ||
} | ||
catch (PlatformNotSupportedException) | ||
{ } | ||
|
||
if (currentIdentity != null && new WindowsPrincipal(currentIdentity).IsInRole(WindowsBuiltInRole.Administrator)) | ||
{ | ||
// Create high integrity dir and set no delete policy for all files under the directory. | ||
// In case of failure, throw exception. | ||
CreateTempDirectoryWithAcl(folder, currentIdentity.User.ToString()); | ||
} | ||
else | ||
{ | ||
try | ||
{ | ||
Directory.CreateDirectory(folder); | ||
} | ||
catch (Exception exc) | ||
{ | ||
throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}"); | ||
} | ||
} | ||
} | ||
|
||
internal static void DeleteFolderWithRetries(IHostEnvironment env, string folder) | ||
{ | ||
Contracts.Check(env != null, nameof(env)); | ||
int currentRetry = 0; | ||
int maxRetryCount = 10; | ||
using (var ch = env.Start("Delete folder")) | ||
{ | ||
for (; ; ) | ||
{ | ||
try | ||
{ | ||
currentRetry++; | ||
Directory.Delete(folder, true); | ||
break; | ||
} | ||
catch (IOException e) | ||
{ | ||
if (currentRetry > maxRetryCount) | ||
throw; | ||
ch.Info("Error deleting folder. {0}. Retry,", e.Message); | ||
} | ||
} | ||
} | ||
} | ||
|
||
private static void CreateTempDirectoryWithAcl(string folder, string identity) | ||
{ | ||
// Dacl Sddl string: | ||
// D: Dacl type | ||
// D; Deny access | ||
// OI; Object inherit ace | ||
// SD; Standard delete function | ||
// wIdentity.User Sid of the given user. | ||
// A; Allow access | ||
// OICI; Object inherit, container inherit | ||
// FA File access | ||
// BA Built-in administrators | ||
// S: Sacl type | ||
// ML;; Mandatory Label | ||
// NW;;; No write policy | ||
// HI High integrity processes only | ||
string sddl = "D:(D;OI;SD;;;" + identity + ")(A;OICI;FA;;;BA)S:(ML;OI;NW;;;HI)"; | ||
|
||
try | ||
{ | ||
var dir = Directory.CreateDirectory(folder); | ||
DirectorySecurity dirSec = new DirectorySecurity(); | ||
dirSec.SetSecurityDescriptorSddlForm(sddl); | ||
dirSec.SetAccessRuleProtection(true, false); // disable inheritance | ||
dir.SetAccessControl(dirSec); | ||
|
||
// Cleaning out the directory, in case someone managed to sneak in between creation and setting ACL. | ||
DirectoryInfo dirInfo = new DirectoryInfo(folder); | ||
foreach (FileInfo file in dirInfo.GetFiles()) | ||
{ | ||
file.Delete(); | ||
} | ||
foreach (DirectoryInfo subDirInfo in dirInfo.GetDirectories()) | ||
{ | ||
subDirInfo.Delete(true); | ||
} | ||
} | ||
catch (Exception exc) | ||
{ | ||
throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}"); | ||
} | ||
} | ||
|
||
internal static TFSession GetSession(IHostEnvironment env, string modelPath) | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Contracts.Check(env, nameof(env); #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added this check. btw i believe you meant Contracts.Check(env != null, nameof(env)); In reply to: 220031028 [](ancestors = 220031028) |
||
Contracts.Check(env != null, nameof(env)); | ||
if (IsSavedModel(env, modelPath)) | ||
{ | ||
env.CheckUserArg(Directory.Exists(modelPath), nameof(modelPath)); | ||
return LoadTFSession(env, modelPath); | ||
} | ||
|
||
env.CheckUserArg(File.Exists(modelPath), nameof(modelPath)); | ||
var bytes = File.ReadAllBytes(modelPath); | ||
return LoadTFSession(env, bytes, modelPath); | ||
} | ||
|
||
internal static unsafe void FetchData<T>(IntPtr data, T[] result) | ||
{ | ||
var size = result.Length; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be
netcoreapp2.1
. Note thatnetcoreapp
andnetstandard
are different things. We wantnetcoreapp2.1
for any and all executables and tests. We wantnetstandard2.0
for any libraries.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moreover, we don't even want this in the lib folder as it is just a commandline tool.